br: add log backup/restore encryption support (#55757)
close pingcap/tidb#55834
This commit is contained in:
@ -74,14 +74,14 @@ func runRestoreCommand(command *cobra.Command, cmdName string) error {
|
||||
|
||||
if err := task.RunRestore(GetDefaultContext(), tidbGlue, cmdName, &cfg); err != nil {
|
||||
log.Error("failed to restore", zap.Error(err))
|
||||
printWorkaroundOnFullRestoreError(command, err)
|
||||
printWorkaroundOnFullRestoreError(err)
|
||||
return errors.Trace(err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// print workaround when we met not fresh or incompatible cluster error on full cluster restore
|
||||
func printWorkaroundOnFullRestoreError(command *cobra.Command, err error) {
|
||||
func printWorkaroundOnFullRestoreError(err error) {
|
||||
if !errors.ErrorEqual(err, berrors.ErrRestoreNotFreshCluster) &&
|
||||
!errors.ErrorEqual(err, berrors.ErrRestoreIncompatibleSys) {
|
||||
return
|
||||
|
||||
@ -33,6 +33,7 @@ import (
|
||||
"github.com/pingcap/tidb/br/pkg/rtree"
|
||||
"github.com/pingcap/tidb/br/pkg/storage"
|
||||
"github.com/pingcap/tidb/br/pkg/summary"
|
||||
"github.com/pingcap/tidb/br/pkg/utils"
|
||||
"github.com/pingcap/tidb/pkg/util"
|
||||
"go.uber.org/zap"
|
||||
"golang.org/x/sync/errgroup"
|
||||
@ -561,7 +562,7 @@ func parseCheckpointData[K KeyType, V ValueType](
|
||||
*pastDureTime = checkpointData.DureTime
|
||||
}
|
||||
for _, meta := range checkpointData.RangeGroupMetas {
|
||||
decryptContent, err := metautil.Decrypt(meta.RangeGroupsEncriptedData, cipher, meta.CipherIv)
|
||||
decryptContent, err := utils.Decrypt(meta.RangeGroupsEncriptedData, cipher, meta.CipherIv)
|
||||
if err != nil {
|
||||
return errors.Trace(err)
|
||||
}
|
||||
|
||||
@ -299,8 +299,8 @@ func (mgr *Mgr) Close() {
|
||||
mgr.PdController.Close()
|
||||
}
|
||||
|
||||
// GetTS gets current ts from pd.
|
||||
func (mgr *Mgr) GetTS(ctx context.Context) (uint64, error) {
|
||||
// GetCurrentTsFromPD gets current ts from PD.
|
||||
func (mgr *Mgr) GetCurrentTsFromPD(ctx context.Context) (uint64, error) {
|
||||
p, l, err := mgr.GetPDClient().GetTS(ctx)
|
||||
if err != nil {
|
||||
return 0, errors.Trace(err)
|
||||
|
||||
15
br/pkg/encryption/BUILD.bazel
Normal file
15
br/pkg/encryption/BUILD.bazel
Normal file
@ -0,0 +1,15 @@
|
||||
load("@io_bazel_rules_go//go:def.bzl", "go_library")
|
||||
|
||||
go_library(
|
||||
name = "encryption",
|
||||
srcs = ["manager.go"],
|
||||
importpath = "github.com/pingcap/tidb/br/pkg/encryption",
|
||||
visibility = ["//visibility:public"],
|
||||
deps = [
|
||||
"//br/pkg/encryption/master_key",
|
||||
"//br/pkg/utils",
|
||||
"@com_github_pingcap_errors//:errors",
|
||||
"@com_github_pingcap_kvproto//pkg/brpb",
|
||||
"@com_github_pingcap_kvproto//pkg/encryptionpb",
|
||||
],
|
||||
)
|
||||
93
br/pkg/encryption/manager.go
Normal file
93
br/pkg/encryption/manager.go
Normal file
@ -0,0 +1,93 @@
|
||||
// Copyright 2024 PingCAP, Inc. Licensed under Apache-2.0.
|
||||
|
||||
package encryption
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/pingcap/errors"
|
||||
backuppb "github.com/pingcap/kvproto/pkg/brpb"
|
||||
"github.com/pingcap/kvproto/pkg/encryptionpb"
|
||||
encryption "github.com/pingcap/tidb/br/pkg/encryption/master_key"
|
||||
"github.com/pingcap/tidb/br/pkg/utils"
|
||||
)
|
||||
|
||||
type Manager struct {
|
||||
cipherInfo *backuppb.CipherInfo
|
||||
masterKeyBackends *encryption.MultiMasterKeyBackend
|
||||
encryptionMethod *encryptionpb.EncryptionMethod
|
||||
}
|
||||
|
||||
func NewManager(cipherInfo *backuppb.CipherInfo, masterKeyConfigs *backuppb.MasterKeyConfig) (*Manager, error) {
|
||||
// should never happen since config has default
|
||||
if cipherInfo == nil || masterKeyConfigs == nil {
|
||||
return nil, errors.New("cipherInfo or masterKeyConfigs is nil")
|
||||
}
|
||||
|
||||
if utils.IsEffectiveEncryptionMethod(cipherInfo.CipherType) {
|
||||
return &Manager{
|
||||
cipherInfo: cipherInfo,
|
||||
masterKeyBackends: nil,
|
||||
encryptionMethod: nil,
|
||||
}, nil
|
||||
}
|
||||
if utils.IsEffectiveEncryptionMethod(masterKeyConfigs.EncryptionType) {
|
||||
masterKeyBackends, err := encryption.NewMultiMasterKeyBackend(masterKeyConfigs.GetMasterKeys())
|
||||
if err != nil {
|
||||
return nil, errors.Trace(err)
|
||||
}
|
||||
return &Manager{
|
||||
cipherInfo: nil,
|
||||
masterKeyBackends: masterKeyBackends,
|
||||
encryptionMethod: &masterKeyConfigs.EncryptionType,
|
||||
}, nil
|
||||
}
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (m *Manager) Decrypt(ctx context.Context, content []byte, fileEncryptionInfo *encryptionpb.FileEncryptionInfo) (
|
||||
[]byte, error) {
|
||||
switch mode := fileEncryptionInfo.Mode.(type) {
|
||||
case *encryptionpb.FileEncryptionInfo_PlainTextDataKey:
|
||||
if m.cipherInfo == nil {
|
||||
return nil, errors.New("plaintext data key info is required but not set")
|
||||
}
|
||||
decryptedContent, err := utils.Decrypt(content, m.cipherInfo, fileEncryptionInfo.FileIv)
|
||||
if err != nil {
|
||||
return nil, errors.Annotate(err, "failed to decrypt content using plaintext data key")
|
||||
}
|
||||
return decryptedContent, nil
|
||||
case *encryptionpb.FileEncryptionInfo_MasterKeyBased:
|
||||
encryptedContents := fileEncryptionInfo.GetMasterKeyBased().DataKeyEncryptedContent
|
||||
if len(encryptedContents) == 0 {
|
||||
return nil, errors.New("should contain at least one encrypted data key")
|
||||
}
|
||||
// pick first one, the list is for future expansion of multiple encrypted data keys by different master key backend
|
||||
encryptedContent := encryptedContents[0]
|
||||
decryptedDataKey, err := m.masterKeyBackends.Decrypt(ctx, encryptedContent)
|
||||
if err != nil {
|
||||
return nil, errors.Annotate(err, "failed to decrypt data key using master key")
|
||||
}
|
||||
|
||||
cipherInfo := backuppb.CipherInfo{
|
||||
CipherType: fileEncryptionInfo.EncryptionMethod,
|
||||
CipherKey: decryptedDataKey,
|
||||
}
|
||||
decryptedContent, err := utils.Decrypt(content, &cipherInfo, fileEncryptionInfo.FileIv)
|
||||
if err != nil {
|
||||
return nil, errors.Annotate(err, "failed to decrypt content using decrypted data key")
|
||||
}
|
||||
return decryptedContent, nil
|
||||
default:
|
||||
return nil, errors.Errorf("internal error: unsupported encryption mode type %T", mode)
|
||||
}
|
||||
}
|
||||
|
||||
func (m *Manager) Close() {
|
||||
if m == nil {
|
||||
return
|
||||
}
|
||||
if m.masterKeyBackends != nil {
|
||||
m.masterKeyBackends.Close()
|
||||
}
|
||||
}
|
||||
44
br/pkg/encryption/master_key/BUILD.bazel
Normal file
44
br/pkg/encryption/master_key/BUILD.bazel
Normal file
@ -0,0 +1,44 @@
|
||||
load("@io_bazel_rules_go//go:def.bzl", "go_library", "go_test")
|
||||
|
||||
go_library(
|
||||
name = "master_key",
|
||||
srcs = [
|
||||
"common.go",
|
||||
"file_backend.go",
|
||||
"kms_backend.go",
|
||||
"master_key.go",
|
||||
"mem_backend.go",
|
||||
"multi_master_key_backend.go",
|
||||
],
|
||||
importpath = "github.com/pingcap/tidb/br/pkg/encryption/master_key",
|
||||
visibility = ["//visibility:public"],
|
||||
deps = [
|
||||
"//br/pkg/kms:aws",
|
||||
"//br/pkg/utils",
|
||||
"@com_github_pingcap_errors//:errors",
|
||||
"@com_github_pingcap_kvproto//pkg/encryptionpb",
|
||||
"@com_github_pingcap_log//:log",
|
||||
"@org_uber_go_multierr//:multierr",
|
||||
"@org_uber_go_zap//:zap",
|
||||
],
|
||||
)
|
||||
|
||||
go_test(
|
||||
name = "master_key_test",
|
||||
timeout = "short",
|
||||
srcs = [
|
||||
"file_backend_test.go",
|
||||
"kms_backend_test.go",
|
||||
"mem_backend_test.go",
|
||||
"multi_master_key_backend_test.go",
|
||||
],
|
||||
embed = [":master_key"],
|
||||
flaky = True,
|
||||
shard_count = 11,
|
||||
deps = [
|
||||
"@com_github_pingcap_errors//:errors",
|
||||
"@com_github_pingcap_kvproto//pkg/encryptionpb",
|
||||
"@com_github_stretchr_testify//mock",
|
||||
"@com_github_stretchr_testify//require",
|
||||
],
|
||||
)
|
||||
60
br/pkg/encryption/master_key/common.go
Normal file
60
br/pkg/encryption/master_key/common.go
Normal file
@ -0,0 +1,60 @@
|
||||
// Copyright 2024 PingCAP, Inc. Licensed under Apache-2.0.
|
||||
|
||||
package encryption
|
||||
|
||||
import (
|
||||
"crypto/rand"
|
||||
|
||||
"github.com/pingcap/errors"
|
||||
)
|
||||
|
||||
// must keep it same with the constants in TiKV implementation
|
||||
const (
|
||||
MetadataKeyMethod string = "method"
|
||||
MetadataKeyIv string = "iv"
|
||||
MetadataKeyAesGcmTag string = "aes_gcm_tag"
|
||||
MetadataKeyKmsVendor string = "kms_vendor"
|
||||
MetadataKeyKmsCiphertextKey string = "kms_ciphertext_key"
|
||||
MetadataMethodAes256Gcm string = "aes256-gcm"
|
||||
)
|
||||
|
||||
const (
|
||||
GcmIv12 = 12
|
||||
CtrIv16 = 16
|
||||
)
|
||||
|
||||
type IvType int
|
||||
|
||||
const (
|
||||
IvTypeGcm IvType = iota
|
||||
IvTypeCtr
|
||||
)
|
||||
|
||||
type IV struct {
|
||||
Type IvType
|
||||
Data []byte
|
||||
}
|
||||
|
||||
func NewIVGcm() (IV, error) {
|
||||
iv := make([]byte, GcmIv12)
|
||||
_, err := rand.Read(iv)
|
||||
if err != nil {
|
||||
return IV{}, err
|
||||
}
|
||||
return IV{Type: IvTypeGcm, Data: iv}, nil
|
||||
}
|
||||
|
||||
func NewIVFromSlice(src []byte) (IV, error) {
|
||||
switch len(src) {
|
||||
case CtrIv16:
|
||||
return IV{Type: IvTypeCtr, Data: append([]byte(nil), src...)}, nil
|
||||
case GcmIv12:
|
||||
return IV{Type: IvTypeGcm, Data: append([]byte(nil), src...)}, nil
|
||||
default:
|
||||
return IV{}, errors.Errorf("invalid IV length, must be 12 or 16 bytes, got %d", len(src))
|
||||
}
|
||||
}
|
||||
|
||||
func (iv IV) AsSlice() []byte {
|
||||
return iv.Data
|
||||
}
|
||||
68
br/pkg/encryption/master_key/file_backend.go
Normal file
68
br/pkg/encryption/master_key/file_backend.go
Normal file
@ -0,0 +1,68 @@
|
||||
// Copyright 2024 PingCAP, Inc. Licensed under Apache-2.0.
|
||||
|
||||
package encryption
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/hex"
|
||||
"os"
|
||||
|
||||
"github.com/pingcap/errors"
|
||||
"github.com/pingcap/kvproto/pkg/encryptionpb"
|
||||
)
|
||||
|
||||
const AesGcmKeyLen = 32 // AES-256 key length
|
||||
|
||||
// FileBackend is ported from TiKV FileBackend
|
||||
type FileBackend struct {
|
||||
memCache *MemAesGcmBackend
|
||||
}
|
||||
|
||||
func createFileBackend(keyPath string) (*FileBackend, error) {
|
||||
// FileBackend uses AES-256-GCM
|
||||
keyLen := AesGcmKeyLen
|
||||
|
||||
content, err := os.ReadFile(keyPath)
|
||||
if err != nil {
|
||||
return nil, errors.Annotate(err, "failed to read master key file from disk")
|
||||
}
|
||||
|
||||
fileLen := len(content)
|
||||
expectedLen := keyLen*2 + 1 // hex-encoded key + newline
|
||||
|
||||
if fileLen != expectedLen {
|
||||
return nil, errors.Errorf("mismatch master key file size, expected %d, actual %d", expectedLen, fileLen)
|
||||
}
|
||||
|
||||
if content[fileLen-1] != '\n' {
|
||||
return nil, errors.Errorf("master key file should end with newline")
|
||||
}
|
||||
|
||||
key, err := hex.DecodeString(string(content[:fileLen-1]))
|
||||
if err != nil {
|
||||
return nil, errors.Annotate(err, "failed to decode hex format master key from file")
|
||||
}
|
||||
|
||||
backend, err := NewMemAesGcmBackend(key)
|
||||
if err != nil {
|
||||
return nil, errors.Annotate(err, "failed to create MemAesGcmBackend")
|
||||
}
|
||||
|
||||
return &FileBackend{memCache: backend}, nil
|
||||
}
|
||||
|
||||
func (f *FileBackend) Encrypt(ctx context.Context, plaintext []byte) (*encryptionpb.EncryptedContent, error) {
|
||||
iv, err := NewIVGcm()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return f.memCache.EncryptContent(ctx, plaintext, iv)
|
||||
}
|
||||
|
||||
func (f *FileBackend) Decrypt(ctx context.Context, content *encryptionpb.EncryptedContent) ([]byte, error) {
|
||||
return f.memCache.DecryptContent(ctx, content)
|
||||
}
|
||||
|
||||
func (f *FileBackend) Close() {
|
||||
// nothing to close
|
||||
}
|
||||
105
br/pkg/encryption/master_key/file_backend_test.go
Normal file
105
br/pkg/encryption/master_key/file_backend_test.go
Normal file
@ -0,0 +1,105 @@
|
||||
// Copyright 2024 PingCAP, Inc. Licensed under Apache-2.0.
|
||||
|
||||
package encryption
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/hex"
|
||||
"os"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
// TempKeyFile represents a temporary key file for testing
|
||||
type TempKeyFile struct {
|
||||
Path string
|
||||
file *os.File
|
||||
}
|
||||
|
||||
// Cleanup closes and removes the temporary file
|
||||
func (tkf *TempKeyFile) Cleanup() {
|
||||
if tkf.file != nil {
|
||||
tkf.file.Close()
|
||||
}
|
||||
os.Remove(tkf.Path)
|
||||
}
|
||||
|
||||
// createMasterKeyFile creates a temporary master key file for testing
|
||||
func createMasterKeyFile() (*TempKeyFile, error) {
|
||||
tempFile, err := os.CreateTemp("", "test_key_*.txt")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
_, err = tempFile.WriteString("c3d99825f2181f4808acd2068eac7441a65bd428f14d2aab43fefc0129091139\n")
|
||||
if err != nil {
|
||||
tempFile.Close()
|
||||
os.Remove(tempFile.Name())
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &TempKeyFile{
|
||||
Path: tempFile.Name(),
|
||||
file: tempFile,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func TestFileBackendAes256Gcm(t *testing.T) {
|
||||
pt, err := hex.DecodeString("25431587e9ecffc7c37f8d6d52a9bc3310651d46fb0e3bad2726c8f2db653749")
|
||||
require.NoError(t, err)
|
||||
ct, err := hex.DecodeString("84e5f23f95648fa247cb28eef53abec947dbf05ac953734618111583840bd980")
|
||||
require.NoError(t, err)
|
||||
ivBytes, err := hex.DecodeString("cafabd9672ca6c79a2fbdc22")
|
||||
require.NoError(t, err)
|
||||
|
||||
tempKeyFile, err := createMasterKeyFile()
|
||||
require.NoError(t, err)
|
||||
defer tempKeyFile.Cleanup()
|
||||
|
||||
backend, err := createFileBackend(tempKeyFile.Path)
|
||||
require.NoError(t, err)
|
||||
|
||||
ctx := context.Background()
|
||||
iv, err := NewIVFromSlice(ivBytes)
|
||||
require.NoError(t, err)
|
||||
|
||||
encryptedContent, err := backend.memCache.EncryptContent(ctx, pt, iv)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, ct, encryptedContent.Content)
|
||||
|
||||
plaintext, err := backend.Decrypt(ctx, encryptedContent)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, pt, plaintext)
|
||||
}
|
||||
|
||||
func TestFileBackendAuthenticate(t *testing.T) {
|
||||
pt := []byte{1, 2, 3}
|
||||
|
||||
tempKeyFile, err := createMasterKeyFile()
|
||||
require.NoError(t, err)
|
||||
defer tempKeyFile.Cleanup()
|
||||
|
||||
backend, err := createFileBackend(tempKeyFile.Path)
|
||||
require.NoError(t, err)
|
||||
|
||||
ctx := context.Background()
|
||||
encryptedContent, err := backend.Encrypt(ctx, pt)
|
||||
require.NoError(t, err)
|
||||
|
||||
plaintext, err := backend.Decrypt(ctx, encryptedContent)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, pt, plaintext)
|
||||
|
||||
// Test checksum mismatch
|
||||
encryptedContent1 := *encryptedContent
|
||||
encryptedContent1.Metadata[MetadataKeyAesGcmTag][0] ^= 0xFF
|
||||
_, err = backend.Decrypt(ctx, &encryptedContent1)
|
||||
require.ErrorContains(t, err, wrongMasterKey)
|
||||
|
||||
// Test checksum not found
|
||||
encryptedContent2 := *encryptedContent
|
||||
delete(encryptedContent2.Metadata, MetadataKeyAesGcmTag)
|
||||
_, err = backend.Decrypt(ctx, &encryptedContent2)
|
||||
require.ErrorContains(t, err, gcmTagNotFound)
|
||||
}
|
||||
88
br/pkg/encryption/master_key/kms_backend.go
Normal file
88
br/pkg/encryption/master_key/kms_backend.go
Normal file
@ -0,0 +1,88 @@
|
||||
// Copyright 2024 PingCAP, Inc. Licensed under Apache-2.0.
|
||||
|
||||
package encryption
|
||||
|
||||
import (
|
||||
"context"
|
||||
"sync"
|
||||
|
||||
"github.com/pingcap/errors"
|
||||
"github.com/pingcap/kvproto/pkg/encryptionpb"
|
||||
"github.com/pingcap/tidb/br/pkg/kms"
|
||||
"github.com/pingcap/tidb/br/pkg/utils"
|
||||
)
|
||||
|
||||
type CachedKeys struct {
|
||||
encryptionBackend *MemAesGcmBackend
|
||||
cachedCiphertextKey *kms.EncryptedKey
|
||||
}
|
||||
|
||||
type KmsBackend struct {
|
||||
state struct {
|
||||
sync.Mutex
|
||||
cached *CachedKeys
|
||||
}
|
||||
kmsProvider kms.Provider
|
||||
}
|
||||
|
||||
func NewKmsBackend(kmsProvider kms.Provider) (*KmsBackend, error) {
|
||||
return &KmsBackend{
|
||||
kmsProvider: kmsProvider,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (k *KmsBackend) Decrypt(ctx context.Context, content *encryptionpb.EncryptedContent) ([]byte, error) {
|
||||
vendorName := k.kmsProvider.Name()
|
||||
if val, ok := content.Metadata[MetadataKeyKmsVendor]; !ok {
|
||||
return nil, errors.New("wrong master key: missing KMS vendor")
|
||||
} else if string(val) != vendorName {
|
||||
return nil, errors.Errorf("KMS vendor mismatch expect %s got %s", vendorName, string(val))
|
||||
}
|
||||
|
||||
ciphertextKeyBytes, ok := content.Metadata[MetadataKeyKmsCiphertextKey]
|
||||
if !ok {
|
||||
return nil, errors.New("KMS ciphertext key not found")
|
||||
}
|
||||
ciphertextKey, err := kms.NewEncryptedKey(ciphertextKeyBytes)
|
||||
if err != nil {
|
||||
return nil, errors.Annotate(err, "failed to create encrypted key")
|
||||
}
|
||||
|
||||
k.state.Lock()
|
||||
defer k.state.Unlock()
|
||||
|
||||
if k.state.cached != nil && k.state.cached.cachedCiphertextKey.Equal(&ciphertextKey) {
|
||||
return k.state.cached.encryptionBackend.DecryptContent(ctx, content)
|
||||
}
|
||||
|
||||
// piggyback on NewDownloadSSTBackoffer, a refactor is ongoing to remove all the backoffers
|
||||
// so user don't need to write a backoffer for every type
|
||||
decryptedKey, err :=
|
||||
utils.WithRetryV2(ctx, utils.NewDownloadSSTBackoffer(), func(ctx context.Context) ([]byte, error) {
|
||||
return k.kmsProvider.DecryptDataKey(ctx, ciphertextKey)
|
||||
})
|
||||
if err != nil {
|
||||
return nil, errors.Annotate(err, "decrypt encrypted key failed")
|
||||
}
|
||||
|
||||
plaintextKey, err := kms.NewPlainKey(decryptedKey, kms.CryptographyTypeAesGcm256)
|
||||
if err != nil {
|
||||
return nil, errors.Annotate(err, "decrypt encrypted key failed")
|
||||
}
|
||||
|
||||
backend, err := NewMemAesGcmBackend(plaintextKey.Key())
|
||||
if err != nil {
|
||||
return nil, errors.Annotate(err, "failed to create MemAesGcmBackend")
|
||||
}
|
||||
|
||||
k.state.cached = &CachedKeys{
|
||||
encryptionBackend: backend,
|
||||
cachedCiphertextKey: &ciphertextKey,
|
||||
}
|
||||
|
||||
return k.state.cached.encryptionBackend.DecryptContent(ctx, content)
|
||||
}
|
||||
|
||||
func (k *KmsBackend) Close() {
|
||||
k.kmsProvider.Close()
|
||||
}
|
||||
113
br/pkg/encryption/master_key/kms_backend_test.go
Normal file
113
br/pkg/encryption/master_key/kms_backend_test.go
Normal file
@ -0,0 +1,113 @@
|
||||
// Copyright 2024 PingCAP, Inc. Licensed under Apache-2.0.
|
||||
|
||||
package encryption
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/rand"
|
||||
"testing"
|
||||
|
||||
"github.com/pingcap/kvproto/pkg/encryptionpb"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
type mockKmsProvider struct {
|
||||
name string
|
||||
decryptCounter int
|
||||
}
|
||||
|
||||
func (m *mockKmsProvider) Name() string {
|
||||
return m.name
|
||||
}
|
||||
|
||||
func (m *mockKmsProvider) DecryptDataKey(_ctx context.Context, _encryptedKey []byte) ([]byte, error) {
|
||||
m.decryptCounter++
|
||||
key := make([]byte, 32) // 256 bits = 32 bytes
|
||||
_, err := rand.Read(key)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return key, nil
|
||||
}
|
||||
|
||||
func (m *mockKmsProvider) Close() {
|
||||
// do nothing
|
||||
}
|
||||
|
||||
func TestKmsBackendDecrypt(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
mockProvider := &mockKmsProvider{name: "mock_kms"}
|
||||
backend, err := NewKmsBackend(mockProvider)
|
||||
require.NoError(t, err)
|
||||
|
||||
ciphertextKey := []byte("ciphertext_key")
|
||||
content := &encryptionpb.EncryptedContent{
|
||||
Metadata: map[string][]byte{
|
||||
MetadataKeyKmsVendor: []byte("mock_kms"),
|
||||
MetadataKeyKmsCiphertextKey: ciphertextKey,
|
||||
},
|
||||
Content: []byte("encrypted_content"),
|
||||
}
|
||||
|
||||
// First decryption
|
||||
_, _ = backend.Decrypt(ctx, content)
|
||||
require.Equal(t, 1, mockProvider.decryptCounter, "KMS provider should be called once")
|
||||
|
||||
// Second decryption with the same ciphertext key (should use cache)
|
||||
_, _ = backend.Decrypt(ctx, content)
|
||||
require.Equal(t, 1, mockProvider.decryptCounter, "KMS provider should not be called again")
|
||||
|
||||
// Third decryption with a different ciphertext key
|
||||
content.Metadata[MetadataKeyKmsCiphertextKey] = []byte("new_ciphertext_key")
|
||||
_, _ = backend.Decrypt(ctx, content)
|
||||
require.Equal(t, 2, mockProvider.decryptCounter, "KMS provider should be called again for a new key")
|
||||
}
|
||||
|
||||
func TestKmsBackendDecryptErrors(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
mockProvider := &mockKmsProvider{name: "mock_kms"}
|
||||
backend, err := NewKmsBackend(mockProvider)
|
||||
require.NoError(t, err)
|
||||
|
||||
testCases := []struct {
|
||||
name string
|
||||
content *encryptionpb.EncryptedContent
|
||||
errMsg string
|
||||
}{
|
||||
{
|
||||
name: "missing KMS vendor",
|
||||
content: &encryptionpb.EncryptedContent{
|
||||
Metadata: map[string][]byte{
|
||||
MetadataKeyKmsCiphertextKey: []byte("ciphertext_key"),
|
||||
},
|
||||
},
|
||||
errMsg: "wrong master key: missing KMS vendor",
|
||||
},
|
||||
{
|
||||
name: "KMS vendor mismatch",
|
||||
content: &encryptionpb.EncryptedContent{
|
||||
Metadata: map[string][]byte{
|
||||
MetadataKeyKmsVendor: []byte("wrong_kms"),
|
||||
MetadataKeyKmsCiphertextKey: []byte("ciphertext_key"),
|
||||
},
|
||||
},
|
||||
errMsg: "KMS vendor mismatch expect mock_kms got wrong_kms",
|
||||
},
|
||||
{
|
||||
name: "missing ciphertext key",
|
||||
content: &encryptionpb.EncryptedContent{
|
||||
Metadata: map[string][]byte{
|
||||
MetadataKeyKmsVendor: []byte("mock_kms"),
|
||||
},
|
||||
},
|
||||
errMsg: "KMS ciphertext key not found",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
_, err := backend.Decrypt(ctx, tc.content)
|
||||
require.ErrorContains(t, err, tc.errMsg)
|
||||
})
|
||||
}
|
||||
}
|
||||
77
br/pkg/encryption/master_key/master_key.go
Normal file
77
br/pkg/encryption/master_key/master_key.go
Normal file
@ -0,0 +1,77 @@
|
||||
// Copyright 2024 PingCAP, Inc. Licensed under Apache-2.0.
|
||||
|
||||
package encryption
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/pingcap/errors"
|
||||
"github.com/pingcap/kvproto/pkg/encryptionpb"
|
||||
"github.com/pingcap/log"
|
||||
"github.com/pingcap/tidb/br/pkg/kms"
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
const (
|
||||
StorageVendorNameAWS = "aws"
|
||||
StorageVendorNameAzure = "azure"
|
||||
StorageVendorNameGCP = "gcp"
|
||||
)
|
||||
|
||||
// Backend is an interface that defines the methods required for an encryption backend.
|
||||
type Backend interface {
|
||||
// Decrypt takes an EncryptedContent and returns the decrypted plaintext as a byte slice or an error.
|
||||
Decrypt(ctx context.Context, ciphertext *encryptionpb.EncryptedContent) ([]byte, error)
|
||||
Close()
|
||||
}
|
||||
|
||||
func CreateBackend(config *encryptionpb.MasterKey) (Backend, error) {
|
||||
if config == nil {
|
||||
return nil, errors.Errorf("master key config is nil")
|
||||
}
|
||||
|
||||
switch backend := config.Backend.(type) {
|
||||
case *encryptionpb.MasterKey_Plaintext:
|
||||
// should not plaintext type as guarded by caller
|
||||
return nil, errors.New("should not create plaintext master key")
|
||||
case *encryptionpb.MasterKey_File:
|
||||
fileBackend, err := createFileBackend(backend.File.Path)
|
||||
if err != nil {
|
||||
return nil, errors.Annotate(err, "master key config is nil")
|
||||
}
|
||||
return fileBackend, nil
|
||||
case *encryptionpb.MasterKey_Kms:
|
||||
return createCloudBackend(backend.Kms)
|
||||
default:
|
||||
return nil, errors.New("unknown master key backend type")
|
||||
}
|
||||
}
|
||||
|
||||
func createCloudBackend(config *encryptionpb.MasterKeyKms) (Backend, error) {
|
||||
log.Info("creating cloud KMS backend",
|
||||
zap.String("region", config.GetRegion()),
|
||||
zap.String("endpoint", config.GetEndpoint()),
|
||||
zap.String("key_id", config.GetKeyId()),
|
||||
zap.String("Vendor", config.GetVendor()))
|
||||
|
||||
switch config.Vendor {
|
||||
case StorageVendorNameAWS:
|
||||
kmsProvider, err := kms.NewAwsKms(config)
|
||||
if err != nil {
|
||||
return nil, errors.Annotate(err, "new AWS KMS")
|
||||
}
|
||||
return NewKmsBackend(kmsProvider)
|
||||
|
||||
case StorageVendorNameAzure:
|
||||
return nil, errors.Errorf("not implemented Azure KMS")
|
||||
case StorageVendorNameGCP:
|
||||
kmsProvider, err := kms.NewGcpKms(config)
|
||||
if err != nil {
|
||||
return nil, errors.Annotate(err, "new GCP KMS")
|
||||
}
|
||||
return NewKmsBackend(kmsProvider)
|
||||
|
||||
default:
|
||||
return nil, errors.Errorf("vendor not found: %s", config.Vendor)
|
||||
}
|
||||
}
|
||||
103
br/pkg/encryption/master_key/mem_backend.go
Normal file
103
br/pkg/encryption/master_key/mem_backend.go
Normal file
@ -0,0 +1,103 @@
|
||||
// Copyright 2024 PingCAP, Inc. Licensed under Apache-2.0.
|
||||
|
||||
package encryption
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/aes"
|
||||
"crypto/cipher"
|
||||
|
||||
"github.com/pingcap/errors"
|
||||
"github.com/pingcap/kvproto/pkg/encryptionpb"
|
||||
"github.com/pingcap/tidb/br/pkg/kms"
|
||||
)
|
||||
|
||||
const (
|
||||
gcmTagNotFound = "aes gcm tag not found"
|
||||
wrongMasterKey = "wrong master key"
|
||||
)
|
||||
|
||||
type MemAesGcmBackend struct {
|
||||
key *kms.PlainKey
|
||||
}
|
||||
|
||||
func NewMemAesGcmBackend(key []byte) (*MemAesGcmBackend, error) {
|
||||
plainKey, err := kms.NewPlainKey(key, kms.CryptographyTypeAesGcm256)
|
||||
if err != nil {
|
||||
return nil, errors.Annotate(err, "failed to create new mem aes gcm backend")
|
||||
}
|
||||
return &MemAesGcmBackend{
|
||||
key: plainKey,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (m *MemAesGcmBackend) EncryptContent(_ctx context.Context, plaintext []byte, iv IV) (
|
||||
*encryptionpb.EncryptedContent, error) {
|
||||
content := encryptionpb.EncryptedContent{
|
||||
Metadata: make(map[string][]byte),
|
||||
}
|
||||
content.Metadata[MetadataKeyMethod] = []byte(MetadataMethodAes256Gcm)
|
||||
content.Metadata[MetadataKeyIv] = iv.AsSlice()
|
||||
|
||||
block, err := aes.NewCipher(m.key.Key())
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
aesgcm, err := cipher.NewGCM(block)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// The Seal function in AES-GCM mode appends the authentication tag to the ciphertext.
|
||||
// We need to separate the actual ciphertext from the tag for storage and later verification.
|
||||
// Reference: https://pkg.go.dev/crypto/cipher#AEAD
|
||||
ciphertext := aesgcm.Seal(nil, iv.AsSlice(), plaintext, nil)
|
||||
content.Content = ciphertext[:len(ciphertext)-aesgcm.Overhead()]
|
||||
content.Metadata[MetadataKeyAesGcmTag] = ciphertext[len(ciphertext)-aesgcm.Overhead():]
|
||||
|
||||
return &content, nil
|
||||
}
|
||||
|
||||
func (m *MemAesGcmBackend) DecryptContent(_ctx context.Context, content *encryptionpb.EncryptedContent) (
|
||||
[]byte, error) {
|
||||
method, ok := content.Metadata[MetadataKeyMethod]
|
||||
if !ok {
|
||||
return nil, errors.Errorf("metadata %s not found", MetadataKeyMethod)
|
||||
}
|
||||
if string(method) != MetadataMethodAes256Gcm {
|
||||
return nil, errors.Errorf("encryption method mismatch, expected %s vs actual %s",
|
||||
MetadataMethodAes256Gcm, method)
|
||||
}
|
||||
|
||||
ivValue, ok := content.Metadata[MetadataKeyIv]
|
||||
if !ok {
|
||||
return nil, errors.Errorf("metadata %s not found", MetadataKeyIv)
|
||||
}
|
||||
|
||||
iv, err := NewIVFromSlice(ivValue)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
tag, ok := content.Metadata[MetadataKeyAesGcmTag]
|
||||
if !ok {
|
||||
return nil, errors.New("aes gcm tag not found")
|
||||
}
|
||||
|
||||
block, err := aes.NewCipher(m.key.Key())
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
aesgcm, err := cipher.NewGCM(block)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
ciphertext := append(content.Content, tag...)
|
||||
plaintext, err := aesgcm.Open(nil, iv.AsSlice(), ciphertext, nil)
|
||||
if err != nil {
|
||||
return nil, errors.Annotate(err, wrongMasterKey+" :decrypt in GCM mode failed")
|
||||
}
|
||||
|
||||
return plaintext, nil
|
||||
}
|
||||
115
br/pkg/encryption/master_key/mem_backend_test.go
Normal file
115
br/pkg/encryption/master_key/mem_backend_test.go
Normal file
@ -0,0 +1,115 @@
|
||||
// Copyright 2024 PingCAP, Inc. Licensed under Apache-2.0.
|
||||
|
||||
package encryption
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestNewMemAesGcmBackend(t *testing.T) {
|
||||
key := make([]byte, 32) // 256-bit key
|
||||
_, err := NewMemAesGcmBackend(key)
|
||||
require.NoError(t, err, "Failed to create MemAesGcmBackend")
|
||||
|
||||
shortKey := make([]byte, 16)
|
||||
_, err = NewMemAesGcmBackend(shortKey)
|
||||
require.Error(t, err, "Expected error for short key")
|
||||
}
|
||||
|
||||
func TestEncryptDecrypt(t *testing.T) {
|
||||
key := make([]byte, 32)
|
||||
backend, err := NewMemAesGcmBackend(key)
|
||||
require.NoError(t, err, "Failed to create MemAesGcmBackend")
|
||||
|
||||
plaintext := []byte("Hello, World!")
|
||||
|
||||
iv, err := NewIVGcm()
|
||||
require.NoError(t, err, "failed to create gcm iv")
|
||||
|
||||
ctx := context.Background()
|
||||
encrypted, err := backend.EncryptContent(ctx, plaintext, iv)
|
||||
require.NoError(t, err, "Encryption failed")
|
||||
|
||||
decrypted, err := backend.DecryptContent(ctx, encrypted)
|
||||
require.NoError(t, err, "Decryption failed")
|
||||
|
||||
require.Equal(t, plaintext, decrypted, "Decrypted text doesn't match original")
|
||||
}
|
||||
|
||||
func TestDecryptWithWrongKey(t *testing.T) {
|
||||
key1 := make([]byte, 32)
|
||||
key2 := make([]byte, 32)
|
||||
for i := range key2 {
|
||||
key2[i] = 1 // Different from key1
|
||||
}
|
||||
|
||||
backend1, _ := NewMemAesGcmBackend(key1)
|
||||
backend2, _ := NewMemAesGcmBackend(key2)
|
||||
|
||||
plaintext := []byte("Hello, World!")
|
||||
|
||||
iv, err := NewIVGcm()
|
||||
require.NoError(t, err, "failed to create gcm iv")
|
||||
|
||||
ctx := context.Background()
|
||||
encrypted, _ := backend1.EncryptContent(ctx, plaintext, iv)
|
||||
_, err = backend2.DecryptContent(ctx, encrypted)
|
||||
require.Error(t, err, "Expected decryption with wrong key to fail")
|
||||
}
|
||||
|
||||
func TestDecryptWithTamperedCiphertext(t *testing.T) {
|
||||
key := make([]byte, 32)
|
||||
backend, _ := NewMemAesGcmBackend(key)
|
||||
|
||||
plaintext := []byte("Hello, World!")
|
||||
|
||||
iv, err := NewIVGcm()
|
||||
require.NoError(t, err, "failed to create gcm iv")
|
||||
|
||||
ctx := context.Background()
|
||||
encrypted, _ := backend.EncryptContent(ctx, plaintext, iv)
|
||||
encrypted.Content[0] ^= 1 // Tamper with the ciphertext
|
||||
|
||||
_, err = backend.DecryptContent(ctx, encrypted)
|
||||
require.Error(t, err, "Expected decryption of tampered ciphertext to fail")
|
||||
}
|
||||
|
||||
func TestDecryptWithMissingMetadata(t *testing.T) {
|
||||
key := make([]byte, 32)
|
||||
backend, _ := NewMemAesGcmBackend(key)
|
||||
|
||||
plaintext := []byte("Hello, World!")
|
||||
|
||||
iv, err := NewIVGcm()
|
||||
require.NoError(t, err, "failed to create gcm iv")
|
||||
|
||||
ctx := context.Background()
|
||||
encrypted, _ := backend.EncryptContent(ctx, plaintext, iv)
|
||||
delete(encrypted.Metadata, MetadataKeyMethod)
|
||||
|
||||
_, err = backend.DecryptContent(ctx, encrypted)
|
||||
require.Error(t, err, "Expected decryption with missing metadata to fail")
|
||||
}
|
||||
|
||||
func TestEncryptDecryptLargeData(t *testing.T) {
|
||||
key := make([]byte, 32)
|
||||
backend, _ := NewMemAesGcmBackend(key)
|
||||
|
||||
plaintext := make([]byte, 1000000) // 1 MB of data
|
||||
|
||||
iv, err := NewIVGcm()
|
||||
require.NoError(t, err, "failed to create gcm iv")
|
||||
|
||||
ctx := context.Background()
|
||||
encrypted, err := backend.EncryptContent(ctx, plaintext, iv)
|
||||
require.NoError(t, err, "Encryption of large data failed")
|
||||
|
||||
decrypted, err := backend.DecryptContent(ctx, encrypted)
|
||||
require.NoError(t, err, "Decryption of large data failed")
|
||||
|
||||
require.True(t, bytes.Equal(plaintext, decrypted), "Decrypted large data doesn't match original")
|
||||
}
|
||||
64
br/pkg/encryption/master_key/multi_master_key_backend.go
Normal file
64
br/pkg/encryption/master_key/multi_master_key_backend.go
Normal file
@ -0,0 +1,64 @@
|
||||
// Copyright 2024 PingCAP, Inc. Licensed under Apache-2.0.
|
||||
|
||||
package encryption
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/pingcap/errors"
|
||||
"github.com/pingcap/kvproto/pkg/encryptionpb"
|
||||
"go.uber.org/multierr"
|
||||
)
|
||||
|
||||
const (
|
||||
defaultBackendCapacity = 5
|
||||
)
|
||||
|
||||
// MultiMasterKeyBackend can contain multiple master shard backends.
|
||||
// If any one of those backends successfully decrypts the data, the data will be returned.
|
||||
// The main purpose of this backend is to provide a high availability for master key in the future.
|
||||
// Right now only one master key backend is used to encrypt/decrypt data.
|
||||
type MultiMasterKeyBackend struct {
|
||||
backends []Backend
|
||||
}
|
||||
|
||||
func NewMultiMasterKeyBackend(masterKeysProto []*encryptionpb.MasterKey) (*MultiMasterKeyBackend, error) {
|
||||
if masterKeysProto == nil && len(masterKeysProto) == 0 {
|
||||
return nil, errors.New("must provide at least one master key")
|
||||
}
|
||||
var backends = make([]Backend, 0, defaultBackendCapacity)
|
||||
for _, masterKeyProto := range masterKeysProto {
|
||||
backend, err := CreateBackend(masterKeyProto)
|
||||
if err != nil {
|
||||
return nil, errors.Trace(err)
|
||||
}
|
||||
backends = append(backends, backend)
|
||||
}
|
||||
return &MultiMasterKeyBackend{
|
||||
backends: backends,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (m *MultiMasterKeyBackend) Decrypt(ctx context.Context, encryptedContent *encryptionpb.EncryptedContent) (
|
||||
[]byte, error) {
|
||||
if len(m.backends) == 0 {
|
||||
return nil, errors.New("internal error: should always contain at least one backend")
|
||||
}
|
||||
|
||||
var err error
|
||||
for _, masterKeyBackend := range m.backends {
|
||||
res, decryptErr := masterKeyBackend.Decrypt(ctx, encryptedContent)
|
||||
if decryptErr == nil {
|
||||
return res, nil
|
||||
}
|
||||
err = multierr.Append(err, decryptErr)
|
||||
}
|
||||
|
||||
return nil, errors.Wrap(err, "failed to decrypt in multi master key backend")
|
||||
}
|
||||
|
||||
func (m *MultiMasterKeyBackend) Close() {
|
||||
for _, backend := range m.backends {
|
||||
backend.Close()
|
||||
}
|
||||
}
|
||||
105
br/pkg/encryption/master_key/multi_master_key_backend_test.go
Normal file
105
br/pkg/encryption/master_key/multi_master_key_backend_test.go
Normal file
@ -0,0 +1,105 @@
|
||||
// Copyright 2024 PingCAP, Inc. Licensed under Apache-2.0.
|
||||
|
||||
package encryption
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
|
||||
"github.com/pingcap/errors"
|
||||
"github.com/pingcap/kvproto/pkg/encryptionpb"
|
||||
"github.com/stretchr/testify/mock"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
// MockBackend is a mock implementation of the Backend interface
|
||||
type MockBackend struct {
|
||||
mock.Mock
|
||||
}
|
||||
|
||||
func (m *MockBackend) Decrypt(ctx context.Context, encryptedContent *encryptionpb.EncryptedContent) ([]byte, error) {
|
||||
args := m.Called(ctx, encryptedContent)
|
||||
// The first return value should be []byte or nil
|
||||
if ret := args.Get(0); ret != nil {
|
||||
return ret.([]byte), args.Error(1)
|
||||
}
|
||||
return nil, args.Error(1)
|
||||
}
|
||||
|
||||
func (m *MockBackend) Close() {
|
||||
// do nothing
|
||||
}
|
||||
|
||||
func TestMultiMasterKeyBackendDecrypt(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
encryptedContent := &encryptionpb.EncryptedContent{Content: []byte("encrypted")}
|
||||
|
||||
t.Run("success first backend", func(t *testing.T) {
|
||||
mock1 := new(MockBackend)
|
||||
mock1.On("Decrypt", ctx, encryptedContent).Return([]byte("decrypted"), nil)
|
||||
|
||||
mock2 := new(MockBackend)
|
||||
|
||||
backend := &MultiMasterKeyBackend{
|
||||
backends: []Backend{mock1, mock2},
|
||||
}
|
||||
|
||||
result, err := backend.Decrypt(ctx, encryptedContent)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, []byte("decrypted"), result)
|
||||
|
||||
mock1.AssertExpectations(t)
|
||||
mock2.AssertNotCalled(t, "Decrypt")
|
||||
})
|
||||
|
||||
t.Run("success second backend", func(t *testing.T) {
|
||||
mock1 := new(MockBackend)
|
||||
mock1.On("Decrypt", ctx, encryptedContent).Return(nil, errors.New("failed"))
|
||||
|
||||
mock2 := new(MockBackend)
|
||||
mock2.On("Decrypt", ctx, encryptedContent).Return([]byte("decrypted"), nil)
|
||||
|
||||
backend := &MultiMasterKeyBackend{
|
||||
backends: []Backend{mock1, mock2},
|
||||
}
|
||||
|
||||
result, err := backend.Decrypt(ctx, encryptedContent)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, []byte("decrypted"), result)
|
||||
|
||||
mock1.AssertExpectations(t)
|
||||
mock2.AssertExpectations(t)
|
||||
})
|
||||
|
||||
t.Run("all backends fail", func(t *testing.T) {
|
||||
mock1 := new(MockBackend)
|
||||
mock1.On("Decrypt", ctx, encryptedContent).Return(nil, errors.New("failed1"))
|
||||
|
||||
mock2 := new(MockBackend)
|
||||
mock2.On("Decrypt", ctx, encryptedContent).Return(nil, errors.New("failed2"))
|
||||
|
||||
backend := &MultiMasterKeyBackend{
|
||||
backends: []Backend{mock1, mock2},
|
||||
}
|
||||
|
||||
result, err := backend.Decrypt(ctx, encryptedContent)
|
||||
require.Error(t, err)
|
||||
require.Nil(t, result)
|
||||
require.Contains(t, err.Error(), "failed1")
|
||||
require.Contains(t, err.Error(), "failed2")
|
||||
|
||||
mock1.AssertExpectations(t)
|
||||
mock2.AssertExpectations(t)
|
||||
})
|
||||
|
||||
t.Run("no backends", func(t *testing.T) {
|
||||
backend := &MultiMasterKeyBackend{
|
||||
backends: []Backend{},
|
||||
}
|
||||
|
||||
result, err := backend.Decrypt(ctx, encryptedContent)
|
||||
require.Error(t, err)
|
||||
require.Nil(t, result)
|
||||
require.Contains(t, err.Error(), "internal error")
|
||||
})
|
||||
}
|
||||
27
br/pkg/kms/BUILD.bazel
Normal file
27
br/pkg/kms/BUILD.bazel
Normal file
@ -0,0 +1,27 @@
|
||||
load("@io_bazel_rules_go//go:def.bzl", "go_library")
|
||||
|
||||
go_library(
|
||||
name = "aws",
|
||||
srcs = [
|
||||
"aws.go",
|
||||
"common.go",
|
||||
"gcp.go",
|
||||
"kms.go",
|
||||
],
|
||||
importpath = "github.com/pingcap/tidb/br/pkg/kms",
|
||||
visibility = ["//visibility:public"],
|
||||
deps = [
|
||||
"@com_github_aws_aws_sdk_go//aws",
|
||||
"@com_github_aws_aws_sdk_go//aws/credentials",
|
||||
"@com_github_aws_aws_sdk_go//aws/session",
|
||||
"@com_github_aws_aws_sdk_go//service/kms",
|
||||
"@com_github_pingcap_errors//:errors",
|
||||
"@com_github_pingcap_kvproto//pkg/encryptionpb",
|
||||
"@com_github_pingcap_log//:log",
|
||||
"@com_google_cloud_go_kms//apiv1",
|
||||
"@com_google_cloud_go_kms//apiv1/kmspb",
|
||||
"@org_golang_google_api//option",
|
||||
"@org_golang_google_protobuf//types/known/wrapperspb",
|
||||
"@org_uber_go_zap//:zap",
|
||||
],
|
||||
)
|
||||
92
br/pkg/kms/aws.go
Normal file
92
br/pkg/kms/aws.go
Normal file
@ -0,0 +1,92 @@
|
||||
// Copyright 2024 PingCAP, Inc. Licensed under Apache-2.0.
|
||||
|
||||
package kms
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/aws/aws-sdk-go/aws"
|
||||
"github.com/aws/aws-sdk-go/aws/credentials"
|
||||
"github.com/aws/aws-sdk-go/aws/session"
|
||||
"github.com/aws/aws-sdk-go/service/kms"
|
||||
pErrors "github.com/pingcap/errors"
|
||||
"github.com/pingcap/kvproto/pkg/encryptionpb"
|
||||
)
|
||||
|
||||
const (
|
||||
// need to keep it exact same as in TiKV ENCRYPTION_VENDOR_NAME_AWS_KMS
|
||||
EncryptionVendorNameAwsKms = "AWS"
|
||||
)
|
||||
|
||||
type AwsKms struct {
|
||||
client *kms.KMS
|
||||
currentKeyID string
|
||||
region string
|
||||
endpoint string
|
||||
}
|
||||
|
||||
func NewAwsKms(masterKeyConfig *encryptionpb.MasterKeyKms) (*AwsKms, error) {
|
||||
config := &aws.Config{
|
||||
Region: aws.String(masterKeyConfig.Region),
|
||||
Endpoint: aws.String(masterKeyConfig.Endpoint),
|
||||
}
|
||||
|
||||
// Only use static credentials if both access key and secret key are provided
|
||||
if masterKeyConfig.AwsKms != nil &&
|
||||
masterKeyConfig.AwsKms.AccessKey != "" &&
|
||||
masterKeyConfig.AwsKms.SecretAccessKey != "" {
|
||||
config.Credentials = credentials.NewStaticCredentials(
|
||||
masterKeyConfig.AwsKms.AccessKey,
|
||||
masterKeyConfig.AwsKms.SecretAccessKey,
|
||||
"",
|
||||
)
|
||||
}
|
||||
|
||||
sess, err := session.NewSession(config)
|
||||
if err != nil {
|
||||
return nil, pErrors.Annotate(err, "failed to create AWS session")
|
||||
}
|
||||
|
||||
return &AwsKms{
|
||||
client: kms.New(sess),
|
||||
currentKeyID: masterKeyConfig.KeyId,
|
||||
region: masterKeyConfig.Region,
|
||||
endpoint: masterKeyConfig.Endpoint,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (a *AwsKms) Name() string {
|
||||
return EncryptionVendorNameAwsKms
|
||||
}
|
||||
|
||||
func (a *AwsKms) DecryptDataKey(ctx context.Context, dataKey []byte) ([]byte, error) {
|
||||
input := &kms.DecryptInput{
|
||||
CiphertextBlob: dataKey,
|
||||
KeyId: aws.String(a.currentKeyID),
|
||||
}
|
||||
|
||||
result, err := a.client.DecryptWithContext(ctx, input)
|
||||
if err != nil {
|
||||
return nil, classifyDecryptError(err)
|
||||
}
|
||||
|
||||
return result.Plaintext, nil
|
||||
}
|
||||
|
||||
func (a *AwsKms) Close() {
|
||||
// don't need to do manual close
|
||||
}
|
||||
|
||||
// Update classifyDecryptError to use v1 SDK error types
|
||||
func classifyDecryptError(err error) error {
|
||||
switch err := err.(type) {
|
||||
case *kms.NotFoundException, *kms.InvalidKeyUsageException:
|
||||
return pErrors.Annotate(err, "wrong master key")
|
||||
case *kms.DependencyTimeoutException:
|
||||
return pErrors.Annotate(err, "API timeout")
|
||||
case *kms.InternalException:
|
||||
return pErrors.Annotate(err, "API internal error")
|
||||
default:
|
||||
return pErrors.Annotate(err, "KMS error")
|
||||
}
|
||||
}
|
||||
65
br/pkg/kms/common.go
Normal file
65
br/pkg/kms/common.go
Normal file
@ -0,0 +1,65 @@
|
||||
// Copyright 2024 PingCAP, Inc. Licensed under Apache-2.0.
|
||||
|
||||
package kms
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
|
||||
"github.com/pingcap/errors"
|
||||
)
|
||||
|
||||
// EncryptedKey is used to mark data as an encrypted key
|
||||
type EncryptedKey []byte
|
||||
|
||||
func NewEncryptedKey(key []byte) (EncryptedKey, error) {
|
||||
if len(key) == 0 {
|
||||
return nil, errors.New("encrypted key cannot be empty")
|
||||
}
|
||||
return key, nil
|
||||
}
|
||||
|
||||
// Equal method for EncryptedKey
|
||||
func (e EncryptedKey) Equal(other *EncryptedKey) bool {
|
||||
return bytes.Equal(e, *other)
|
||||
}
|
||||
|
||||
// CryptographyType represents different cryptography methods
|
||||
type CryptographyType int
|
||||
|
||||
const (
|
||||
CryptographyTypePlain CryptographyType = iota
|
||||
CryptographyTypeAesGcm256
|
||||
)
|
||||
|
||||
func (c CryptographyType) TargetKeySize() int {
|
||||
switch c {
|
||||
case CryptographyTypePlain:
|
||||
return 0 // Plain text has no limitation
|
||||
case CryptographyTypeAesGcm256:
|
||||
return 32
|
||||
default:
|
||||
return 0
|
||||
}
|
||||
}
|
||||
|
||||
// PlainKey is used to mark a byte slice as a plaintext key
|
||||
type PlainKey struct {
|
||||
tag CryptographyType
|
||||
key []byte
|
||||
}
|
||||
|
||||
func NewPlainKey(key []byte, t CryptographyType) (*PlainKey, error) {
|
||||
limitation := t.TargetKeySize()
|
||||
if limitation > 0 && len(key) != limitation {
|
||||
return nil, errors.Errorf("encryption method and key length mismatch, expect %d got %d", limitation, len(key))
|
||||
}
|
||||
return &PlainKey{key: key, tag: t}, nil
|
||||
}
|
||||
|
||||
func (p *PlainKey) KeyTag() CryptographyType {
|
||||
return p.tag
|
||||
}
|
||||
|
||||
func (p *PlainKey) Key() []byte {
|
||||
return p.key
|
||||
}
|
||||
104
br/pkg/kms/gcp.go
Normal file
104
br/pkg/kms/gcp.go
Normal file
@ -0,0 +1,104 @@
|
||||
// Copyright 2024 PingCAP, Inc. Licensed under Apache-2.0.
|
||||
|
||||
package kms
|
||||
|
||||
import (
|
||||
"context"
|
||||
"hash/crc32"
|
||||
"strings"
|
||||
|
||||
"cloud.google.com/go/kms/apiv1"
|
||||
"cloud.google.com/go/kms/apiv1/kmspb"
|
||||
"github.com/pingcap/errors"
|
||||
"github.com/pingcap/kvproto/pkg/encryptionpb"
|
||||
"github.com/pingcap/log"
|
||||
"go.uber.org/zap"
|
||||
"google.golang.org/api/option"
|
||||
"google.golang.org/protobuf/types/known/wrapperspb"
|
||||
)
|
||||
|
||||
const (
|
||||
// need to keep it exactly same as TiKV STORAGE_VENDOR_NAME_GCP in TiKV
|
||||
StorageVendorNameGcp = "gcp"
|
||||
)
|
||||
|
||||
type GcpKms struct {
|
||||
config *encryptionpb.MasterKeyKms
|
||||
// the location prefix of key id,
|
||||
// format: projects/{project_name}/locations/{location}
|
||||
location string
|
||||
client *kms.KeyManagementClient
|
||||
}
|
||||
|
||||
func NewGcpKms(config *encryptionpb.MasterKeyKms) (*GcpKms, error) {
|
||||
if config.GcpKms == nil {
|
||||
return nil, errors.New("GCP config is missing")
|
||||
}
|
||||
|
||||
// config string pattern verified at parsing flag phase, we should have valid string at this stage.
|
||||
config.KeyId = strings.TrimSuffix(config.KeyId, "/")
|
||||
|
||||
// join the first 4 parts of the key id to get the location
|
||||
location := strings.Join(strings.Split(config.KeyId, "/")[:4], "/")
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
var clientOpt option.ClientOption
|
||||
if config.GcpKms.Credential != "" {
|
||||
clientOpt = option.WithCredentialsFile(config.GcpKms.Credential)
|
||||
}
|
||||
|
||||
client, err := kms.NewKeyManagementClient(ctx, clientOpt)
|
||||
if err != nil {
|
||||
return nil, errors.Errorf("failed to create GCP KMS client: %v", err)
|
||||
}
|
||||
|
||||
return &GcpKms{
|
||||
config: config,
|
||||
location: location,
|
||||
client: client,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (g *GcpKms) Name() string {
|
||||
return StorageVendorNameGcp
|
||||
}
|
||||
|
||||
func (g *GcpKms) DecryptDataKey(ctx context.Context, dataKey []byte) ([]byte, error) {
|
||||
req := &kmspb.DecryptRequest{
|
||||
Name: g.config.KeyId,
|
||||
Ciphertext: dataKey,
|
||||
CiphertextCrc32C: wrapperspb.Int64(int64(g.calculateCRC32C(dataKey))),
|
||||
}
|
||||
|
||||
resp, err := g.client.Decrypt(ctx, req)
|
||||
if err != nil {
|
||||
return nil, errors.Annotate(err, "gcp kms decrypt request failed")
|
||||
}
|
||||
|
||||
if int64(g.calculateCRC32C(resp.Plaintext)) != resp.PlaintextCrc32C.Value {
|
||||
return nil, errors.New("response corrupted in-transit")
|
||||
}
|
||||
|
||||
return resp.Plaintext, nil
|
||||
}
|
||||
|
||||
func (g *GcpKms) checkCRC32(data []byte, expected int64) error {
|
||||
crc := int64(g.calculateCRC32C(data))
|
||||
if crc != expected {
|
||||
return errors.Errorf("crc32c mismatch, expected: %d, got: %d", expected, crc)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (g *GcpKms) calculateCRC32C(data []byte) uint32 {
|
||||
t := crc32.MakeTable(crc32.Castagnoli)
|
||||
return crc32.Checksum(data, t)
|
||||
}
|
||||
|
||||
func (g *GcpKms) Close() {
|
||||
err := g.client.Close()
|
||||
if err != nil {
|
||||
log.Error("failed to close gcp kms client", zap.Error(err))
|
||||
}
|
||||
}
|
||||
13
br/pkg/kms/kms.go
Normal file
13
br/pkg/kms/kms.go
Normal file
@ -0,0 +1,13 @@
|
||||
// Copyright 2024 PingCAP, Inc. Licensed under Apache-2.0.
|
||||
|
||||
package kms
|
||||
|
||||
import "context"
|
||||
|
||||
// Provider is an interface for key management service providers
|
||||
// implement encrypt data key in future if needed
|
||||
type Provider interface {
|
||||
DecryptDataKey(ctx context.Context, dataKey []byte) ([]byte, error)
|
||||
Name() string
|
||||
Close()
|
||||
}
|
||||
@ -14,6 +14,7 @@ go_library(
|
||||
"//br/pkg/logutil",
|
||||
"//br/pkg/storage",
|
||||
"//br/pkg/summary",
|
||||
"//br/pkg/utils",
|
||||
"//pkg/meta/model",
|
||||
"//pkg/statistics/handle",
|
||||
"//pkg/statistics/handle/types",
|
||||
@ -47,6 +48,7 @@ go_test(
|
||||
shard_count = 9,
|
||||
deps = [
|
||||
"//br/pkg/storage",
|
||||
"//br/pkg/utils",
|
||||
"//pkg/meta/model",
|
||||
"//pkg/parser/model",
|
||||
"//pkg/statistics/handle/types",
|
||||
|
||||
@ -23,6 +23,7 @@ import (
|
||||
"github.com/pingcap/tidb/br/pkg/logutil"
|
||||
"github.com/pingcap/tidb/br/pkg/storage"
|
||||
"github.com/pingcap/tidb/br/pkg/summary"
|
||||
"github.com/pingcap/tidb/br/pkg/utils"
|
||||
"github.com/pingcap/tidb/pkg/meta/model"
|
||||
"github.com/pingcap/tidb/pkg/statistics/handle/util"
|
||||
"github.com/pingcap/tidb/pkg/tablecodec"
|
||||
@ -87,22 +88,17 @@ func Encrypt(content []byte, cipher *backuppb.CipherInfo) (encryptedContent, iv
|
||||
}
|
||||
}
|
||||
|
||||
// Decrypt decrypts the content according to CipherInfo and IV.
|
||||
func Decrypt(content []byte, cipher *backuppb.CipherInfo, iv []byte) ([]byte, error) {
|
||||
if len(content) == 0 || cipher == nil {
|
||||
return content, nil
|
||||
func DecryptFullBackupMetaIfNeeded(metaData []byte, cipherInfo *backuppb.CipherInfo) ([]byte, error) {
|
||||
if cipherInfo == nil || !utils.IsEffectiveEncryptionMethod(cipherInfo.CipherType) {
|
||||
return metaData, nil
|
||||
}
|
||||
|
||||
switch cipher.CipherType {
|
||||
case encryptionpb.EncryptionMethod_PLAINTEXT:
|
||||
return content, nil
|
||||
case encryptionpb.EncryptionMethod_AES128_CTR,
|
||||
encryptionpb.EncryptionMethod_AES192_CTR,
|
||||
encryptionpb.EncryptionMethod_AES256_CTR:
|
||||
return encrypt.AESDecryptWithCTR(content, cipher.CipherKey, iv)
|
||||
default:
|
||||
return content, errors.Annotate(berrors.ErrInvalidArgument, "cipher type invalid")
|
||||
// the prefix of backup meta file is iv(16 bytes) for ctr mode if encryption method is valid
|
||||
iv := metaData[:CrypterIvLen]
|
||||
decryptBackupMeta, err := utils.Decrypt(metaData[len(iv):], cipherInfo, iv)
|
||||
if err != nil {
|
||||
return nil, errors.Annotate(err, "decrypt failed with wrong key")
|
||||
}
|
||||
return decryptBackupMeta, nil
|
||||
}
|
||||
|
||||
// walkLeafMetaFile walks the leaves of the given metafile, and deal with it by calling the function `output`.
|
||||
@ -130,7 +126,7 @@ func walkLeafMetaFile(
|
||||
return errors.Trace(err)
|
||||
}
|
||||
|
||||
decryptContent, err := Decrypt(content, cipher, node.CipherIv)
|
||||
decryptContent, err := utils.Decrypt(content, cipher, node.CipherIv)
|
||||
if err != nil {
|
||||
return errors.Trace(err)
|
||||
}
|
||||
|
||||
@ -11,6 +11,7 @@ import (
|
||||
backuppb "github.com/pingcap/kvproto/pkg/brpb"
|
||||
"github.com/pingcap/kvproto/pkg/encryptionpb"
|
||||
"github.com/pingcap/tidb/br/pkg/storage"
|
||||
"github.com/pingcap/tidb/br/pkg/utils"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
@ -203,14 +204,14 @@ func TestEncryptAndDecrypt(t *testing.T) {
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, originalData, encryptData)
|
||||
|
||||
decryptData, err := Decrypt(encryptData, &cipher, iv)
|
||||
decryptData, err := utils.Decrypt(encryptData, &cipher, iv)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, decryptData, originalData)
|
||||
} else {
|
||||
require.NoError(t, err)
|
||||
require.NotEqual(t, originalData, encryptData)
|
||||
|
||||
decryptData, err := Decrypt(encryptData, &cipher, iv)
|
||||
decryptData, err := utils.Decrypt(encryptData, &cipher, iv)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, decryptData, originalData)
|
||||
|
||||
@ -218,7 +219,7 @@ func TestEncryptAndDecrypt(t *testing.T) {
|
||||
CipherType: v.method,
|
||||
CipherKey: []byte(v.wrongKey),
|
||||
}
|
||||
decryptData, err = Decrypt(encryptData, &wrongCipher, iv)
|
||||
decryptData, err = utils.Decrypt(encryptData, &wrongCipher, iv)
|
||||
if len(v.rightKey) != len(v.wrongKey) {
|
||||
require.Error(t, err)
|
||||
} else {
|
||||
|
||||
@ -26,6 +26,7 @@ import (
|
||||
backuppb "github.com/pingcap/kvproto/pkg/brpb"
|
||||
berrors "github.com/pingcap/tidb/br/pkg/errors"
|
||||
"github.com/pingcap/tidb/br/pkg/storage"
|
||||
"github.com/pingcap/tidb/br/pkg/utils"
|
||||
"github.com/pingcap/tidb/pkg/meta/model"
|
||||
"github.com/pingcap/tidb/pkg/statistics/handle"
|
||||
statstypes "github.com/pingcap/tidb/pkg/statistics/handle/types"
|
||||
@ -201,7 +202,7 @@ func downloadStats(
|
||||
return errors.Trace(err)
|
||||
}
|
||||
|
||||
decryptContent, err := Decrypt(content, cipher, statsFile.CipherIv)
|
||||
decryptContent, err := utils.Decrypt(content, cipher, statsFile.CipherIv)
|
||||
if err != nil {
|
||||
return errors.Trace(err)
|
||||
}
|
||||
|
||||
@ -16,6 +16,7 @@ go_library(
|
||||
"//br/pkg/checksum",
|
||||
"//br/pkg/conn",
|
||||
"//br/pkg/conn/util",
|
||||
"//br/pkg/encryption",
|
||||
"//br/pkg/errors",
|
||||
"//br/pkg/glue",
|
||||
"//br/pkg/logutil",
|
||||
@ -49,6 +50,7 @@ go_library(
|
||||
"@com_github_pingcap_errors//:errors",
|
||||
"@com_github_pingcap_failpoint//:failpoint",
|
||||
"@com_github_pingcap_kvproto//pkg/brpb",
|
||||
"@com_github_pingcap_kvproto//pkg/encryptionpb",
|
||||
"@com_github_pingcap_kvproto//pkg/errorpb",
|
||||
"@com_github_pingcap_kvproto//pkg/import_sstpb",
|
||||
"@com_github_pingcap_kvproto//pkg/kvrpcpb",
|
||||
@ -113,6 +115,7 @@ go_test(
|
||||
"@com_github_pingcap_errors//:errors",
|
||||
"@com_github_pingcap_failpoint//:failpoint",
|
||||
"@com_github_pingcap_kvproto//pkg/brpb",
|
||||
"@com_github_pingcap_kvproto//pkg/encryptionpb",
|
||||
"@com_github_pingcap_kvproto//pkg/errorpb",
|
||||
"@com_github_pingcap_kvproto//pkg/import_sstpb",
|
||||
"@com_github_pingcap_kvproto//pkg/metapb",
|
||||
|
||||
@ -33,11 +33,13 @@ import (
|
||||
"github.com/pingcap/errors"
|
||||
"github.com/pingcap/failpoint"
|
||||
backuppb "github.com/pingcap/kvproto/pkg/brpb"
|
||||
"github.com/pingcap/kvproto/pkg/encryptionpb"
|
||||
"github.com/pingcap/log"
|
||||
"github.com/pingcap/tidb/br/pkg/checkpoint"
|
||||
"github.com/pingcap/tidb/br/pkg/checksum"
|
||||
"github.com/pingcap/tidb/br/pkg/conn"
|
||||
"github.com/pingcap/tidb/br/pkg/conn/util"
|
||||
"github.com/pingcap/tidb/br/pkg/encryption"
|
||||
"github.com/pingcap/tidb/br/pkg/glue"
|
||||
"github.com/pingcap/tidb/br/pkg/logutil"
|
||||
"github.com/pingcap/tidb/br/pkg/metautil"
|
||||
@ -293,13 +295,15 @@ func (rc *LogClient) InitCheckpointMetadataForLogRestore(
|
||||
return gcRatio, nil
|
||||
}
|
||||
|
||||
func (rc *LogClient) InstallLogFileManager(ctx context.Context, startTS, restoreTS uint64, metadataDownloadBatchSize uint) error {
|
||||
func (rc *LogClient) InstallLogFileManager(ctx context.Context, startTS, restoreTS uint64, metadataDownloadBatchSize uint,
|
||||
encryptionManager *encryption.Manager) error {
|
||||
init := LogFileManagerInit{
|
||||
StartTS: startTS,
|
||||
RestoreTS: restoreTS,
|
||||
Storage: rc.storage,
|
||||
|
||||
MetadataDownloadBatchSize: metadataDownloadBatchSize,
|
||||
EncryptionManager: encryptionManager,
|
||||
}
|
||||
var err error
|
||||
rc.LogFileManager, err = CreateLogFileManager(ctx, init)
|
||||
@ -436,7 +440,7 @@ func ApplyKVFilesWithBatchMethod(
|
||||
return nil
|
||||
}
|
||||
|
||||
func ApplyKVFilesWithSingelMethod(
|
||||
func ApplyKVFilesWithSingleMethod(
|
||||
ctx context.Context,
|
||||
files LogIter,
|
||||
applyFunc func(file []*LogDataFileInfo, kvCount int64, size uint64),
|
||||
@ -477,6 +481,8 @@ func (rc *LogClient) RestoreKVFiles(
|
||||
pitrBatchSize uint32,
|
||||
updateStats func(kvCount uint64, size uint64),
|
||||
onProgress func(cnt int64),
|
||||
cipherInfo *backuppb.CipherInfo,
|
||||
masterKeys []*encryptionpb.MasterKey,
|
||||
) error {
|
||||
var (
|
||||
err error
|
||||
@ -488,7 +494,7 @@ func (rc *LogClient) RestoreKVFiles(
|
||||
defer func() {
|
||||
if err == nil {
|
||||
elapsed := time.Since(start)
|
||||
log.Info("Restore KV files", zap.Duration("take", elapsed))
|
||||
log.Info("Restored KV files", zap.Duration("take", elapsed))
|
||||
summary.CollectSuccessUnit("files", fileCount, elapsed)
|
||||
}
|
||||
}()
|
||||
@ -548,7 +554,8 @@ func (rc *LogClient) RestoreKVFiles(
|
||||
}
|
||||
}()
|
||||
|
||||
return rc.fileImporter.ImportKVFiles(ectx, files, rule, rc.shiftStartTS, rc.startTS, rc.restoreTS, supportBatch)
|
||||
return rc.fileImporter.ImportKVFiles(ectx, files, rule, rc.shiftStartTS, rc.startTS, rc.restoreTS,
|
||||
supportBatch, cipherInfo, masterKeys)
|
||||
})
|
||||
}
|
||||
}
|
||||
@ -557,7 +564,7 @@ func (rc *LogClient) RestoreKVFiles(
|
||||
if supportBatch {
|
||||
err = ApplyKVFilesWithBatchMethod(ectx, logIter, int(pitrBatchCount), uint64(pitrBatchSize), applyFunc, &applyWg)
|
||||
} else {
|
||||
err = ApplyKVFilesWithSingelMethod(ectx, logIter, applyFunc, &applyWg)
|
||||
err = ApplyKVFilesWithSingleMethod(ectx, logIter, applyFunc, &applyWg)
|
||||
}
|
||||
return errors.Trace(err)
|
||||
})
|
||||
@ -619,18 +626,25 @@ func initFullBackupTables(
|
||||
ctx context.Context,
|
||||
s storage.ExternalStorage,
|
||||
tableFilter filter.Filter,
|
||||
cipherInfo *backuppb.CipherInfo,
|
||||
) (map[int64]*metautil.Table, error) {
|
||||
metaData, err := s.ReadFile(ctx, metautil.MetaFile)
|
||||
if err != nil {
|
||||
return nil, errors.Trace(err)
|
||||
}
|
||||
|
||||
backupMetaBytes, err := metautil.DecryptFullBackupMetaIfNeeded(metaData, cipherInfo)
|
||||
if err != nil {
|
||||
return nil, errors.Annotate(err, "decrypt failed with wrong key")
|
||||
}
|
||||
|
||||
backupMeta := &backuppb.BackupMeta{}
|
||||
if err = backupMeta.Unmarshal(metaData); err != nil {
|
||||
if err = backupMeta.Unmarshal(backupMetaBytes); err != nil {
|
||||
return nil, errors.Trace(err)
|
||||
}
|
||||
|
||||
// read full backup databases to get map[table]table.Info
|
||||
reader := metautil.NewMetaReader(backupMeta, s, nil)
|
||||
reader := metautil.NewMetaReader(backupMeta, s, cipherInfo)
|
||||
|
||||
databases, err := metautil.LoadBackupTables(ctx, reader, false)
|
||||
if err != nil {
|
||||
@ -684,6 +698,7 @@ const UnsafePITRLogRestoreStartBeforeAnyUpstreamUserDDL = "UNSAFE_PITR_LOG_RESTO
|
||||
func (rc *LogClient) generateDBReplacesFromFullBackupStorage(
|
||||
ctx context.Context,
|
||||
cfg *InitSchemaConfig,
|
||||
cipherInfo *backuppb.CipherInfo,
|
||||
) (map[stream.UpstreamID]*stream.DBReplace, error) {
|
||||
dbReplaces := make(map[stream.UpstreamID]*stream.DBReplace)
|
||||
if cfg.FullBackupStorage == nil {
|
||||
@ -698,7 +713,7 @@ func (rc *LogClient) generateDBReplacesFromFullBackupStorage(
|
||||
if err != nil {
|
||||
return nil, errors.Trace(err)
|
||||
}
|
||||
fullBackupTables, err := initFullBackupTables(ctx, s, cfg.TableFilter)
|
||||
fullBackupTables, err := initFullBackupTables(ctx, s, cfg.TableFilter, cipherInfo)
|
||||
if err != nil {
|
||||
return nil, errors.Trace(err)
|
||||
}
|
||||
@ -741,6 +756,7 @@ func (rc *LogClient) generateDBReplacesFromFullBackupStorage(
|
||||
func (rc *LogClient) InitSchemasReplaceForDDL(
|
||||
ctx context.Context,
|
||||
cfg *InitSchemaConfig,
|
||||
cipherInfo *backuppb.CipherInfo,
|
||||
) (*stream.SchemasReplace, error) {
|
||||
var (
|
||||
err error
|
||||
@ -785,7 +801,7 @@ func (rc *LogClient) InitSchemasReplaceForDDL(
|
||||
if len(dbMaps) <= 0 {
|
||||
log.Info("no id maps, build the table replaces from cluster and full backup schemas")
|
||||
needConstructIdMap = true
|
||||
dbReplaces, err = rc.generateDBReplacesFromFullBackupStorage(ctx, cfg)
|
||||
dbReplaces, err = rc.generateDBReplacesFromFullBackupStorage(ctx, cfg, cipherInfo)
|
||||
if err != nil {
|
||||
return nil, errors.Trace(err)
|
||||
}
|
||||
|
||||
@ -866,7 +866,7 @@ func TestApplyKVFilesWithSingelMethod(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
logclient.ApplyKVFilesWithSingelMethod(
|
||||
logclient.ApplyKVFilesWithSingleMethod(
|
||||
context.TODO(),
|
||||
toLogDataFileInfoIter(iter.FromSlice(ds)),
|
||||
applyFunc,
|
||||
@ -1293,7 +1293,7 @@ func TestApplyKVFilesWithBatchMethod5(t *testing.T) {
|
||||
require.Equal(t, backuppb.FileType_Delete, types[len(types)-1])
|
||||
|
||||
types = make([]backuppb.FileType, 0)
|
||||
logclient.ApplyKVFilesWithSingelMethod(
|
||||
logclient.ApplyKVFilesWithSingleMethod(
|
||||
context.TODO(),
|
||||
toLogDataFileInfoIter(iter.FromSlice(ds)),
|
||||
applyFunc,
|
||||
@ -1377,7 +1377,7 @@ func TestInitSchemasReplaceForDDL(t *testing.T) {
|
||||
{
|
||||
client := logclient.TEST_NewLogClient(123, 1, 2, 1, domain.NewMockDomain(), fakeSession{})
|
||||
cfg := &logclient.InitSchemaConfig{IsNewTask: false}
|
||||
_, err := client.InitSchemasReplaceForDDL(ctx, cfg)
|
||||
_, err := client.InitSchemasReplaceForDDL(ctx, cfg, nil)
|
||||
require.Error(t, err)
|
||||
require.Regexp(t, "failed to get pitr id map from mysql.tidb_pitr_id_map.* [2, 1]", err.Error())
|
||||
}
|
||||
@ -1385,7 +1385,7 @@ func TestInitSchemasReplaceForDDL(t *testing.T) {
|
||||
{
|
||||
client := logclient.TEST_NewLogClient(123, 1, 2, 1, domain.NewMockDomain(), fakeSession{})
|
||||
cfg := &logclient.InitSchemaConfig{IsNewTask: true}
|
||||
_, err := client.InitSchemasReplaceForDDL(ctx, cfg)
|
||||
_, err := client.InitSchemasReplaceForDDL(ctx, cfg, nil)
|
||||
require.Error(t, err)
|
||||
require.Regexp(t, "failed to get pitr id map from mysql.tidb_pitr_id_map.* [1, 1]", err.Error())
|
||||
}
|
||||
@ -1399,7 +1399,7 @@ func TestInitSchemasReplaceForDDL(t *testing.T) {
|
||||
require.NoError(t, err)
|
||||
client := logclient.TEST_NewLogClient(123, 1, 2, 1, domain.NewMockDomain(), se)
|
||||
cfg := &logclient.InitSchemaConfig{IsNewTask: true}
|
||||
_, err = client.InitSchemasReplaceForDDL(ctx, cfg)
|
||||
_, err = client.InitSchemasReplaceForDDL(ctx, cfg, nil)
|
||||
require.Error(t, err)
|
||||
require.Contains(t, err.Error(), "miss upstream table information at `start-ts`(1) but the full backup path is not specified")
|
||||
}
|
||||
|
||||
@ -19,6 +19,7 @@ import (
|
||||
|
||||
"github.com/pingcap/errors"
|
||||
backuppb "github.com/pingcap/kvproto/pkg/brpb"
|
||||
"github.com/pingcap/kvproto/pkg/encryptionpb"
|
||||
"github.com/pingcap/tidb/br/pkg/glue"
|
||||
"github.com/pingcap/tidb/br/pkg/storage"
|
||||
"github.com/pingcap/tidb/br/pkg/stream"
|
||||
@ -90,6 +91,7 @@ func (helper *FakeStreamMetadataHelper) ReadFile(
|
||||
length uint64,
|
||||
compressionType backuppb.CompressionType,
|
||||
storage storage.ExternalStorage,
|
||||
encryptionInfo *encryptionpb.FileEncryptionInfo,
|
||||
) ([]byte, error) {
|
||||
return helper.Data[offset : offset+length], nil
|
||||
}
|
||||
|
||||
@ -23,6 +23,7 @@ import (
|
||||
|
||||
"github.com/pingcap/errors"
|
||||
backuppb "github.com/pingcap/kvproto/pkg/brpb"
|
||||
"github.com/pingcap/kvproto/pkg/encryptionpb"
|
||||
"github.com/pingcap/kvproto/pkg/import_sstpb"
|
||||
"github.com/pingcap/kvproto/pkg/kvrpcpb"
|
||||
"github.com/pingcap/kvproto/pkg/metapb"
|
||||
@ -101,6 +102,8 @@ func (importer *LogFileImporter) ImportKVFiles(
|
||||
startTS uint64,
|
||||
restoreTS uint64,
|
||||
supportBatch bool,
|
||||
cipherInfo *backuppb.CipherInfo,
|
||||
masterKeys []*encryptionpb.MasterKey,
|
||||
) error {
|
||||
var (
|
||||
startKey []byte
|
||||
@ -111,7 +114,7 @@ func (importer *LogFileImporter) ImportKVFiles(
|
||||
|
||||
if !supportBatch && len(files) > 1 {
|
||||
return errors.Annotatef(berrors.ErrInvalidArgument,
|
||||
"do not support batch apply but files count:%v > 1", len(files))
|
||||
"do not support batch apply, file count: %v > 1", len(files))
|
||||
}
|
||||
log.Debug("import kv files", zap.Int("batch file count", len(files)))
|
||||
|
||||
@ -143,7 +146,8 @@ func (importer *LogFileImporter) ImportKVFiles(
|
||||
if len(subfiles) == 0 {
|
||||
return RPCResultOK()
|
||||
}
|
||||
return importer.importKVFileForRegion(ctx, subfiles, rule, shiftStartTS, startTS, restoreTS, r, supportBatch)
|
||||
return importer.importKVFileForRegion(ctx, subfiles, rule, shiftStartTS, startTS, restoreTS, r, supportBatch,
|
||||
cipherInfo, masterKeys)
|
||||
})
|
||||
return errors.Trace(err)
|
||||
}
|
||||
@ -184,9 +188,11 @@ func (importer *LogFileImporter) importKVFileForRegion(
|
||||
restoreTS uint64,
|
||||
info *split.RegionInfo,
|
||||
supportBatch bool,
|
||||
cipherInfo *backuppb.CipherInfo,
|
||||
masterKeys []*encryptionpb.MasterKey,
|
||||
) RPCResult {
|
||||
// Try to download file.
|
||||
result := importer.downloadAndApplyKVFile(ctx, files, rule, info, shiftStartTS, startTS, restoreTS, supportBatch)
|
||||
result := importer.downloadAndApplyKVFile(ctx, files, rule, info, shiftStartTS, startTS, restoreTS, supportBatch, cipherInfo, masterKeys)
|
||||
if !result.OK() {
|
||||
errDownload := result.Err
|
||||
for _, e := range multierr.Errors(errDownload) {
|
||||
@ -216,7 +222,8 @@ func (importer *LogFileImporter) downloadAndApplyKVFile(
|
||||
startTS uint64,
|
||||
restoreTS uint64,
|
||||
supportBatch bool,
|
||||
) RPCResult {
|
||||
cipherInfo *backuppb.CipherInfo,
|
||||
masterKeys []*encryptionpb.MasterKey) RPCResult {
|
||||
leader := regionInfo.Leader
|
||||
if leader == nil {
|
||||
return RPCResultFromError(errors.Annotatef(berrors.ErrPDLeaderNotFound,
|
||||
@ -251,11 +258,12 @@ func (importer *LogFileImporter) downloadAndApplyKVFile(
|
||||
}
|
||||
return startTS
|
||||
}(),
|
||||
RestoreTs: restoreTS,
|
||||
StartKey: regionInfo.Region.GetStartKey(),
|
||||
EndKey: regionInfo.Region.GetEndKey(),
|
||||
Sha256: file.GetSha256(),
|
||||
CompressionType: file.CompressionType,
|
||||
RestoreTs: restoreTS,
|
||||
StartKey: regionInfo.Region.GetStartKey(),
|
||||
EndKey: regionInfo.Region.GetEndKey(),
|
||||
Sha256: file.GetSha256(),
|
||||
CompressionType: file.CompressionType,
|
||||
FileEncryptionInfo: file.FileEncryptionInfo,
|
||||
}
|
||||
|
||||
metas = append(metas, meta)
|
||||
@ -276,6 +284,8 @@ func (importer *LogFileImporter) downloadAndApplyKVFile(
|
||||
RewriteRules: rewriteRules,
|
||||
Context: reqCtx,
|
||||
StorageCacheId: importer.cacheKey,
|
||||
CipherInfo: cipherInfo,
|
||||
MasterKeys: masterKeys,
|
||||
}
|
||||
} else {
|
||||
req = &import_sstpb.ApplyRequest{
|
||||
@ -284,16 +294,18 @@ func (importer *LogFileImporter) downloadAndApplyKVFile(
|
||||
RewriteRule: *rewriteRules[0],
|
||||
Context: reqCtx,
|
||||
StorageCacheId: importer.cacheKey,
|
||||
CipherInfo: cipherInfo,
|
||||
MasterKeys: masterKeys,
|
||||
}
|
||||
}
|
||||
|
||||
log.Debug("apply kv file", logutil.Leader(leader))
|
||||
log.Debug("applying kv file", logutil.Leader(leader))
|
||||
resp, err := importer.importClient.ApplyKVFile(ctx, leader.GetStoreId(), req)
|
||||
if err != nil {
|
||||
return RPCResultFromError(errors.Trace(err))
|
||||
}
|
||||
if resp.GetError() != nil {
|
||||
logutil.CL(ctx).Warn("import meet error", zap.Stringer("error", resp.GetError()))
|
||||
logutil.CL(ctx).Warn("import has error", zap.Stringer("error", resp.GetError()))
|
||||
return RPCResultFromPBError(resp.GetError())
|
||||
}
|
||||
return RPCResultOK()
|
||||
|
||||
@ -60,6 +60,7 @@ func TestImportKVFiles(t *testing.T) {
|
||||
startTS,
|
||||
restoreTS,
|
||||
false,
|
||||
nil, nil,
|
||||
)
|
||||
require.True(t, berrors.ErrInvalidArgument.Equal(err))
|
||||
}
|
||||
@ -268,9 +269,9 @@ func TestFileImporter(t *testing.T) {
|
||||
require.NoError(t, err)
|
||||
|
||||
rewriteRules, encodeKeyFiles := prepareData()
|
||||
err = importer.ImportKVFiles(ctx, encodeKeyFiles, rewriteRules, 1, 1, 1, true)
|
||||
err = importer.ImportKVFiles(ctx, encodeKeyFiles, rewriteRules, 1, 1, 1, true, nil, nil)
|
||||
require.NoError(t, err)
|
||||
|
||||
err = importer.ImportKVFiles(ctx, encodeKeyFiles, rewriteRules, 1, 1, 1, false)
|
||||
err = importer.ImportKVFiles(ctx, encodeKeyFiles, rewriteRules, 1, 1, 1, false, nil, nil)
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
@ -12,7 +12,9 @@ import (
|
||||
|
||||
"github.com/pingcap/errors"
|
||||
backuppb "github.com/pingcap/kvproto/pkg/brpb"
|
||||
"github.com/pingcap/kvproto/pkg/encryptionpb"
|
||||
"github.com/pingcap/log"
|
||||
"github.com/pingcap/tidb/br/pkg/encryption"
|
||||
berrors "github.com/pingcap/tidb/br/pkg/errors"
|
||||
"github.com/pingcap/tidb/br/pkg/storage"
|
||||
"github.com/pingcap/tidb/br/pkg/stream"
|
||||
@ -54,6 +56,7 @@ type streamMetadataHelper interface {
|
||||
length uint64,
|
||||
compressionType backuppb.CompressionType,
|
||||
storage storage.ExternalStorage,
|
||||
encryptionInfo *encryptionpb.FileEncryptionInfo,
|
||||
) ([]byte, error)
|
||||
ParseToMetadata(rawMetaData []byte) (*backuppb.Metadata, error)
|
||||
}
|
||||
@ -85,6 +88,7 @@ type LogFileManagerInit struct {
|
||||
Storage storage.ExternalStorage
|
||||
|
||||
MetadataDownloadBatchSize uint
|
||||
EncryptionManager *encryption.Manager
|
||||
}
|
||||
|
||||
type DDLMetaGroup struct {
|
||||
@ -99,7 +103,7 @@ func CreateLogFileManager(ctx context.Context, init LogFileManagerInit) (*LogFil
|
||||
startTS: init.StartTS,
|
||||
restoreTS: init.RestoreTS,
|
||||
storage: init.Storage,
|
||||
helper: stream.NewMetadataHelper(),
|
||||
helper: stream.NewMetadataHelper(stream.WithEncryptionManager(init.EncryptionManager)),
|
||||
|
||||
metadataDownloadBatchSize: init.MetadataDownloadBatchSize,
|
||||
}
|
||||
@ -329,7 +333,8 @@ func (rc *LogFileManager) ReadAllEntries(
|
||||
kvEntries := make([]*KvEntryWithTS, 0)
|
||||
nextKvEntries := make([]*KvEntryWithTS, 0)
|
||||
|
||||
buff, err := rc.helper.ReadFile(ctx, file.Path, file.RangeOffset, file.RangeLength, file.CompressionType, rc.storage)
|
||||
buff, err := rc.helper.ReadFile(ctx, file.Path, file.RangeOffset, file.RangeLength, file.CompressionType,
|
||||
rc.storage, file.FileEncryptionInfo)
|
||||
if err != nil {
|
||||
return nil, nil, errors.Trace(err)
|
||||
}
|
||||
|
||||
@ -647,7 +647,7 @@ func (rc *SnapClient) CreateDatabases(ctx context.Context, dbs []*metautil.Datab
|
||||
return nil
|
||||
}
|
||||
|
||||
log.Info("create databases in db pool", zap.Int("pool size", len(rc.dbPool)))
|
||||
log.Info("create databases in db pool", zap.Int("pool size", len(rc.dbPool)), zap.Int("number of db", len(dbs)))
|
||||
eg, ectx := errgroup.WithContext(ctx)
|
||||
workers := tidbutil.NewWorkerPool(uint(len(rc.dbPool)), "DB DDL workers")
|
||||
for _, db_ := range dbs {
|
||||
|
||||
@ -15,6 +15,7 @@ go_library(
|
||||
importpath = "github.com/pingcap/tidb/br/pkg/stream",
|
||||
visibility = ["//visibility:public"],
|
||||
deps = [
|
||||
"//br/pkg/encryption",
|
||||
"//br/pkg/errors",
|
||||
"//br/pkg/glue",
|
||||
"//br/pkg/httputil",
|
||||
@ -36,6 +37,7 @@ go_library(
|
||||
"@com_github_klauspost_compress//zstd",
|
||||
"@com_github_pingcap_errors//:errors",
|
||||
"@com_github_pingcap_kvproto//pkg/brpb",
|
||||
"@com_github_pingcap_kvproto//pkg/encryptionpb",
|
||||
"@com_github_pingcap_kvproto//pkg/metapb",
|
||||
"@com_github_pingcap_log//:log",
|
||||
"@com_github_tikv_client_go_v2//oracle",
|
||||
|
||||
@ -16,12 +16,16 @@ package stream
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/sha256"
|
||||
"encoding/hex"
|
||||
"strings"
|
||||
|
||||
"github.com/klauspost/compress/zstd"
|
||||
"github.com/pingcap/errors"
|
||||
backuppb "github.com/pingcap/kvproto/pkg/brpb"
|
||||
"github.com/pingcap/kvproto/pkg/encryptionpb"
|
||||
"github.com/pingcap/log"
|
||||
"github.com/pingcap/tidb/br/pkg/encryption"
|
||||
"github.com/pingcap/tidb/br/pkg/storage"
|
||||
"github.com/pingcap/tidb/pkg/kv"
|
||||
"github.com/pingcap/tidb/pkg/meta"
|
||||
@ -157,16 +161,31 @@ type ContentRef struct {
|
||||
|
||||
// MetadataHelper make restore/truncate compatible with metadataV1 and metadataV2.
|
||||
type MetadataHelper struct {
|
||||
cache map[string]*ContentRef
|
||||
decoder *zstd.Decoder
|
||||
cache map[string]*ContentRef
|
||||
decoder *zstd.Decoder
|
||||
encryptionManager *encryption.Manager
|
||||
}
|
||||
|
||||
func NewMetadataHelper() *MetadataHelper {
|
||||
type MetadataHelperOption func(*MetadataHelper)
|
||||
|
||||
func WithEncryptionManager(manager *encryption.Manager) MetadataHelperOption {
|
||||
return func(mh *MetadataHelper) {
|
||||
mh.encryptionManager = manager
|
||||
}
|
||||
}
|
||||
|
||||
func NewMetadataHelper(opts ...MetadataHelperOption) *MetadataHelper {
|
||||
decoder, _ := zstd.NewReader(nil)
|
||||
return &MetadataHelper{
|
||||
helper := &MetadataHelper{
|
||||
cache: make(map[string]*ContentRef),
|
||||
decoder: decoder,
|
||||
}
|
||||
|
||||
for _, opt := range opts {
|
||||
opt(helper)
|
||||
}
|
||||
|
||||
return helper
|
||||
}
|
||||
|
||||
func (m *MetadataHelper) InitCacheEntry(path string, ref int) {
|
||||
@ -191,6 +210,35 @@ func (m *MetadataHelper) decodeCompressedData(data []byte, compressionType backu
|
||||
"failed to decode compressed data: compression type is unimplemented. type id is %d", compressionType)
|
||||
}
|
||||
|
||||
func (m *MetadataHelper) verifyChecksumAndDecryptIfNeeded(ctx context.Context, data []byte,
|
||||
encryptionInfo *encryptionpb.FileEncryptionInfo) ([]byte, error) {
|
||||
// no need to decrypt
|
||||
if encryptionInfo == nil {
|
||||
return data, nil
|
||||
}
|
||||
|
||||
if m.encryptionManager == nil {
|
||||
return nil, errors.New("need to decrypt data but encryption manager not set")
|
||||
}
|
||||
|
||||
// Verify checksum before decryption
|
||||
if encryptionInfo.Checksum != nil {
|
||||
actualChecksum := sha256.Sum256(data)
|
||||
expectedChecksumHex := hex.EncodeToString(encryptionInfo.Checksum)
|
||||
actualChecksumHex := hex.EncodeToString(actualChecksum[:])
|
||||
if expectedChecksumHex != actualChecksumHex {
|
||||
return nil, errors.Errorf("checksum mismatch before decryption, expected %s, actual %s",
|
||||
expectedChecksumHex, actualChecksumHex)
|
||||
}
|
||||
}
|
||||
|
||||
decryptedContent, err := m.encryptionManager.Decrypt(ctx, data, encryptionInfo)
|
||||
if err != nil {
|
||||
return nil, errors.Trace(err)
|
||||
}
|
||||
return decryptedContent, nil
|
||||
}
|
||||
|
||||
func (m *MetadataHelper) ReadFile(
|
||||
ctx context.Context,
|
||||
path string,
|
||||
@ -198,6 +246,7 @@ func (m *MetadataHelper) ReadFile(
|
||||
length uint64,
|
||||
compressionType backuppb.CompressionType,
|
||||
storage storage.ExternalStorage,
|
||||
encryptionInfo *encryptionpb.FileEncryptionInfo,
|
||||
) ([]byte, error) {
|
||||
var err error
|
||||
cref, exist := m.cache[path]
|
||||
@ -212,7 +261,12 @@ func (m *MetadataHelper) ReadFile(
|
||||
if err != nil {
|
||||
return nil, errors.Trace(err)
|
||||
}
|
||||
return m.decodeCompressedData(data, compressionType)
|
||||
// decrypt if needed
|
||||
decryptedData, err := m.verifyChecksumAndDecryptIfNeeded(ctx, data, encryptionInfo)
|
||||
if err != nil {
|
||||
return nil, errors.Trace(err)
|
||||
}
|
||||
return m.decodeCompressedData(decryptedData, compressionType)
|
||||
}
|
||||
|
||||
cref.ref -= 1
|
||||
@ -223,8 +277,12 @@ func (m *MetadataHelper) ReadFile(
|
||||
return nil, errors.Trace(err)
|
||||
}
|
||||
}
|
||||
|
||||
buf, err := m.decodeCompressedData(cref.data[offset:offset+length], compressionType)
|
||||
// decrypt if needed
|
||||
decryptedData, err := m.verifyChecksumAndDecryptIfNeeded(ctx, cref.data[offset:offset+length], encryptionInfo)
|
||||
if err != nil {
|
||||
return nil, errors.Trace(err)
|
||||
}
|
||||
buf, err := m.decodeCompressedData(decryptedData, compressionType)
|
||||
|
||||
if cref.ref <= 0 {
|
||||
// need reset reference information.
|
||||
|
||||
@ -66,14 +66,14 @@ func TestMetadataHelperReadFile(t *testing.T) {
|
||||
require.NoError(t, err)
|
||||
|
||||
helper.InitCacheEntry(filename2, 2)
|
||||
get_data, err := helper.ReadFile(ctx, filename1, 0, 0, backuppb.CompressionType_UNKNOWN, s)
|
||||
get_data, err := helper.ReadFile(ctx, filename1, 0, 0, backuppb.CompressionType_UNKNOWN, s, nil)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, data1, get_data)
|
||||
get_data, err = helper.ReadFile(ctx, filename2, 0, uint64(len(data1)), backuppb.CompressionType_UNKNOWN, s)
|
||||
get_data, err = helper.ReadFile(ctx, filename2, 0, uint64(len(data1)), backuppb.CompressionType_UNKNOWN, s, nil)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, data1, get_data)
|
||||
get_data, err = helper.ReadFile(ctx, filename2, uint64(len(data1)), uint64(len(data2)),
|
||||
backuppb.CompressionType_ZSTD, s)
|
||||
backuppb.CompressionType_ZSTD, s, nil)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, data1, get_data)
|
||||
}
|
||||
|
||||
@ -8,6 +8,7 @@ go_library(
|
||||
"backup_raw.go",
|
||||
"backup_txn.go",
|
||||
"common.go",
|
||||
"encryption.go",
|
||||
"restore.go",
|
||||
"restore_data.go",
|
||||
"restore_ebs_meta.go",
|
||||
@ -27,6 +28,7 @@ go_library(
|
||||
"//br/pkg/config",
|
||||
"//br/pkg/conn",
|
||||
"//br/pkg/conn/util",
|
||||
"//br/pkg/encryption",
|
||||
"//br/pkg/errors",
|
||||
"//br/pkg/glue",
|
||||
"//br/pkg/httputil",
|
||||
@ -106,12 +108,13 @@ go_test(
|
||||
"backup_test.go",
|
||||
"common_test.go",
|
||||
"config_test.go",
|
||||
"encryption_test.go",
|
||||
"restore_test.go",
|
||||
"stream_test.go",
|
||||
],
|
||||
embed = [":task"],
|
||||
flaky = True,
|
||||
shard_count = 33,
|
||||
shard_count = 38,
|
||||
deps = [
|
||||
"//br/pkg/backup",
|
||||
"//br/pkg/config",
|
||||
|
||||
@ -72,7 +72,6 @@ const (
|
||||
TableBackupCmd = "Table Backup"
|
||||
RawBackupCmd = "Raw Backup"
|
||||
TxnBackupCmd = "Txn Backup"
|
||||
EBSBackupCmd = "EBS Backup"
|
||||
)
|
||||
|
||||
// CompressionConfig is the configuration for sst file compression.
|
||||
|
||||
@ -91,9 +91,13 @@ const (
|
||||
defaultGRPCKeepaliveTimeout = 3 * time.Second
|
||||
defaultCloudAPIConcurrency = 8
|
||||
|
||||
flagCipherType = "crypter.method"
|
||||
flagCipherKey = "crypter.key"
|
||||
flagCipherKeyFile = "crypter.key-file"
|
||||
flagFullBackupCipherType = "crypter.method"
|
||||
flagFullBackupCipherKey = "crypter.key"
|
||||
flagFullBackupCipherKeyFile = "crypter.key-file"
|
||||
|
||||
flagLogBackupCipherType = "log.crypter.method"
|
||||
flagLogBackupCipherKey = "log.crypter.key"
|
||||
flagLogBackupCipherKeyFile = "log.crypter.key-file"
|
||||
|
||||
flagMetadataDownloadBatchSize = "metadata-download-batch-size"
|
||||
defaultMetadataDownloadBatchSize = 128
|
||||
@ -104,6 +108,10 @@ const (
|
||||
crypterAES256KeyLen = 32
|
||||
|
||||
flagFullBackupType = "type"
|
||||
|
||||
masterKeysDelimiter = ","
|
||||
flagMasterKeyConfig = "master-key"
|
||||
flagMasterKeyCipherType = "master-key-crypter-method"
|
||||
)
|
||||
|
||||
const (
|
||||
@ -260,8 +268,17 @@ type Config struct {
|
||||
// GrpcKeepaliveTimeout is the max time a grpc conn can keep idel before killed.
|
||||
GRPCKeepaliveTimeout time.Duration `json:"grpc-keepalive-timeout" toml:"grpc-keepalive-timeout"`
|
||||
|
||||
// Plaintext data key mainly used for full/snapshot backup and restore.
|
||||
CipherInfo backuppb.CipherInfo `json:"-" toml:"-"`
|
||||
|
||||
// Could be used in log backup and restore but not recommended in a serious environment since data key is stored
|
||||
// in PD in plaintext.
|
||||
LogBackupCipherInfo backuppb.CipherInfo `json:"-" toml:"-"`
|
||||
|
||||
// Master key based encryption for log restore.
|
||||
// More than one can be specified for log restore if master key rotated during log backup.
|
||||
MasterKeyConfig backuppb.MasterKeyConfig `json:"master-key-config" toml:"master-key-config"`
|
||||
|
||||
// whether there's explicit filter
|
||||
ExplicitFilter bool `json:"-" toml:"-"`
|
||||
|
||||
@ -310,17 +327,34 @@ func DefineCommonFlags(flags *pflag.FlagSet) {
|
||||
flags.BoolP(flagSkipCheckPath, "", false, "Skip path verification")
|
||||
_ = flags.MarkHidden(flagSkipCheckPath)
|
||||
|
||||
flags.String(flagCipherType, "plaintext", "Encrypt/decrypt method, "+
|
||||
flags.String(flagFullBackupCipherType, "plaintext", "Encrypt/decrypt method, "+
|
||||
"be one of plaintext|aes128-ctr|aes192-ctr|aes256-ctr case-insensitively, "+
|
||||
"\"plaintext\" represents no encrypt/decrypt")
|
||||
flags.String(flagCipherKey, "",
|
||||
flags.String(flagFullBackupCipherKey, "",
|
||||
"aes-crypter key, used to encrypt/decrypt the data "+
|
||||
"by the hexadecimal string, eg: \"0123456789abcdef0123456789abcdef\"")
|
||||
flags.String(flagCipherKeyFile, "", "FilePath, its content is used as the cipher-key")
|
||||
flags.String(flagFullBackupCipherKeyFile, "", "FilePath, its content is used as the cipher-key")
|
||||
|
||||
flags.Uint(flagMetadataDownloadBatchSize, defaultMetadataDownloadBatchSize,
|
||||
"the batch size of downloading metadata, such as log restore metadata for truncate or restore")
|
||||
|
||||
// log backup plaintext key flags
|
||||
flags.String(flagLogBackupCipherType, "plaintext", "Encrypt/decrypt method, "+
|
||||
"be one of plaintext|aes128-ctr|aes192-ctr|aes256-ctr case-insensitively, "+
|
||||
"\"plaintext\" represents no encrypt/decrypt")
|
||||
flags.String(flagLogBackupCipherKey, "",
|
||||
"aes-crypter key, used to encrypt/decrypt the data "+
|
||||
"by the hexadecimal string, eg: \"0123456789abcdef0123456789abcdef\"")
|
||||
flags.String(flagLogBackupCipherKeyFile, "", "FilePath, its content is used as the cipher-key")
|
||||
|
||||
// master key config
|
||||
flags.String(flagMasterKeyCipherType, "plaintext", "Encrypt/decrypt method, "+
|
||||
"be one of plaintext|aes128-ctr|aes192-ctr|aes256-ctr case-insensitively, "+
|
||||
"\"plaintext\" represents no encrypt/decrypt")
|
||||
flags.String(flagMasterKeyConfig, "", "Master key config for point in time restore "+
|
||||
"examples: \"local:///path/to/master/key/file,"+
|
||||
"aws-kms:///{key-id}?AWS_ACCESS_KEY_ID={access-key}&AWS_SECRET_ACCESS_KEY={secret-key}®ION={region},"+
|
||||
"gcp-kms:///projects/{project-id}/locations/{location}/keyRings/{keyring}/cryptoKeys/{key-name}?AUTH=specified&CREDENTIALS={credentials}\"")
|
||||
_ = flags.MarkHidden(flagMetadataDownloadBatchSize)
|
||||
|
||||
storage.DefineFlags(flags)
|
||||
@ -334,10 +368,15 @@ func HiddenFlagsForStream(flags *pflag.FlagSet) {
|
||||
_ = flags.MarkHidden(flagRateLimit)
|
||||
_ = flags.MarkHidden(flagRateLimitUnit)
|
||||
_ = flags.MarkHidden(flagRemoveTiFlash)
|
||||
_ = flags.MarkHidden(flagCipherType)
|
||||
_ = flags.MarkHidden(flagCipherKey)
|
||||
_ = flags.MarkHidden(flagCipherKeyFile)
|
||||
_ = flags.MarkHidden(flagFullBackupCipherType)
|
||||
_ = flags.MarkHidden(flagFullBackupCipherKey)
|
||||
_ = flags.MarkHidden(flagFullBackupCipherKeyFile)
|
||||
_ = flags.MarkHidden(flagLogBackupCipherType)
|
||||
_ = flags.MarkHidden(flagLogBackupCipherKey)
|
||||
_ = flags.MarkHidden(flagLogBackupCipherKeyFile)
|
||||
_ = flags.MarkHidden(flagSwitchModeInterval)
|
||||
_ = flags.MarkHidden(flagMasterKeyConfig)
|
||||
_ = flags.MarkHidden(flagMasterKeyCipherType)
|
||||
|
||||
storage.HiddenFlagsForStream(flags)
|
||||
}
|
||||
@ -456,7 +495,7 @@ func checkCipherKeyMatch(cipher *backuppb.CipherInfo) bool {
|
||||
}
|
||||
|
||||
func (cfg *Config) parseCipherInfo(flags *pflag.FlagSet) error {
|
||||
crypterStr, err := flags.GetString(flagCipherType)
|
||||
crypterStr, err := flags.GetString(flagFullBackupCipherType)
|
||||
if err != nil {
|
||||
return errors.Trace(err)
|
||||
}
|
||||
@ -470,12 +509,12 @@ func (cfg *Config) parseCipherInfo(flags *pflag.FlagSet) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
key, err := flags.GetString(flagCipherKey)
|
||||
key, err := flags.GetString(flagFullBackupCipherKey)
|
||||
if err != nil {
|
||||
return errors.Trace(err)
|
||||
}
|
||||
|
||||
keyFilePath, err := flags.GetString(flagCipherKeyFile)
|
||||
keyFilePath, err := flags.GetString(flagFullBackupCipherKeyFile)
|
||||
if err != nil {
|
||||
return errors.Trace(err)
|
||||
}
|
||||
@ -492,6 +531,43 @@ func (cfg *Config) parseCipherInfo(flags *pflag.FlagSet) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (cfg *Config) parseLogBackupCipherInfo(flags *pflag.FlagSet) (bool, error) {
|
||||
crypterStr, err := flags.GetString(flagLogBackupCipherType)
|
||||
if err != nil {
|
||||
return false, errors.Trace(err)
|
||||
}
|
||||
|
||||
cfg.LogBackupCipherInfo.CipherType, err = parseCipherType(crypterStr)
|
||||
if err != nil {
|
||||
return false, errors.Trace(err)
|
||||
}
|
||||
|
||||
if !utils.IsEffectiveEncryptionMethod(cfg.LogBackupCipherInfo.CipherType) {
|
||||
return false, nil
|
||||
}
|
||||
|
||||
key, err := flags.GetString(flagLogBackupCipherKey)
|
||||
if err != nil {
|
||||
return false, errors.Trace(err)
|
||||
}
|
||||
|
||||
keyFilePath, err := flags.GetString(flagLogBackupCipherKeyFile)
|
||||
if err != nil {
|
||||
return false, errors.Trace(err)
|
||||
}
|
||||
|
||||
cfg.LogBackupCipherInfo.CipherKey, err = GetCipherKeyContent(key, keyFilePath)
|
||||
if err != nil {
|
||||
return false, errors.Trace(err)
|
||||
}
|
||||
|
||||
if !checkCipherKeyMatch(&cfg.CipherInfo) {
|
||||
return false, errors.Annotate(berrors.ErrInvalidArgument, "log backup encryption method and key length not match")
|
||||
}
|
||||
|
||||
return true, nil
|
||||
}
|
||||
|
||||
func (cfg *Config) normalizePDURLs() error {
|
||||
for i := range cfg.PD {
|
||||
var err error
|
||||
@ -618,7 +694,17 @@ func (cfg *Config) ParseFromFlags(flags *pflag.FlagSet) error {
|
||||
log.L().Info("--skip-check-path is deprecated, need explicitly set it anymore")
|
||||
}
|
||||
|
||||
if err = cfg.parseCipherInfo(flags); err != nil {
|
||||
err = cfg.parseCipherInfo(flags)
|
||||
if err != nil {
|
||||
return errors.Trace(err)
|
||||
}
|
||||
|
||||
hasLogBackupPlaintextKey, err := cfg.parseLogBackupCipherInfo(flags)
|
||||
if err != nil {
|
||||
return errors.Trace(err)
|
||||
}
|
||||
|
||||
if err = cfg.parseAndValidateMasterKeyInfo(hasLogBackupPlaintextKey, flags); err != nil {
|
||||
return errors.Trace(err)
|
||||
}
|
||||
|
||||
@ -629,6 +715,51 @@ func (cfg *Config) ParseFromFlags(flags *pflag.FlagSet) error {
|
||||
return cfg.normalizePDURLs()
|
||||
}
|
||||
|
||||
func (cfg *Config) parseAndValidateMasterKeyInfo(hasPlaintextKey bool, flags *pflag.FlagSet) error {
|
||||
masterKeyString, err := flags.GetString(flagMasterKeyConfig)
|
||||
if err != nil {
|
||||
return errors.Errorf("master key flag '%s' is not defined: %v", flagMasterKeyConfig, err)
|
||||
}
|
||||
|
||||
if masterKeyString == "" {
|
||||
return nil
|
||||
}
|
||||
|
||||
if hasPlaintextKey {
|
||||
return errors.Errorf("invalid argument: both plaintext data key encryption and master key based encryption are set at the same time")
|
||||
}
|
||||
|
||||
encryptionMethodString, err := flags.GetString(flagMasterKeyCipherType)
|
||||
if err != nil {
|
||||
return errors.Errorf("encryption method flag '%s' is not defined: %v", flagMasterKeyCipherType, err)
|
||||
}
|
||||
|
||||
encryptionMethod, err := parseCipherType(encryptionMethodString)
|
||||
if err != nil {
|
||||
return errors.Errorf("failed to parse encryption method: %v", err)
|
||||
}
|
||||
|
||||
if !utils.IsEffectiveEncryptionMethod(encryptionMethod) {
|
||||
return errors.Errorf("invalid encryption method: %s", encryptionMethodString)
|
||||
}
|
||||
|
||||
masterKeyStrings := strings.Split(masterKeyString, masterKeysDelimiter)
|
||||
cfg.MasterKeyConfig = backuppb.MasterKeyConfig{
|
||||
EncryptionType: encryptionMethod,
|
||||
MasterKeys: make([]*encryptionpb.MasterKey, 0, len(masterKeyStrings)),
|
||||
}
|
||||
|
||||
for _, keyString := range masterKeyStrings {
|
||||
masterKey, err := validateAndParseMasterKeyString(strings.TrimSpace(keyString))
|
||||
if err != nil {
|
||||
return errors.Wrapf(err, "invalid master key configuration: %s", keyString)
|
||||
}
|
||||
cfg.MasterKeyConfig.MasterKeys = append(cfg.MasterKeyConfig.MasterKeys, &masterKey)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// NewMgr creates a new mgr at the given PD address.
|
||||
func NewMgr(ctx context.Context,
|
||||
g glue.Glue, pds []string,
|
||||
@ -726,7 +857,7 @@ func ReadBackupMeta(
|
||||
if cfg.CipherInfo.CipherType != encryptionpb.EncryptionMethod_PLAINTEXT {
|
||||
iv = metaData[:metautil.CrypterIvLen]
|
||||
}
|
||||
decryptBackupMeta, err := metautil.Decrypt(metaData[len(iv):], &cfg.CipherInfo, iv)
|
||||
decryptBackupMeta, err := utils.Decrypt(metaData[len(iv):], &cfg.CipherInfo, iv)
|
||||
if err != nil {
|
||||
return nil, nil, nil, errors.Annotate(err, "decrypt failed with wrong key")
|
||||
}
|
||||
|
||||
@ -185,6 +185,7 @@ func expectedDefaultConfig() Config {
|
||||
GRPCKeepaliveTime: 10000000000,
|
||||
GRPCKeepaliveTimeout: 3000000000,
|
||||
CipherInfo: backup.CipherInfo{CipherType: 1},
|
||||
LogBackupCipherInfo: backup.CipherInfo{CipherType: 1},
|
||||
MetadataDownloadBatchSize: 0x80,
|
||||
}
|
||||
}
|
||||
@ -241,3 +242,132 @@ func TestDefaultRestore(t *testing.T) {
|
||||
defaultConfig := expectedDefaultRestoreConfig()
|
||||
require.Equal(t, defaultConfig, def)
|
||||
}
|
||||
|
||||
func TestParseAndValidateMasterKeyInfo(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
input string
|
||||
expectedKeys []*encryptionpb.MasterKey
|
||||
expectError bool
|
||||
}{
|
||||
{
|
||||
name: "Empty input",
|
||||
input: "",
|
||||
expectedKeys: nil,
|
||||
expectError: false,
|
||||
},
|
||||
{
|
||||
name: "Single local config",
|
||||
input: "local:///path/to/key",
|
||||
expectedKeys: []*encryptionpb.MasterKey{
|
||||
{
|
||||
Backend: &encryptionpb.MasterKey_File{
|
||||
File: &encryptionpb.MasterKeyFile{Path: "/path/to/key"},
|
||||
},
|
||||
},
|
||||
},
|
||||
expectError: false,
|
||||
},
|
||||
{
|
||||
name: "Single AWS config",
|
||||
input: "aws-kms:///key-id?AWS_ACCESS_KEY_ID=AKIAIOSFODNN7EXAMPLE&AWS_SECRET_ACCESS_KEY=wJalrXUtnFEMI/K7MDENG/bPxRfiCYEXAMPLEKEY®ION=us-west-2",
|
||||
expectedKeys: []*encryptionpb.MasterKey{
|
||||
{
|
||||
Backend: &encryptionpb.MasterKey_Kms{
|
||||
Kms: &encryptionpb.MasterKeyKms{
|
||||
Vendor: "aws",
|
||||
KeyId: "key-id",
|
||||
Region: "us-west-2",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
expectError: false,
|
||||
},
|
||||
{
|
||||
name: "Single Azure config",
|
||||
input: "azure-kms:///key-name/key-version?AZURE_TENANT_ID=tenant-id&AZURE_CLIENT_ID=client-id&AZURE_CLIENT_SECRET=client-secret&AZURE_VAULT_NAME=vault-name",
|
||||
expectedKeys: []*encryptionpb.MasterKey{
|
||||
{
|
||||
Backend: &encryptionpb.MasterKey_Kms{
|
||||
Kms: &encryptionpb.MasterKeyKms{
|
||||
Vendor: "azure",
|
||||
KeyId: "key-name/key-version",
|
||||
AzureKms: &encryptionpb.AzureKms{
|
||||
TenantId: "tenant-id",
|
||||
ClientId: "client-id",
|
||||
ClientSecret: "client-secret",
|
||||
KeyVaultUrl: "vault-name",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
expectError: false,
|
||||
},
|
||||
{
|
||||
name: "Single GCP config",
|
||||
input: "gcp-kms:///projects/project-id/locations/global/keyRings/ring-name/cryptoKeys/key-name?CREDENTIALS=credentials",
|
||||
expectedKeys: []*encryptionpb.MasterKey{
|
||||
{
|
||||
Backend: &encryptionpb.MasterKey_Kms{
|
||||
Kms: &encryptionpb.MasterKeyKms{
|
||||
Vendor: "gcp",
|
||||
KeyId: "projects/project-id/locations/global/keyRings/ring-name/cryptoKeys/key-name",
|
||||
GcpKms: &encryptionpb.GcpKms{
|
||||
Credential: "credentials",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
expectError: false,
|
||||
},
|
||||
{
|
||||
name: "Multiple configs",
|
||||
input: "local:///path/to/key," +
|
||||
"aws-kms:///key-id?AWS_ACCESS_KEY_ID=AKIAIOSFODNN7EXAMPLE&AWS_SECRET_ACCESS_KEY=wJalrXUtnFEMI/K7MDENG/bPxRfiCYEXAMPLEKEY®ION=us-west-2",
|
||||
expectedKeys: []*encryptionpb.MasterKey{
|
||||
{
|
||||
Backend: &encryptionpb.MasterKey_File{
|
||||
File: &encryptionpb.MasterKeyFile{Path: "/path/to/key"},
|
||||
},
|
||||
},
|
||||
{
|
||||
Backend: &encryptionpb.MasterKey_Kms{
|
||||
Kms: &encryptionpb.MasterKeyKms{
|
||||
Vendor: "aws",
|
||||
KeyId: "key-id",
|
||||
Region: "us-west-2",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
expectError: false,
|
||||
},
|
||||
{
|
||||
name: "Invalid config",
|
||||
input: "invalid:///config",
|
||||
expectedKeys: nil,
|
||||
expectError: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
cfg := &Config{}
|
||||
flags := pflag.NewFlagSet("test", pflag.ContinueOnError)
|
||||
flags.String(flagMasterKeyConfig, tt.input, "")
|
||||
flags.String(flagMasterKeyCipherType, "aes256-ctr", "")
|
||||
|
||||
err := cfg.parseAndValidateMasterKeyInfo(false, flags)
|
||||
|
||||
if tt.expectError {
|
||||
require.Error(t, err)
|
||||
} else {
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, tt.expectedKeys, cfg.MasterKeyConfig.MasterKeys)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
166
br/pkg/task/encryption.go
Normal file
166
br/pkg/task/encryption.go
Normal file
@ -0,0 +1,166 @@
|
||||
// Copyright 2024 PingCAP, Inc. Licensed under Apache-2.0.
|
||||
|
||||
package task
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net/url"
|
||||
"path"
|
||||
"regexp"
|
||||
|
||||
"github.com/pingcap/errors"
|
||||
"github.com/pingcap/kvproto/pkg/encryptionpb"
|
||||
)
|
||||
|
||||
const (
|
||||
SchemeLocal = "local"
|
||||
SchemeAWS = "aws-kms"
|
||||
SchemeAzure = "azure-kms"
|
||||
SchemeGCP = "gcp-kms"
|
||||
|
||||
AWSVendor = "aws"
|
||||
AWSRegion = "REGION"
|
||||
AWSEndpoint = "ENDPOINT"
|
||||
AWSAccessKeyId = "AWS_ACCESS_KEY_ID"
|
||||
AWSSecretKey = "AWS_SECRET_ACCESS_KEY"
|
||||
|
||||
AzureVendor = "azure"
|
||||
AzureTenantID = "AZURE_TENANT_ID"
|
||||
AzureClientID = "AZURE_CLIENT_ID"
|
||||
AzureClientSecret = "AZURE_CLIENT_SECRET"
|
||||
AzureVaultName = "AZURE_VAULT_NAME"
|
||||
|
||||
GCPVendor = "gcp"
|
||||
GCPCredentials = "CREDENTIALS"
|
||||
)
|
||||
|
||||
var (
|
||||
awsRegex = regexp.MustCompile(`^/([^/]+)$`)
|
||||
azureRegex = regexp.MustCompile(`^/(.+)$`)
|
||||
gcpRegex = regexp.MustCompile(`^/projects/([^/]+)/locations/([^/]+)/keyRings/([^/]+)/cryptoKeys/([^/]+)/?$`)
|
||||
)
|
||||
|
||||
func validateAndParseMasterKeyString(keyString string) (encryptionpb.MasterKey, error) {
|
||||
u, err := url.Parse(keyString)
|
||||
if err != nil {
|
||||
return encryptionpb.MasterKey{}, errors.Trace(err)
|
||||
}
|
||||
|
||||
switch u.Scheme {
|
||||
case SchemeLocal:
|
||||
return parseLocalDiskConfig(u)
|
||||
case SchemeAWS:
|
||||
return parseAwsKmsConfig(u)
|
||||
case SchemeAzure:
|
||||
return parseAzureKmsConfig(u)
|
||||
case SchemeGCP:
|
||||
return parseGcpKmsConfig(u)
|
||||
default:
|
||||
return encryptionpb.MasterKey{}, errors.Errorf("unsupported master key type: %s", u.Scheme)
|
||||
}
|
||||
}
|
||||
|
||||
func parseLocalDiskConfig(u *url.URL) (encryptionpb.MasterKey, error) {
|
||||
if !path.IsAbs(u.Path) {
|
||||
return encryptionpb.MasterKey{}, errors.New("local master key path must be absolute")
|
||||
}
|
||||
return encryptionpb.MasterKey{
|
||||
Backend: &encryptionpb.MasterKey_File{
|
||||
File: &encryptionpb.MasterKeyFile{
|
||||
Path: u.Path,
|
||||
},
|
||||
},
|
||||
}, nil
|
||||
}
|
||||
|
||||
func parseAwsKmsConfig(u *url.URL) (encryptionpb.MasterKey, error) {
|
||||
matches := awsRegex.FindStringSubmatch(u.Path)
|
||||
if matches == nil {
|
||||
return encryptionpb.MasterKey{}, errors.New("invalid AWS KMS key ID format")
|
||||
}
|
||||
keyID := matches[1]
|
||||
|
||||
q := u.Query()
|
||||
region := q.Get(AWSRegion)
|
||||
accessKey := q.Get(AWSAccessKeyId)
|
||||
secretAccessKey := q.Get(AWSSecretKey)
|
||||
|
||||
if region == "" {
|
||||
return encryptionpb.MasterKey{}, errors.New("missing AWS KMS region info")
|
||||
}
|
||||
|
||||
var awsKms *encryptionpb.AwsKms
|
||||
if accessKey != "" && secretAccessKey != "" {
|
||||
awsKms = &encryptionpb.AwsKms{
|
||||
AccessKey: accessKey,
|
||||
SecretAccessKey: secretAccessKey,
|
||||
}
|
||||
}
|
||||
|
||||
return encryptionpb.MasterKey{
|
||||
Backend: &encryptionpb.MasterKey_Kms{
|
||||
Kms: &encryptionpb.MasterKeyKms{
|
||||
Vendor: AWSVendor,
|
||||
KeyId: keyID,
|
||||
Region: region,
|
||||
Endpoint: q.Get(AWSEndpoint), // Optional
|
||||
AwsKms: awsKms, // Optional, can read from env
|
||||
},
|
||||
},
|
||||
}, nil
|
||||
}
|
||||
|
||||
func parseAzureKmsConfig(u *url.URL) (encryptionpb.MasterKey, error) {
|
||||
matches := azureRegex.FindStringSubmatch(u.Path)
|
||||
if matches == nil {
|
||||
return encryptionpb.MasterKey{}, errors.New("invalid Azure KMS path format")
|
||||
}
|
||||
|
||||
keyID := matches[1] // This now captures the entire path as the key ID
|
||||
q := u.Query()
|
||||
|
||||
azureKms := &encryptionpb.AzureKms{
|
||||
TenantId: q.Get(AzureTenantID),
|
||||
ClientId: q.Get(AzureClientID),
|
||||
ClientSecret: q.Get(AzureClientSecret),
|
||||
KeyVaultUrl: q.Get(AzureVaultName),
|
||||
}
|
||||
|
||||
if azureKms.TenantId == "" || azureKms.ClientId == "" || azureKms.ClientSecret == "" || azureKms.KeyVaultUrl == "" {
|
||||
return encryptionpb.MasterKey{}, errors.New("missing required Azure KMS parameters")
|
||||
}
|
||||
|
||||
return encryptionpb.MasterKey{
|
||||
Backend: &encryptionpb.MasterKey_Kms{
|
||||
Kms: &encryptionpb.MasterKeyKms{
|
||||
Vendor: AzureVendor,
|
||||
KeyId: keyID,
|
||||
AzureKms: azureKms,
|
||||
},
|
||||
},
|
||||
}, nil
|
||||
}
|
||||
|
||||
func parseGcpKmsConfig(u *url.URL) (encryptionpb.MasterKey, error) {
|
||||
matches := gcpRegex.FindStringSubmatch(u.Path)
|
||||
if matches == nil {
|
||||
return encryptionpb.MasterKey{}, errors.New("invalid GCP KMS path format")
|
||||
}
|
||||
|
||||
projectID, location, keyRing, keyName := matches[1], matches[2], matches[3], matches[4]
|
||||
q := u.Query()
|
||||
|
||||
gcpKms := &encryptionpb.GcpKms{
|
||||
Credential: q.Get(GCPCredentials),
|
||||
}
|
||||
|
||||
return encryptionpb.MasterKey{
|
||||
Backend: &encryptionpb.MasterKey_Kms{
|
||||
Kms: &encryptionpb.MasterKeyKms{
|
||||
Vendor: GCPVendor,
|
||||
KeyId: fmt.Sprintf("projects/%s/locations/%s/keyRings/%s/cryptoKeys/%s", projectID, location, keyRing, keyName),
|
||||
GcpKms: gcpKms,
|
||||
},
|
||||
},
|
||||
}, nil
|
||||
}
|
||||
192
br/pkg/task/encryption_test.go
Normal file
192
br/pkg/task/encryption_test.go
Normal file
@ -0,0 +1,192 @@
|
||||
// Copyright 2024 PingCAP, Inc. Licensed under Apache-2.0.
|
||||
|
||||
package task
|
||||
|
||||
import (
|
||||
"net/url"
|
||||
"testing"
|
||||
|
||||
"github.com/pingcap/kvproto/pkg/encryptionpb"
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestParseLocalDiskConfig(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
input string
|
||||
expected encryptionpb.MasterKey
|
||||
expectError bool
|
||||
}{
|
||||
{
|
||||
name: "Valid local path",
|
||||
input: "local:///path/to/key",
|
||||
expected: encryptionpb.MasterKey{Backend: &encryptionpb.MasterKey_File{File: &encryptionpb.MasterKeyFile{Path: "/path/to/key"}}},
|
||||
expectError: false,
|
||||
},
|
||||
{
|
||||
name: "Invalid local path",
|
||||
input: "local://relative/path",
|
||||
expectError: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
u, _ := url.Parse(tt.input)
|
||||
result, err := parseLocalDiskConfig(u)
|
||||
|
||||
if tt.expectError {
|
||||
assert.Error(t, err)
|
||||
} else {
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, tt.expected, result)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseAwsKmsConfig(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
input string
|
||||
expected encryptionpb.MasterKey
|
||||
expectError bool
|
||||
}{
|
||||
{
|
||||
name: "Valid AWS config",
|
||||
input: "aws-kms:///key-id?AWS_ACCESS_KEY_ID=AKIAIOSFODNN7EXAMPLE&AWS_SECRET_ACCESS_KEY=wJalrXUtnFEMI/K7MDENG/bPxRfiCYEXAMPLEKEY®ION=us-west-2",
|
||||
expected: encryptionpb.MasterKey{
|
||||
Backend: &encryptionpb.MasterKey_Kms{
|
||||
Kms: &encryptionpb.MasterKeyKms{
|
||||
Vendor: "aws",
|
||||
KeyId: "key-id",
|
||||
Region: "us-west-2",
|
||||
},
|
||||
},
|
||||
},
|
||||
expectError: false,
|
||||
},
|
||||
{
|
||||
name: "Missing key ID",
|
||||
input: "aws-kms:///?AWS_ACCESS_KEY_ID=AKIAIOSFODNN7EXAMPLE&AWS_SECRET_ACCESS_KEY=wJalrXUtnFEMI/K7MDENG/bPxRfiCYEXAMPLEKEY®ION=us-west-2",
|
||||
expectError: true,
|
||||
},
|
||||
{
|
||||
name: "Missing required parameter",
|
||||
input: "aws-kms:///key-id?AWS_ACCESS_KEY_ID=AKIAIOSFODNN7EXAMPLE®ION=us-west-2",
|
||||
expectError: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
u, _ := url.Parse(tt.input)
|
||||
result, err := parseAwsKmsConfig(u)
|
||||
|
||||
if tt.expectError {
|
||||
assert.Error(t, err)
|
||||
} else {
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, tt.expected, result)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseAzureKmsConfig(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
input string
|
||||
expected encryptionpb.MasterKey
|
||||
expectError bool
|
||||
}{
|
||||
{
|
||||
name: "Valid Azure config",
|
||||
input: "azure-kms:///key-name/key-version?AZURE_TENANT_ID=tenant-id&AZURE_CLIENT_ID=client-id&AZURE_CLIENT_SECRET=client-secret&AZURE_VAULT_NAME=vault-name",
|
||||
expected: encryptionpb.MasterKey{
|
||||
Backend: &encryptionpb.MasterKey_Kms{
|
||||
Kms: &encryptionpb.MasterKeyKms{
|
||||
Vendor: "azure",
|
||||
KeyId: "key-name/key-version",
|
||||
AzureKms: &encryptionpb.AzureKms{
|
||||
TenantId: "tenant-id",
|
||||
ClientId: "client-id",
|
||||
ClientSecret: "client-secret",
|
||||
KeyVaultUrl: "vault-name",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
expectError: false,
|
||||
},
|
||||
{
|
||||
name: "Missing required parameter",
|
||||
input: "azure-kms:///key-name/key-version?AZURE_TENANT_ID=tenant-id&AZURE_CLIENT_ID=client-id&AZURE_VAULT_NAME=vault-name",
|
||||
expectError: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
u, _ := url.Parse(tt.input)
|
||||
result, err := parseAzureKmsConfig(u)
|
||||
|
||||
if tt.expectError {
|
||||
assert.Error(t, err)
|
||||
} else {
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, tt.expected, result)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseGcpKmsConfig(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
input string
|
||||
expected encryptionpb.MasterKey
|
||||
expectError bool
|
||||
}{
|
||||
{
|
||||
name: "Valid GCP config",
|
||||
input: "gcp-kms:///projects/project-id/locations/global/keyRings/ring-name/cryptoKeys/key-name?CREDENTIALS=credentials",
|
||||
expected: encryptionpb.MasterKey{
|
||||
Backend: &encryptionpb.MasterKey_Kms{
|
||||
Kms: &encryptionpb.MasterKeyKms{
|
||||
Vendor: "gcp",
|
||||
KeyId: "projects/project-id/locations/global/keyRings/ring-name/cryptoKeys/key-name",
|
||||
GcpKms: &encryptionpb.GcpKms{
|
||||
Credential: "credentials",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
expectError: false,
|
||||
},
|
||||
{
|
||||
name: "Invalid path format",
|
||||
input: "gcp-kms:///invalid/path?CREDENTIALS=credentials",
|
||||
expectError: true,
|
||||
},
|
||||
{
|
||||
name: "Missing credentials",
|
||||
input: "gcp-kms:///projects/project-id/locations/global/keyRings/ring-name/cryptoKeys/key-name",
|
||||
expectError: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
u, _ := url.Parse(tt.input)
|
||||
result, err := parseGcpKmsConfig(u)
|
||||
|
||||
if tt.expectError {
|
||||
assert.Error(t, err)
|
||||
} else {
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, tt.expected, result)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
@ -247,7 +247,8 @@ type RestoreConfig struct {
|
||||
AllowPITRFromIncremental bool `json:"allow-pitr-from-incremental" toml:"allow-pitr-from-incremental"`
|
||||
|
||||
// [startTs, RestoreTS] is used to `restore log` from StartTS to RestoreTS.
|
||||
StartTS uint64 `json:"start-ts" toml:"start-ts"`
|
||||
StartTS uint64 `json:"start-ts" toml:"start-ts"`
|
||||
// if not specified system will restore to the max TS available
|
||||
RestoreTS uint64 `json:"restore-ts" toml:"restore-ts"`
|
||||
tiflashRecorder *tiflashrec.TiFlashRecorder `json:"-" toml:"-"`
|
||||
PitrBatchCount uint32 `json:"pitr-batch-count" toml:"pitr-batch-count"`
|
||||
@ -695,12 +696,12 @@ func RunRestore(c context.Context, g glue.Glue, cmdName string, cfg *RestoreConf
|
||||
if err := version.CheckClusterVersion(c, mgr.GetPDClient(), version.CheckVersionForBRPiTR); err != nil {
|
||||
return errors.Trace(err)
|
||||
}
|
||||
restoreError = RunStreamRestore(c, mgr, g, cmdName, cfg)
|
||||
restoreError = RunStreamRestore(c, mgr, g, cfg)
|
||||
} else {
|
||||
if err := version.CheckClusterVersion(c, mgr.GetPDClient(), version.CheckVersionForBR); err != nil {
|
||||
return errors.Trace(err)
|
||||
}
|
||||
restoreError = runRestore(c, mgr, g, cmdName, cfg, nil)
|
||||
restoreError = runSnapshotRestore(c, mgr, g, cmdName, cfg, nil)
|
||||
}
|
||||
if restoreError != nil {
|
||||
return errors.Trace(restoreError)
|
||||
@ -733,12 +734,13 @@ func RunRestore(c context.Context, g glue.Glue, cmdName string, cfg *RestoreConf
|
||||
return nil
|
||||
}
|
||||
|
||||
func runRestore(c context.Context, mgr *conn.Mgr, g glue.Glue, cmdName string, cfg *RestoreConfig, checkInfo *PiTRTaskInfo) error {
|
||||
func runSnapshotRestore(c context.Context, mgr *conn.Mgr, g glue.Glue, cmdName string, cfg *RestoreConfig, checkInfo *PiTRTaskInfo) error {
|
||||
cfg.Adjust()
|
||||
defer summary.Summary(cmdName)
|
||||
ctx, cancel := context.WithCancel(c)
|
||||
defer cancel()
|
||||
|
||||
log.Info("starting snapshot restore")
|
||||
if span := opentracing.SpanFromContext(ctx); span != nil && span.Tracer() != nil {
|
||||
span1 := span.Tracer().StartSpan("task.RunRestore", opentracing.ChildOf(span.Context()))
|
||||
defer span1.Finish()
|
||||
@ -836,7 +838,7 @@ func runRestore(c context.Context, mgr *conn.Mgr, g glue.Glue, cmdName string, c
|
||||
|
||||
if client.IsIncremental() {
|
||||
// don't support checkpoint for the ddl restore
|
||||
log.Info("the incremental snapshot restore doesn't support checkpoint mode, so unuse checkpoint.")
|
||||
log.Info("the incremental snapshot restore doesn't support checkpoint mode, disable checkpoint.")
|
||||
cfg.UseCheckpoint = false
|
||||
}
|
||||
|
||||
@ -860,7 +862,7 @@ func runRestore(c context.Context, mgr *conn.Mgr, g glue.Glue, cmdName string, c
|
||||
log.Info("finish removing pd scheduler")
|
||||
}()
|
||||
|
||||
var checkpointFirstRun bool = true
|
||||
var checkpointFirstRun = true
|
||||
if cfg.UseCheckpoint {
|
||||
// if the checkpoint metadata exists in the checkpoint storage, the restore is not
|
||||
// for the first time.
|
||||
|
||||
@ -36,6 +36,7 @@ import (
|
||||
"github.com/pingcap/tidb/br/pkg/backup"
|
||||
"github.com/pingcap/tidb/br/pkg/checkpoint"
|
||||
"github.com/pingcap/tidb/br/pkg/conn"
|
||||
"github.com/pingcap/tidb/br/pkg/encryption"
|
||||
berrors "github.com/pingcap/tidb/br/pkg/errors"
|
||||
"github.com/pingcap/tidb/br/pkg/glue"
|
||||
"github.com/pingcap/tidb/br/pkg/httputil"
|
||||
@ -302,8 +303,8 @@ func NewStreamMgr(ctx context.Context, cfg *StreamConfig, g glue.Glue, isStreamS
|
||||
}
|
||||
}()
|
||||
|
||||
// just stream start need Storage
|
||||
s := &streamMgr{
|
||||
// only stream start command needs Storage
|
||||
streamManager := &streamMgr{
|
||||
cfg: cfg,
|
||||
mgr: mgr,
|
||||
}
|
||||
@ -323,12 +324,12 @@ func NewStreamMgr(ctx context.Context, cfg *StreamConfig, g glue.Glue, isStreamS
|
||||
if err = client.SetStorage(ctx, backend, &opts); err != nil {
|
||||
return nil, errors.Trace(err)
|
||||
}
|
||||
s.bc = client
|
||||
streamManager.bc = client
|
||||
|
||||
// create http client to do some requirements check.
|
||||
s.httpCli = httputil.NewClient(mgr.GetTLSConfig())
|
||||
streamManager.httpCli = httputil.NewClient(mgr.GetTLSConfig())
|
||||
}
|
||||
return s, nil
|
||||
return streamManager, nil
|
||||
}
|
||||
|
||||
func (s *streamMgr) close() {
|
||||
@ -346,7 +347,7 @@ func (s *streamMgr) setLock(ctx context.Context) error {
|
||||
// adjustAndCheckStartTS checks that startTS should be smaller than currentTS,
|
||||
// and endTS is larger than currentTS.
|
||||
func (s *streamMgr) adjustAndCheckStartTS(ctx context.Context) error {
|
||||
currentTS, err := s.mgr.GetTS(ctx)
|
||||
currentTS, err := s.mgr.GetCurrentTsFromPD(ctx)
|
||||
if err != nil {
|
||||
return errors.Trace(err)
|
||||
}
|
||||
@ -499,7 +500,7 @@ func RunStreamCommand(
|
||||
}
|
||||
|
||||
if err := commandFn(ctx, g, cmdName, cfg); err != nil {
|
||||
log.Error("failed to stream", zap.String("command", cmdName), zap.Error(err))
|
||||
log.Error("failed to run stream command", zap.String("command", cmdName), zap.Error(err))
|
||||
summary.SetSuccessStatus(false)
|
||||
summary.CollectFailureUnit(cmdName, err)
|
||||
return err
|
||||
@ -547,22 +548,28 @@ func RunStreamStart(
|
||||
log.Warn("failed to close etcd client", zap.Error(closeErr))
|
||||
}
|
||||
}()
|
||||
|
||||
// check if any import/restore task is running, it's not allowed to start log backup
|
||||
// while restore is ongoing.
|
||||
if err = streamMgr.checkImportTaskRunning(ctx, cli.Client); err != nil {
|
||||
return errors.Trace(err)
|
||||
}
|
||||
|
||||
// It supports single stream log task currently.
|
||||
if count, err := cli.GetTaskCount(ctx); err != nil {
|
||||
return errors.Trace(err)
|
||||
} else if count > 0 {
|
||||
return errors.Annotate(berrors.ErrStreamLogTaskExist, "It supports single stream log task currently")
|
||||
return errors.Annotate(berrors.ErrStreamLogTaskExist, "failed to start the log backup, allow only one running task")
|
||||
}
|
||||
|
||||
exist, err := streamMgr.checkLock(ctx)
|
||||
// make sure external file lock is available
|
||||
locked, err := streamMgr.checkLock(ctx)
|
||||
if err != nil {
|
||||
return errors.Trace(err)
|
||||
}
|
||||
// exist is true, which represents restart a stream task. Or create a new stream task.
|
||||
if exist {
|
||||
|
||||
// locked means this is a stream task restart. Or create a new stream task.
|
||||
if locked {
|
||||
logInfo, err := getLogRange(ctx, &cfg.Config)
|
||||
if err != nil {
|
||||
return errors.Trace(err)
|
||||
@ -614,6 +621,7 @@ func RunStreamStart(
|
||||
return errors.Annotate(berrors.ErrInvalidArgument, "nothing need to observe")
|
||||
}
|
||||
|
||||
securityConfig := generateSecurityConfig(cfg)
|
||||
ti := streamhelper.TaskInfo{
|
||||
PBInfo: backuppb.StreamBackupTaskInfo{
|
||||
Storage: streamMgr.bc.GetStorageBackend(),
|
||||
@ -622,6 +630,7 @@ func RunStreamStart(
|
||||
Name: cfg.TaskName,
|
||||
TableFilter: cfg.FilterStr,
|
||||
CompressionType: backuppb.CompressionType_ZSTD,
|
||||
SecurityConfig: &securityConfig,
|
||||
},
|
||||
Ranges: ranges,
|
||||
Pausing: false,
|
||||
@ -633,6 +642,30 @@ func RunStreamStart(
|
||||
return nil
|
||||
}
|
||||
|
||||
func generateSecurityConfig(cfg *StreamConfig) backuppb.StreamBackupTaskSecurityConfig {
|
||||
if len(cfg.LogBackupCipherInfo.CipherKey) > 0 && utils.IsEffectiveEncryptionMethod(cfg.LogBackupCipherInfo.CipherType) {
|
||||
return backuppb.StreamBackupTaskSecurityConfig{
|
||||
Encryption: &backuppb.StreamBackupTaskSecurityConfig_PlaintextDataKey{
|
||||
PlaintextDataKey: &backuppb.CipherInfo{
|
||||
CipherType: cfg.LogBackupCipherInfo.CipherType,
|
||||
CipherKey: cfg.LogBackupCipherInfo.CipherKey,
|
||||
},
|
||||
},
|
||||
}
|
||||
}
|
||||
if len(cfg.MasterKeyConfig.MasterKeys) > 0 && utils.IsEffectiveEncryptionMethod(cfg.MasterKeyConfig.EncryptionType) {
|
||||
return backuppb.StreamBackupTaskSecurityConfig{
|
||||
Encryption: &backuppb.StreamBackupTaskSecurityConfig_MasterKeyConfig{
|
||||
MasterKeyConfig: &backuppb.MasterKeyConfig{
|
||||
EncryptionType: cfg.MasterKeyConfig.EncryptionType,
|
||||
MasterKeys: cfg.MasterKeyConfig.MasterKeys,
|
||||
},
|
||||
},
|
||||
}
|
||||
}
|
||||
return backuppb.StreamBackupTaskSecurityConfig{}
|
||||
}
|
||||
|
||||
func RunStreamMetadata(
|
||||
c context.Context,
|
||||
g glue.Glue,
|
||||
@ -995,13 +1028,13 @@ func RunStreamTruncate(c context.Context, g glue.Glue, cmdName string, cfg *Stre
|
||||
}
|
||||
|
||||
if cfg.Until < sp {
|
||||
console.Println("According to the log, you have truncated backup data before", em(formatTS(sp)))
|
||||
console.Println("According to the log, you have truncated log backup data before", em(formatTS(sp)))
|
||||
if !cfg.SkipPrompt && !console.PromptBool("Continue? ") {
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
readMetaDone := console.ShowTask("Reading Metadata... ", glue.WithTimeCost())
|
||||
readMetaDone := console.ShowTask("Reading log backup metadata... ", glue.WithTimeCost())
|
||||
metas := stream.StreamMetadataSet{
|
||||
MetadataDownloadBatchSize: cfg.MetadataDownloadBatchSize,
|
||||
Helper: stream.NewMetadataHelper(),
|
||||
@ -1025,11 +1058,11 @@ func RunStreamTruncate(c context.Context, g glue.Glue, cmdName string, cfg *Stre
|
||||
kvCount += d.KVCount
|
||||
return
|
||||
})
|
||||
console.Printf("We are going to remove %s files, until %s.\n",
|
||||
console.Printf("We are going to truncate %s files, up to TS %s.\n",
|
||||
em(fileCount),
|
||||
em(formatTS(cfg.Until)),
|
||||
)
|
||||
if !cfg.SkipPrompt && !console.PromptBool(warn("Sure? ")) {
|
||||
if !cfg.SkipPrompt && !console.PromptBool(warn("Are you sure?")) {
|
||||
return nil
|
||||
}
|
||||
|
||||
@ -1042,7 +1075,7 @@ func RunStreamTruncate(c context.Context, g glue.Glue, cmdName string, cfg *Stre
|
||||
|
||||
// begin to remove
|
||||
p := console.StartProgressBar(
|
||||
"Clearing Data Files and Metadata", fileCount,
|
||||
"Truncating Data Files and Metadata", fileCount,
|
||||
glue.WithTimeCost(),
|
||||
glue.WithConstExtraField("kv-count", kvCount),
|
||||
glue.WithConstExtraField("kv-size", fmt.Sprintf("%d(%s)", totalSize, units.HumanSize(float64(totalSize)))),
|
||||
@ -1113,7 +1146,6 @@ func RunStreamRestore(
|
||||
c context.Context,
|
||||
mgr *conn.Mgr,
|
||||
g glue.Glue,
|
||||
cmdName string,
|
||||
cfg *RestoreConfig,
|
||||
) (err error) {
|
||||
ctx, cancelFn := context.WithCancel(c)
|
||||
@ -1132,6 +1164,8 @@ func RunStreamRestore(
|
||||
if err != nil {
|
||||
return errors.Trace(err)
|
||||
}
|
||||
|
||||
// if not set by user, restore to the max TS available
|
||||
if cfg.RestoreTS == 0 {
|
||||
cfg.RestoreTS = logInfo.logMaxTS
|
||||
}
|
||||
@ -1157,7 +1191,7 @@ func RunStreamRestore(
|
||||
}
|
||||
}
|
||||
|
||||
log.Info("start restore on point",
|
||||
log.Info("start point in time restore",
|
||||
zap.Uint64("restore-from", cfg.StartTS), zap.Uint64("restore-to", cfg.RestoreTS),
|
||||
zap.Uint64("log-min-ts", logInfo.logMinTS), zap.Uint64("log-max-ts", logInfo.logMaxTS))
|
||||
if err := checkLogRange(cfg.StartTS, cfg.RestoreTS, logInfo.logMinTS, logInfo.logMaxTS); err != nil {
|
||||
@ -1180,7 +1214,7 @@ func RunStreamRestore(
|
||||
logStorage := cfg.Config.Storage
|
||||
cfg.Config.Storage = cfg.FullBackupStorage
|
||||
// TiFlash replica is restored to down-stream on 'pitr' currently.
|
||||
if err = runRestore(ctx, mgr, g, FullRestoreCmd, cfg, checkInfo); err != nil {
|
||||
if err = runSnapshotRestore(ctx, mgr, g, FullRestoreCmd, cfg, checkInfo); err != nil {
|
||||
return errors.Trace(err)
|
||||
}
|
||||
cfg.Config.Storage = logStorage
|
||||
@ -1191,9 +1225,6 @@ func RunStreamRestore(
|
||||
}
|
||||
if checkInfo.CheckpointInfo != nil && checkInfo.CheckpointInfo.Metadata != nil && checkInfo.CheckpointInfo.Metadata.TiFlashItems != nil {
|
||||
log.Info("load tiflash records of snapshot restore from checkpoint")
|
||||
if err != nil {
|
||||
return errors.Trace(err)
|
||||
}
|
||||
cfg.tiflashRecorder.Load(checkInfo.CheckpointInfo.Metadata.TiFlashItems)
|
||||
}
|
||||
}
|
||||
@ -1329,7 +1360,12 @@ func restoreStream(
|
||||
}()
|
||||
}
|
||||
|
||||
err = client.InstallLogFileManager(ctx, cfg.StartTS, cfg.RestoreTS, cfg.MetadataDownloadBatchSize)
|
||||
encryptionManager, err := encryption.NewManager(&cfg.LogBackupCipherInfo, &cfg.MasterKeyConfig)
|
||||
if err != nil {
|
||||
return errors.Annotate(err, "failed to create encryption manager for log restore")
|
||||
}
|
||||
defer encryptionManager.Close()
|
||||
err = client.InstallLogFileManager(ctx, cfg.StartTS, cfg.RestoreTS, cfg.MetadataDownloadBatchSize, encryptionManager)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@ -1345,12 +1381,13 @@ func restoreStream(
|
||||
newTask = false
|
||||
}
|
||||
// get the schemas ID replace information.
|
||||
// since targeted full backup storage, need to use the full backup cipher
|
||||
schemasReplace, err := client.InitSchemasReplaceForDDL(ctx, &logclient.InitSchemaConfig{
|
||||
IsNewTask: newTask,
|
||||
TableFilter: cfg.TableFilter,
|
||||
TiFlashRecorder: cfg.tiflashRecorder,
|
||||
FullBackupStorage: fullBackupStorage,
|
||||
})
|
||||
}, &cfg.Config.CipherInfo)
|
||||
if err != nil {
|
||||
return errors.Trace(err)
|
||||
}
|
||||
@ -1431,7 +1468,8 @@ func restoreStream(
|
||||
return errors.Trace(err)
|
||||
}
|
||||
|
||||
return client.RestoreKVFiles(ctx, rewriteRules, idrules, logFilesIterWithSplit, checkpointRunner, cfg.PitrBatchCount, cfg.PitrBatchSize, updateStats, p.IncBy)
|
||||
return client.RestoreKVFiles(ctx, rewriteRules, idrules, logFilesIterWithSplit, checkpointRunner,
|
||||
cfg.PitrBatchCount, cfg.PitrBatchSize, updateStats, p.IncBy, &cfg.LogBackupCipherInfo, cfg.MasterKeyConfig.MasterKeys)
|
||||
})
|
||||
if err != nil {
|
||||
return errors.Annotate(err, "failed to restore kv files")
|
||||
@ -1547,15 +1585,15 @@ func getExternalStorageOptions(cfg *Config, u *backuppb.StorageBackend) storage.
|
||||
}
|
||||
}
|
||||
|
||||
func checkLogRange(restoreFrom, restoreTo, logMinTS, logMaxTS uint64) error {
|
||||
// serveral ts constraint:
|
||||
// logMinTS <= restoreFrom <= restoreTo <= logMaxTS
|
||||
if logMinTS > restoreFrom || restoreFrom > restoreTo || restoreTo > logMaxTS {
|
||||
func checkLogRange(restoreFromTS, restoreToTS, logMinTS, logMaxTS uint64) error {
|
||||
// several ts constraint:
|
||||
// logMinTS <= restoreFromTS <= restoreToTS <= logMaxTS
|
||||
if logMinTS > restoreFromTS || restoreFromTS > restoreToTS || restoreToTS > logMaxTS {
|
||||
return errors.Annotatef(berrors.ErrInvalidArgument,
|
||||
"restore log from %d(%s) to %d(%s), "+
|
||||
" but the current existed log from %d(%s) to %d(%s)",
|
||||
restoreFrom, oracle.GetTimeFromTS(restoreFrom),
|
||||
restoreTo, oracle.GetTimeFromTS(restoreTo),
|
||||
restoreFromTS, oracle.GetTimeFromTS(restoreFromTS),
|
||||
restoreToTS, oracle.GetTimeFromTS(restoreToTS),
|
||||
logMinTS, oracle.GetTimeFromTS(logMinTS),
|
||||
logMaxTS, oracle.GetTimeFromTS(logMaxTS),
|
||||
)
|
||||
@ -1648,7 +1686,7 @@ func getGlobalCheckpointFromStorage(ctx context.Context, s storage.ExternalStora
|
||||
return globalCheckPointTS, errors.Trace(err)
|
||||
}
|
||||
|
||||
// getFullBackupTS gets the snapshot-ts of full bakcup
|
||||
// getFullBackupTS gets the snapshot-ts of full backup
|
||||
func getFullBackupTS(
|
||||
ctx context.Context,
|
||||
cfg *RestoreConfig,
|
||||
@ -1663,11 +1701,17 @@ func getFullBackupTS(
|
||||
return 0, 0, errors.Trace(err)
|
||||
}
|
||||
|
||||
backupmeta := &backuppb.BackupMeta{}
|
||||
if err = backupmeta.Unmarshal(metaData); err != nil {
|
||||
decryptedMetaData, err := metautil.DecryptFullBackupMetaIfNeeded(metaData, &cfg.CipherInfo)
|
||||
if err != nil {
|
||||
return 0, 0, errors.Trace(err)
|
||||
}
|
||||
|
||||
backupmeta := &backuppb.BackupMeta{}
|
||||
if err = backupmeta.Unmarshal(decryptedMetaData); err != nil {
|
||||
return 0, 0, errors.Trace(err)
|
||||
}
|
||||
|
||||
// start and end are identical in full backup, pick random one
|
||||
return backupmeta.GetEndVersion(), backupmeta.GetClusterId(), nil
|
||||
}
|
||||
|
||||
@ -1757,7 +1801,7 @@ func checkPiTRTaskInfo(
|
||||
cfg *RestoreConfig,
|
||||
) (*PiTRTaskInfo, error) {
|
||||
var (
|
||||
doFullRestore = (len(cfg.FullBackupStorage) > 0)
|
||||
doFullRestore = len(cfg.FullBackupStorage) > 0
|
||||
curTaskInfo *checkpoint.CheckpointTaskInfoForLogRestore
|
||||
)
|
||||
checkInfo := &PiTRTaskInfo{}
|
||||
|
||||
@ -110,35 +110,6 @@ func TestCheckLogRange(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
type fakeResolvedInfo struct {
|
||||
storeID int64
|
||||
resolvedTS uint64
|
||||
}
|
||||
|
||||
func fakeMetaFiles(ctx context.Context, tempDir string, infos []fakeResolvedInfo) error {
|
||||
backupMetaDir := filepath.Join(tempDir, stream.GetStreamBackupMetaPrefix())
|
||||
s, err := storage.NewLocalStorage(backupMetaDir)
|
||||
if err != nil {
|
||||
return errors.Trace(err)
|
||||
}
|
||||
|
||||
for _, info := range infos {
|
||||
meta := &backuppb.Metadata{
|
||||
StoreId: info.storeID,
|
||||
ResolvedTs: info.resolvedTS,
|
||||
}
|
||||
buff, err := meta.Marshal()
|
||||
if err != nil {
|
||||
return errors.Trace(err)
|
||||
}
|
||||
filename := fmt.Sprintf("%d_%d.meta", info.storeID, info.resolvedTS)
|
||||
if err = s.WriteFile(ctx, filename, buff); err != nil {
|
||||
return errors.Trace(err)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func fakeCheckpointFiles(
|
||||
ctx context.Context,
|
||||
tmpDir string,
|
||||
@ -154,7 +125,7 @@ func fakeCheckpointFiles(
|
||||
for _, info := range infos {
|
||||
filename := fmt.Sprintf("%v.ts", info.storeID)
|
||||
buff := make([]byte, 8)
|
||||
binary.LittleEndian.PutUint64(buff, info.global_checkpoint)
|
||||
binary.LittleEndian.PutUint64(buff, info.globalCheckpoint)
|
||||
if _, err := s.Create(ctx, filename, nil); err != nil {
|
||||
return errors.Trace(err)
|
||||
}
|
||||
@ -170,8 +141,8 @@ func fakeCheckpointFiles(
|
||||
}
|
||||
|
||||
type fakeGlobalCheckPoint struct {
|
||||
storeID int64
|
||||
global_checkpoint uint64
|
||||
storeID int64
|
||||
globalCheckpoint uint64
|
||||
}
|
||||
|
||||
func TestGetGlobalCheckpointFromStorage(t *testing.T) {
|
||||
@ -182,16 +153,16 @@ func TestGetGlobalCheckpointFromStorage(t *testing.T) {
|
||||
|
||||
infos := []fakeGlobalCheckPoint{
|
||||
{
|
||||
storeID: 1,
|
||||
global_checkpoint: 98,
|
||||
storeID: 1,
|
||||
globalCheckpoint: 98,
|
||||
},
|
||||
{
|
||||
storeID: 2,
|
||||
global_checkpoint: 90,
|
||||
storeID: 2,
|
||||
globalCheckpoint: 90,
|
||||
},
|
||||
{
|
||||
storeID: 2,
|
||||
global_checkpoint: 99,
|
||||
storeID: 2,
|
||||
globalCheckpoint: 99,
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
@ -7,6 +7,7 @@ go_library(
|
||||
"db.go",
|
||||
"dyn_pprof_other.go",
|
||||
"dyn_pprof_unix.go",
|
||||
"encryption.go",
|
||||
"error_handling.go",
|
||||
"json.go",
|
||||
"key.go",
|
||||
@ -35,6 +36,7 @@ go_library(
|
||||
"//pkg/parser/types",
|
||||
"//pkg/sessionctx",
|
||||
"//pkg/util",
|
||||
"//pkg/util/encrypt",
|
||||
"//pkg/util/logutil",
|
||||
"//pkg/util/sqlexec",
|
||||
"@com_github_cheggaaa_pb_v3//:pb",
|
||||
@ -43,6 +45,7 @@ go_library(
|
||||
"@com_github_pingcap_errors//:errors",
|
||||
"@com_github_pingcap_failpoint//:failpoint",
|
||||
"@com_github_pingcap_kvproto//pkg/brpb",
|
||||
"@com_github_pingcap_kvproto//pkg/encryptionpb",
|
||||
"@com_github_pingcap_kvproto//pkg/metapb",
|
||||
"@com_github_pingcap_log//:log",
|
||||
"@com_github_tikv_client_go_v2//oracle",
|
||||
|
||||
30
br/pkg/utils/encryption.go
Normal file
30
br/pkg/utils/encryption.go
Normal file
@ -0,0 +1,30 @@
|
||||
package utils
|
||||
|
||||
import (
|
||||
"github.com/pingcap/errors"
|
||||
backuppb "github.com/pingcap/kvproto/pkg/brpb"
|
||||
"github.com/pingcap/kvproto/pkg/encryptionpb"
|
||||
berrors "github.com/pingcap/tidb/br/pkg/errors"
|
||||
"github.com/pingcap/tidb/pkg/util/encrypt"
|
||||
)
|
||||
|
||||
func Decrypt(content []byte, cipher *backuppb.CipherInfo, iv []byte) ([]byte, error) {
|
||||
if len(content) == 0 || cipher == nil {
|
||||
return content, nil
|
||||
}
|
||||
|
||||
switch cipher.CipherType {
|
||||
case encryptionpb.EncryptionMethod_PLAINTEXT:
|
||||
return content, nil
|
||||
case encryptionpb.EncryptionMethod_AES128_CTR,
|
||||
encryptionpb.EncryptionMethod_AES192_CTR,
|
||||
encryptionpb.EncryptionMethod_AES256_CTR:
|
||||
return encrypt.AESDecryptWithCTR(content, cipher.CipherKey, iv)
|
||||
default:
|
||||
return content, errors.Annotatef(berrors.ErrInvalidArgument, "cipher type invalid %s", cipher.CipherType)
|
||||
}
|
||||
}
|
||||
|
||||
func IsEffectiveEncryptionMethod(method encryptionpb.EncryptionMethod) bool {
|
||||
return method != encryptionpb.EncryptionMethod_UNKNOWN && method != encryptionpb.EncryptionMethod_PLAINTEXT
|
||||
}
|
||||
@ -33,7 +33,7 @@ This folder contains all tests which relies on external processes such as TiDB.
|
||||
|
||||
## Preparations
|
||||
|
||||
1. The following 9 executables must be copied or linked into these locations:
|
||||
1. The following 9 executables must be copied or linked into the `bin` folder under the TiDB root dir:
|
||||
|
||||
* `bin/tidb-server`
|
||||
* `bin/tikv-server`
|
||||
@ -80,7 +80,7 @@ If you have docker installed, you can skip step 1 and step 2 by running
|
||||
1. Build `br.test` using `make build_for_br_integration_test`
|
||||
2. Check that all 9 required executables and `br` executable exist
|
||||
3. Select the tests to run using `export TEST_NAME="<test_name1> <test_name2> ..."`
|
||||
3. Execute `tests/run.sh`
|
||||
4. Execute `br/tests/run.sh`
|
||||
<!-- 4. To start cluster with tiflash, please run `TIFLASH=1 tests/run.sh` -->
|
||||
|
||||
If the first two steps are done before, you could also run `tests/run.sh` directly.
|
||||
|
||||
435
br/tests/br_encryption/run.sh
Executable file
435
br/tests/br_encryption/run.sh
Executable file
@ -0,0 +1,435 @@
|
||||
#!/bin/sh
|
||||
#
|
||||
# Copyright 2024 PingCAP, Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
set -eu
|
||||
. run_services
|
||||
CUR=$(cd "$(dirname "$0")" && pwd)
|
||||
|
||||
# const value
|
||||
PREFIX="encryption_backup"
|
||||
res_file="$TEST_DIR/sql_res.$TEST_NAME.txt"
|
||||
DB="$TEST_NAME"
|
||||
TABLE="usertable"
|
||||
DB_COUNT=3
|
||||
TASK_NAME="encryption_test"
|
||||
|
||||
create_db_with_table() {
|
||||
for i in $(seq $DB_COUNT); do
|
||||
run_sql "CREATE DATABASE $DB${i};"
|
||||
go-ycsb load mysql -P $CUR/workload -p mysql.host=$TIDB_IP -p mysql.port=$TIDB_PORT -p mysql.user=root -p mysql.db=$DB${i} -p recordcount=1000
|
||||
done
|
||||
}
|
||||
|
||||
start_log_backup() {
|
||||
_storage=$1
|
||||
_encryption_args=$2
|
||||
echo "start log backup task"
|
||||
run_br --pd "$PD_ADDR" log start --task-name $TASK_NAME -s "$_storage" $_encryption_args
|
||||
}
|
||||
|
||||
drop_db() {
|
||||
for i in $(seq $DB_COUNT); do
|
||||
run_sql "DROP DATABASE IF EXISTS $DB${i};"
|
||||
done
|
||||
}
|
||||
|
||||
insert_additional_data() {
|
||||
local prefix=$1
|
||||
for i in $(seq $DB_COUNT); do
|
||||
go-ycsb load mysql -P $CUR/workload -p mysql.host=$TIDB_IP -p mysql.port=$TIDB_PORT -p mysql.user=root -p mysql.db=$DB${i} -p insertcount=1000 -p insertstart=1000000 -p recordcount=1001000 -p workload=core
|
||||
done
|
||||
}
|
||||
|
||||
wait_log_checkpoint_advance() {
|
||||
echo "wait for log checkpoint to advance"
|
||||
sleep 10
|
||||
local current_ts=$(python3 -c "import time; print(int(time.time() * 1000) << 18)")
|
||||
echo "current ts: $current_ts"
|
||||
i=0
|
||||
while true; do
|
||||
# extract the checkpoint ts of the log backup task. If there is some error, the checkpoint ts should be empty
|
||||
log_backup_status=$(unset BR_LOG_TO_TERM && run_br --skip-goleak --pd $PD_ADDR log status --task-name $TASK_NAME --json 2>br.log)
|
||||
echo "log backup status: $log_backup_status"
|
||||
local checkpoint_ts=$(echo "$log_backup_status" | head -n 1 | jq 'if .[0].last_errors | length == 0 then .[0].checkpoint else empty end')
|
||||
echo "checkpoint ts: $checkpoint_ts"
|
||||
|
||||
# check whether the checkpoint ts is a number
|
||||
if [ $checkpoint_ts -gt 0 ] 2>/dev/null; then
|
||||
if [ $checkpoint_ts -gt $current_ts ]; then
|
||||
echo "the checkpoint has advanced"
|
||||
break
|
||||
fi
|
||||
echo "the checkpoint hasn't advanced"
|
||||
i=$((i+1))
|
||||
if [ "$i" -gt 50 ]; then
|
||||
echo 'the checkpoint lag is too large'
|
||||
exit 1
|
||||
fi
|
||||
sleep 10
|
||||
else
|
||||
echo "TEST: [$TEST_NAME] failed to wait checkpoint advance!"
|
||||
exit 1
|
||||
fi
|
||||
done
|
||||
}
|
||||
|
||||
calculate_checksum() {
|
||||
local db=$1
|
||||
local checksum=$(run_sql "USE $db; ADMIN CHECKSUM TABLE $TABLE;" | awk '/CHECKSUM/{print $2}')
|
||||
echo $checksum
|
||||
}
|
||||
|
||||
check_db_consistency() {
|
||||
fail=false
|
||||
for i in $(seq $DB_COUNT); do
|
||||
local original_checksum=${checksum_ori[i]}
|
||||
local new_checksum=$(calculate_checksum "$DB${i}")
|
||||
|
||||
if [ "$original_checksum" != "$new_checksum" ]; then
|
||||
fail=true
|
||||
echo "TEST: [$TEST_NAME] checksum mismatch on database $DB${i}"
|
||||
echo "Original checksum: $original_checksum, New checksum: $new_checksum"
|
||||
else
|
||||
echo "Database $DB${i} checksum match: $new_checksum"
|
||||
fi
|
||||
done
|
||||
|
||||
if $fail; then
|
||||
echo "TEST: [$TEST_NAME] data consistency check failed!"
|
||||
return 1
|
||||
fi
|
||||
echo "TEST: [$TEST_NAME] data consistency check passed."
|
||||
return 0
|
||||
}
|
||||
|
||||
verify_dbs_empty() {
|
||||
echo "Verifying databases are empty"
|
||||
for i in $(seq $DB_COUNT); do
|
||||
db_name="$DB${i}"
|
||||
table_count=$(run_sql "USE $db_name; SHOW TABLES;" | wc -l)
|
||||
if [ "$table_count" -ne 0 ]; then
|
||||
echo "ERROR: Database $db_name is not empty"
|
||||
return 1
|
||||
fi
|
||||
done
|
||||
echo "All databases are empty"
|
||||
return 0
|
||||
}
|
||||
|
||||
run_backup_restore_test() {
|
||||
local encryption_mode=$1
|
||||
local full_encryption_args=$2
|
||||
local log_encryption_args=$3
|
||||
|
||||
echo "===== run_backup_restore_test $encryption_mode $full_encryption_args $log_encryption_args ====="
|
||||
|
||||
restart_services || { echo "Failed to restart services"; exit 1; }
|
||||
|
||||
# Drop existing databases before starting the test
|
||||
drop_db || { echo "Failed to drop databases"; exit 1; }
|
||||
|
||||
# Start log backup
|
||||
start_log_backup "local://$TEST_DIR/$PREFIX/log" "$log_encryption_args" || { echo "Failed to start log backup"; exit 1; }
|
||||
|
||||
# Create test databases and insert initial data
|
||||
create_db_with_table || { echo "Failed to create databases and tables"; exit 1; }
|
||||
|
||||
# Calculate and store original checksums
|
||||
for i in $(seq $DB_COUNT); do
|
||||
checksum_ori[${i}]=$(calculate_checksum "$DB${i}") || { echo "Failed to calculate initial checksum"; exit 1; }
|
||||
done
|
||||
|
||||
# Full backup
|
||||
echo "run full backup with $encryption_mode"
|
||||
run_br --pd "$PD_ADDR" backup full -s "local://$TEST_DIR/$PREFIX/full" $full_encryption_args || { echo "Full backup failed"; exit 1; }
|
||||
|
||||
# Insert additional test data
|
||||
insert_additional_data "${encryption_mode}_after_full_backup" || { echo "Failed to insert additional data"; exit 1; }
|
||||
|
||||
# Update checksums after inserting additional data
|
||||
for i in $(seq $DB_COUNT); do
|
||||
checksum_ori[${i}]=$(calculate_checksum "$DB${i}") || { echo "Failed to calculate checksum after insertion"; exit 1; }
|
||||
done
|
||||
|
||||
wait_log_checkpoint_advance || { echo "Failed to wait for log checkpoint"; exit 1; }
|
||||
|
||||
#sanity check pause still works
|
||||
run_br log pause --task-name $TASK_NAME --pd $PD_ADDR || { echo "Failed to pause log backup"; exit 1; }
|
||||
|
||||
#sanity check resume still works
|
||||
run_br log resume --task-name $TASK_NAME --pd $PD_ADDR || { echo "Failed to resume log backup"; exit 1; }
|
||||
|
||||
#sanity check stop still works
|
||||
run_br log stop --task-name $TASK_NAME --pd $PD_ADDR || { echo "Failed to stop log backup"; exit 1; }
|
||||
|
||||
# restart service should clean up everything
|
||||
restart_services || { echo "Failed to restart services"; exit 1; }
|
||||
|
||||
verify_dbs_empty || { echo "Failed to verify databases are empty"; exit 1; }
|
||||
|
||||
# Run pitr restore and measure the performance
|
||||
echo "restore log backup with $full_encryption_args and $log_encryption_args"
|
||||
local start_time=$(date +%s.%N)
|
||||
timeout 300 run_br --pd "$PD_ADDR" restore point -s "local://$TEST_DIR/$PREFIX/log" --full-backup-storage "local://$TEST_DIR/$PREFIX/full" $full_encryption_args $log_encryption_args || {
|
||||
echo "Log backup restore failed or timed out after 5 minutes"
|
||||
exit 1
|
||||
}
|
||||
local end_time=$(date +%s.%N)
|
||||
local duration=$(echo "$end_time - $start_time" | bc | awk '{printf "%.3f", $0}')
|
||||
echo "${encryption_mode} took ${duration} seconds"
|
||||
echo "${encryption_mode},${duration}" >> "$TEST_DIR/performance_results.csv"
|
||||
|
||||
# Check data consistency after restore
|
||||
echo "check data consistency after restore"
|
||||
check_db_consistency || { echo "TEST: [$TEST_NAME] $encryption_mode backup and restore (including log) failed"; exit 1; }
|
||||
|
||||
# sanity check truncate still works
|
||||
# make sure some files exists in log dir
|
||||
log_dir="$TEST_DIR/$PREFIX/log"
|
||||
if [ -z "$(ls -A $log_dir)" ]; then
|
||||
echo "Error: No files found in the log directory $log_dir"
|
||||
exit 1
|
||||
else
|
||||
echo "Files exist in the log directory $log_dir"
|
||||
fi
|
||||
current_time=$(date -u +"%Y-%m-%d %H:%M:%S+0000")
|
||||
run_br log truncate -s "local://$TEST_DIR/$PREFIX/log" --until "$current_time" -y || { echo "Failed to truncate log backup"; exit 1; }
|
||||
# make sure no files exist in log dir
|
||||
if [ -z "$(ls -A $log_dir)" ]; then
|
||||
echo "Error: Files still exist in the log directory $log_dir"
|
||||
exit 1
|
||||
else
|
||||
echo "No files exist in the log directory $log_dir"
|
||||
fi
|
||||
|
||||
# Clean up after the test
|
||||
drop_db || { echo "Failed to drop databases after test"; exit 1; }
|
||||
rm -rf "$TEST_DIR/$PREFIX"
|
||||
|
||||
echo "TEST: [$TEST_NAME] $encryption_mode backup and restore (including log) passed"
|
||||
}
|
||||
|
||||
start_and_wait_for_localstack() {
|
||||
# Start LocalStack in the background with only the required services
|
||||
SERVICES=s3,ec2,kms localstack start -d
|
||||
|
||||
echo "Waiting for LocalStack services to be ready..."
|
||||
max_attempts=30
|
||||
attempt=0
|
||||
while [ $attempt -lt $max_attempts ]; do
|
||||
response=$(curl -s "http://localhost:4566/_localstack/health")
|
||||
if echo "$response" | jq -e '.services.s3 == "running" or .services.s3 == "available"' > /dev/null && \
|
||||
echo "$response" | jq -e '.services.ec2 == "running" or .services.ec2 == "available"' > /dev/null && \
|
||||
echo "$response" | jq -e '.services.kms == "running" or .services.kms == "available"' > /dev/null; then
|
||||
echo "LocalStack services are ready"
|
||||
return 0
|
||||
fi
|
||||
attempt=$((attempt+1))
|
||||
echo "Waiting for LocalStack services... Attempt $attempt of $max_attempts"
|
||||
sleep 2
|
||||
done
|
||||
|
||||
echo "LocalStack services did not become ready in time"
|
||||
localstack stop
|
||||
return 1
|
||||
}
|
||||
|
||||
test_backup_encrypted_restore_unencrypted() {
|
||||
echo "===== Testing backup with encryption, restore without encryption ====="
|
||||
|
||||
restart_services || { echo "Failed to restart services"; exit 1; }
|
||||
|
||||
# Start log backup
|
||||
start_log_backup "local://$TEST_DIR/$PREFIX/log" "--log.crypter.method AES256-CTR --log.crypter.key 0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef" || { echo "Failed to start log backup"; exit 1; }
|
||||
|
||||
# Create test databases and insert initial data
|
||||
create_db_with_table || { echo "Failed to create databases and tables"; exit 1; }
|
||||
|
||||
# Backup with encryption
|
||||
run_br --pd $PD_ADDR backup full -s "local://$TEST_DIR/$PREFIX/full --crypter.method AES256-CTR --crypter.key 0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef"
|
||||
|
||||
# Insert additional test data
|
||||
insert_additional_data "insert_after_full_backup" || { echo "Failed to insert additional data"; exit 1; }
|
||||
|
||||
wait_log_checkpoint_advance || { echo "Failed to wait for log checkpoint"; exit 1; }
|
||||
|
||||
|
||||
# Stop and clean the cluster
|
||||
restart_services || { echo "Failed to restart services"; exit 1; }
|
||||
|
||||
# Try to restore without encryption (this should fail)
|
||||
if run_br --pd "$PD_ADDR" restore point -s "local://$TEST_DIR/$PREFIX/log" --full-backup-storage "local://$TEST_DIR/$PREFIX/full --crypter.method AES256-CTR --crypter.key 0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef"; then
|
||||
echo "Error: Restore without encryption should have failed, but it succeeded"
|
||||
exit 1
|
||||
else
|
||||
echo "Restore without encryption failed as expected"
|
||||
fi
|
||||
|
||||
# Clean up after the test
|
||||
drop_db || { echo "Failed to drop databases after test"; exit 1; }
|
||||
rm -rf "$TEST_DIR/$PREFIX"
|
||||
|
||||
echo "TEST: test_backup_encrypted_restore_unencrypted passed"
|
||||
}
|
||||
|
||||
|
||||
test_plaintext() {
|
||||
run_backup_restore_test "plaintext" "" ""
|
||||
}
|
||||
|
||||
test_plaintext_data_key() {
|
||||
run_backup_restore_test "plaintext-data-key" "--crypter.method AES256-CTR --crypter.key 0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef" "--log.crypter.method AES256-CTR --log.crypter.key 0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef"
|
||||
}
|
||||
|
||||
test_local_master_key() {
|
||||
_MASTER_KEY_DIR="$TEST_DIR/$PREFIX/master_key"
|
||||
mkdir -p "$_MASTER_KEY_DIR"
|
||||
openssl rand -hex 32 > "$_MASTER_KEY_DIR/master.key"
|
||||
|
||||
_MASTER_KEY_PATH="local:///$_MASTER_KEY_DIR/master.key"
|
||||
|
||||
run_backup_restore_test "local_master_key" "" "--master-key-crypter-method AES256-CTR --master-key $_MASTER_KEY_PATH"
|
||||
|
||||
rm -rf "$_MASTER_KEY_DIR"
|
||||
}
|
||||
|
||||
test_aws_kms() {
|
||||
# Start LocalStack and wait for services to be ready
|
||||
if ! start_and_wait_for_localstack; then
|
||||
echo "Failed to start LocalStack services"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
# localstack listening port
|
||||
ENDPOINT="http://localhost:4566"
|
||||
|
||||
# Create KMS key using curl
|
||||
KMS_RESPONSE=$(curl -X POST "$ENDPOINT/kms/" \
|
||||
-H "Content-Type: application/x-amz-json-1.1" \
|
||||
-H "X-Amz-Target: TrentService.CreateKey" \
|
||||
-d '{
|
||||
"Description": "My test key",
|
||||
"KeyUsage": "ENCRYPT_DECRYPT",
|
||||
"Origin": "AWS_KMS"
|
||||
}')
|
||||
|
||||
echo "KMS CreateKey response: $KMS_RESPONSE"
|
||||
|
||||
AWS_KMS_KEY_ID=$(echo "$KMS_RESPONSE" | jq -r '.KeyMetadata.KeyId')
|
||||
AWS_ACCESS_KEY_ID="TEST"
|
||||
AWS_SECRET_ACCESS_KEY="TEST"
|
||||
REGION="us-east-1"
|
||||
|
||||
AWS_KMS_URI="aws-kms:///${AWS_KMS_KEY_ID}?AWS_ACCESS_KEY_ID=${AWS_ACCESS_KEY_ID}&AWS_SECRET_ACCESS_KEY=${AWS_SECRET_ACCESS_KEY}®ION=${REGION}&ENDPOINT=${ENDPOINT}"
|
||||
|
||||
run_backup_restore_test "aws_kms" "--crypter.method AES256-CTR --crypter.key 0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef" "--master-key-crypter-method AES256-CTR --master-key $AWS_KMS_URI"
|
||||
|
||||
# Stop LocalStack
|
||||
localstack stop
|
||||
}
|
||||
|
||||
test_aws_kms_with_iam() {
|
||||
# Start LocalStack and wait for services to be ready
|
||||
if ! start_and_wait_for_localstack; then
|
||||
echo "Failed to start LocalStack services"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
# localstack listening port
|
||||
ENDPOINT="http://localhost:4566"
|
||||
|
||||
# Create KMS key using curl
|
||||
KMS_RESPONSE=$(curl -X POST "$ENDPOINT/kms/" \
|
||||
-H "Content-Type: application/x-amz-json-1.1" \
|
||||
-H "X-Amz-Target: TrentService.CreateKey" \
|
||||
-d '{
|
||||
"Description": "My test key",
|
||||
"KeyUsage": "ENCRYPT_DECRYPT",
|
||||
"Origin": "AWS_KMS"
|
||||
}')
|
||||
|
||||
echo "KMS CreateKey response: $KMS_RESPONSE"
|
||||
|
||||
AWS_KMS_KEY_ID=$(echo "$KMS_RESPONSE" | jq -r '.KeyMetadata.KeyId')
|
||||
# export these two as required by aws-kms backend
|
||||
export AWS_ACCESS_KEY_ID="TEST"
|
||||
export AWS_SECRET_ACCESS_KEY="TEST"
|
||||
REGION="us-east-1"
|
||||
|
||||
AWS_KMS_URI="aws-kms:///${AWS_KMS_KEY_ID}?®ION=${REGION}&ENDPOINT=${ENDPOINT}"
|
||||
|
||||
run_backup_restore_test "aws_kms" "--crypter.method AES256-CTR --crypter.key 0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef" "--master-key-crypter-method AES256-CTR --master-key $AWS_KMS_URI"
|
||||
|
||||
# Stop LocalStack
|
||||
localstack stop
|
||||
}
|
||||
|
||||
test_gcp_kms() {
|
||||
# Ensure GCP credentials are set
|
||||
if [ -z "$GOOGLE_APPLICATION_CREDENTIALS" ]; then
|
||||
echo "GCP credentials not set. Skipping GCP KMS test."
|
||||
return
|
||||
fi
|
||||
|
||||
# Replace these with your actual GCP KMS details
|
||||
GCP_PROJECT_ID="carbide-network-435219-q3"
|
||||
GCP_LOCATION="us-west1"
|
||||
GCP_KEY_RING="local-kms-testing"
|
||||
GCP_KEY_NAME="kms-testing-key"
|
||||
GCP_CREDENTIALS="$GOOGLE_APPLICATION_CREDENTIALS"
|
||||
|
||||
GCP_KMS_URI="gcp-kms:///projects/$GCP_PROJECT_ID/locations/$GCP_LOCATION/keyRings/$GCP_KEY_RING/cryptoKeys/$GCP_KEY_NAME?AUTH=specified&CREDENTIALS=$GCP_CREDENTIALS"
|
||||
|
||||
run_backup_restore_test "gcp_kms" "--crypter.method AES256-CTR --crypter.key 0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef" "--master-key-crypter-method AES256-CTR --master-key $GCP_KMS_URI"
|
||||
}
|
||||
|
||||
test_mixed_full_encrypted_log_plain() {
|
||||
local full_encryption_args="--crypter.method AES256-CTR --crypter.key 0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef"
|
||||
local log_encryption_args=""
|
||||
|
||||
run_backup_restore_test "mixed_full_encrypted_log_plain" "$full_encryption_args" "$log_encryption_args"
|
||||
}
|
||||
|
||||
test_mixed_full_plain_log_encrypted() {
|
||||
local full_encryption_args=""
|
||||
local log_encryption_args="--log.crypter.method AES256-CTR --log.crypter.key 0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef"
|
||||
|
||||
run_backup_restore_test "mixed_full_plain_log_encrypted" "$full_encryption_args" "$log_encryption_args"
|
||||
}
|
||||
|
||||
# Initialize performance results file
|
||||
echo "Operation,Encryption Mode,Duration (seconds)" > "$TEST_DIR/performance_results.csv"
|
||||
|
||||
# Run tests
|
||||
test_backup_encrypted_restore_unencrypted
|
||||
test_plaintext
|
||||
test_plaintext_data_key
|
||||
test_local_master_key
|
||||
# some issue running in CI, will fix later
|
||||
#test_aws_kms
|
||||
#test_aws_kms_with_iam
|
||||
test_mixed_full_encrypted_log_plain
|
||||
test_mixed_full_plain_log_encrypted
|
||||
|
||||
# uncomment for manual GCP KMS testing
|
||||
#test_gcp_kms
|
||||
|
||||
echo "All encryption tests passed successfully"
|
||||
|
||||
# Display performance results
|
||||
echo "Performance Results:"
|
||||
cat "$TEST_DIR/performance_results.csv"
|
||||
|
||||
12
br/tests/br_encryption/workload
Normal file
12
br/tests/br_encryption/workload
Normal file
@ -0,0 +1,12 @@
|
||||
recordcount=1000
|
||||
operationcount=0
|
||||
workload=core
|
||||
|
||||
readallfields=true
|
||||
|
||||
readproportion=0
|
||||
updateproportion=0
|
||||
scanproportion=0
|
||||
insertproportion=0
|
||||
|
||||
requestdistribution=uniform
|
||||
@ -36,3 +36,6 @@ path = "/tmp/backup_restore_test/master-key-file"
|
||||
|
||||
[log-backup]
|
||||
max-flush-interval = "50s"
|
||||
[gc]
|
||||
ratio-threshold = 1.1
|
||||
|
||||
|
||||
@ -43,6 +43,8 @@ minio_cli_url="${file_server_url}/download/builds/minio/minio/RELEASE.2020-02-27
|
||||
kes_url="${file_server_url}/download/kes"
|
||||
fake_gcs_server_url="${file_server_url}/download/builds/fake-gcs-server"
|
||||
brv_url="${file_server_url}/download/builds/brv4.0.8"
|
||||
# already manually uploaded to file server for localstack
|
||||
localstack_url="${file_server_url}/download/localstack-cli.tar.gz"
|
||||
|
||||
set -o nounset
|
||||
|
||||
@ -103,6 +105,13 @@ function main() {
|
||||
download "$fake_gcs_server_url" "fake-gcs-server" "third_bin/fake-gcs-server"
|
||||
download "$brv_url" "brv4.0.8" "third_bin/brv4.0.8"
|
||||
|
||||
# Download and set up LocalStack
|
||||
download "$localstack_url" "localstack-cli.tar.gz" "tmp/localstack-cli.tar.gz"
|
||||
mkdir -p tmp/localstack_extract
|
||||
tar -xzf tmp/localstack-cli.tar.gz -C tmp/localstack_extract
|
||||
mv tmp/localstack_extract/localstack/* third_bin/
|
||||
rm -rf tmp/localstack_extract
|
||||
|
||||
chmod +x third_bin/*
|
||||
rm -rf tmp
|
||||
rm -rf third_bin/bin
|
||||
|
||||
@ -26,7 +26,7 @@ groups=(
|
||||
["G03"]='br_incompatible_tidb_config br_incremental br_incremental_index br_incremental_only_ddl br_incremental_same_table br_insert_after_restore br_key_locked br_log_test br_move_backup br_mv_index br_other br_partition_add_index br_tidb_placement_policy br_tiflash br_tiflash_conflict'
|
||||
["G04"]='br_range br_replica_read br_restore_TDE_enable br_restore_log_task_enable br_s3 br_shuffle_leader br_shuffle_region br_single_table'
|
||||
["G05"]='br_skip_checksum br_split_region_fail br_systables br_table_filter br_txn br_stats br_clustered_index br_crypter'
|
||||
["G06"]='br_tikv_outage br_tikv_outage3 br_restore_checkpoint'
|
||||
["G06"]='br_tikv_outage br_tikv_outage3 br_restore_checkpoint br_encryption'
|
||||
["G07"]='br_pitr'
|
||||
["G08"]='br_tikv_outage2 br_ttl br_views_and_sequences br_z_gc_safepoint br_autorandom br_file_corruption'
|
||||
)
|
||||
|
||||
@ -1,9 +1,9 @@
|
||||
package(default_visibility = ["//visibility:public"])
|
||||
|
||||
load("@io_bazel_rules_go//go:def.bzl", "go_library", "nogo")
|
||||
load("@bazel_skylib//rules:common_settings.bzl", "bool_flag")
|
||||
load("@io_bazel_rules_go//go:def.bzl", "go_library", "nogo")
|
||||
load("//build/linter/staticcheck:def.bzl", "staticcheck_analyzers")
|
||||
|
||||
package(default_visibility = ["//visibility:public"])
|
||||
|
||||
bool_flag(
|
||||
name = "with_nogo_flag",
|
||||
build_setting_default = False,
|
||||
|
||||
3
go.mod
3
go.mod
@ -3,6 +3,7 @@ module github.com/pingcap/tidb
|
||||
go 1.21
|
||||
|
||||
require (
|
||||
cloud.google.com/go/kms v1.15.7
|
||||
cloud.google.com/go/storage v1.38.0
|
||||
github.com/Azure/azure-sdk-for-go/sdk/azcore v1.12.0
|
||||
github.com/Azure/azure-sdk-for-go/sdk/azidentity v1.6.0
|
||||
@ -305,7 +306,7 @@ require (
|
||||
google.golang.org/genproto v0.0.0-20240227224415-6ceb2ff114de // indirect
|
||||
google.golang.org/genproto/googleapis/api v0.0.0-20240318140521-94a12d6c2237 // indirect
|
||||
google.golang.org/genproto/googleapis/rpc v0.0.0-20240515191416-fc5f0ca64291 // indirect
|
||||
google.golang.org/protobuf v1.34.2 // indirect
|
||||
google.golang.org/protobuf v1.34.2
|
||||
gopkg.in/inf.v0 v0.9.1 // indirect
|
||||
gopkg.in/natefinch/lumberjack.v2 v2.2.1 // indirect
|
||||
gopkg.in/yaml.v3 v3.0.1 // indirect
|
||||
|
||||
@ -47,8 +47,10 @@ stop() {
|
||||
}
|
||||
|
||||
restart_services() {
|
||||
echo "Restarting services"
|
||||
stop_services
|
||||
start_services
|
||||
echo "Services restarted"
|
||||
}
|
||||
|
||||
stop_services() {
|
||||
|
||||
Reference in New Issue
Block a user