From 7c88876a7fd00a76b60d276d8b750b749d04bb68 Mon Sep 17 00:00:00 2001 From: Wenqi Mou Date: Wed, 25 Sep 2024 22:31:09 -0400 Subject: [PATCH] br: add log backup/restore encryption support (#55757) close pingcap/tidb#55834 --- br/cmd/br/restore.go | 4 +- br/pkg/checkpoint/checkpoint.go | 3 +- br/pkg/conn/conn.go | 4 +- br/pkg/encryption/BUILD.bazel | 15 + br/pkg/encryption/manager.go | 93 ++++ br/pkg/encryption/master_key/BUILD.bazel | 44 ++ br/pkg/encryption/master_key/common.go | 60 +++ br/pkg/encryption/master_key/file_backend.go | 68 +++ .../master_key/file_backend_test.go | 105 +++++ br/pkg/encryption/master_key/kms_backend.go | 88 ++++ .../encryption/master_key/kms_backend_test.go | 113 +++++ br/pkg/encryption/master_key/master_key.go | 77 ++++ br/pkg/encryption/master_key/mem_backend.go | 103 +++++ .../encryption/master_key/mem_backend_test.go | 115 +++++ .../master_key/multi_master_key_backend.go | 64 +++ .../multi_master_key_backend_test.go | 105 +++++ br/pkg/kms/BUILD.bazel | 27 ++ br/pkg/kms/aws.go | 92 ++++ br/pkg/kms/common.go | 65 +++ br/pkg/kms/gcp.go | 104 +++++ br/pkg/kms/kms.go | 13 + br/pkg/metautil/BUILD.bazel | 2 + br/pkg/metautil/metafile.go | 26 +- br/pkg/metautil/metafile_test.go | 7 +- br/pkg/metautil/statsfile.go | 3 +- br/pkg/restore/log_client/BUILD.bazel | 3 + br/pkg/restore/log_client/client.go | 34 +- br/pkg/restore/log_client/client_test.go | 10 +- br/pkg/restore/log_client/export_test.go | 2 + br/pkg/restore/log_client/import.go | 34 +- br/pkg/restore/log_client/import_test.go | 5 +- br/pkg/restore/log_client/log_file_manager.go | 9 +- br/pkg/restore/snap_client/client.go | 2 +- br/pkg/stream/BUILD.bazel | 2 + br/pkg/stream/stream_mgr.go | 72 ++- br/pkg/stream/stream_misc_test.go | 6 +- br/pkg/task/BUILD.bazel | 5 +- br/pkg/task/backup.go | 1 - br/pkg/task/common.go | 159 ++++++- br/pkg/task/common_test.go | 130 ++++++ br/pkg/task/encryption.go | 166 +++++++ br/pkg/task/encryption_test.go | 192 ++++++++ br/pkg/task/restore.go | 14 +- br/pkg/task/stream.go | 114 +++-- br/pkg/task/stream_test.go | 47 +- br/pkg/utils/BUILD.bazel | 3 + br/pkg/utils/encryption.go | 30 ++ br/tests/README.md | 4 +- br/tests/br_encryption/run.sh | 435 ++++++++++++++++++ br/tests/br_encryption/workload | 12 + br/tests/config/tikv.toml | 3 + .../download_integration_test_binaries.sh | 9 + br/tests/run_group_br_tests.sh | 2 +- build/BUILD.bazel | 6 +- go.mod | 3 +- tests/_utils/run_services | 2 + 56 files changed, 2750 insertions(+), 166 deletions(-) create mode 100644 br/pkg/encryption/BUILD.bazel create mode 100644 br/pkg/encryption/manager.go create mode 100644 br/pkg/encryption/master_key/BUILD.bazel create mode 100644 br/pkg/encryption/master_key/common.go create mode 100644 br/pkg/encryption/master_key/file_backend.go create mode 100644 br/pkg/encryption/master_key/file_backend_test.go create mode 100644 br/pkg/encryption/master_key/kms_backend.go create mode 100644 br/pkg/encryption/master_key/kms_backend_test.go create mode 100644 br/pkg/encryption/master_key/master_key.go create mode 100644 br/pkg/encryption/master_key/mem_backend.go create mode 100644 br/pkg/encryption/master_key/mem_backend_test.go create mode 100644 br/pkg/encryption/master_key/multi_master_key_backend.go create mode 100644 br/pkg/encryption/master_key/multi_master_key_backend_test.go create mode 100644 br/pkg/kms/BUILD.bazel create mode 100644 br/pkg/kms/aws.go create mode 100644 br/pkg/kms/common.go create mode 100644 br/pkg/kms/gcp.go create mode 100644 br/pkg/kms/kms.go create mode 100644 br/pkg/task/encryption.go create mode 100644 br/pkg/task/encryption_test.go create mode 100644 br/pkg/utils/encryption.go create mode 100755 br/tests/br_encryption/run.sh create mode 100644 br/tests/br_encryption/workload diff --git a/br/cmd/br/restore.go b/br/cmd/br/restore.go index 916ed3b703..f991163813 100644 --- a/br/cmd/br/restore.go +++ b/br/cmd/br/restore.go @@ -74,14 +74,14 @@ func runRestoreCommand(command *cobra.Command, cmdName string) error { if err := task.RunRestore(GetDefaultContext(), tidbGlue, cmdName, &cfg); err != nil { log.Error("failed to restore", zap.Error(err)) - printWorkaroundOnFullRestoreError(command, err) + printWorkaroundOnFullRestoreError(err) return errors.Trace(err) } return nil } // print workaround when we met not fresh or incompatible cluster error on full cluster restore -func printWorkaroundOnFullRestoreError(command *cobra.Command, err error) { +func printWorkaroundOnFullRestoreError(err error) { if !errors.ErrorEqual(err, berrors.ErrRestoreNotFreshCluster) && !errors.ErrorEqual(err, berrors.ErrRestoreIncompatibleSys) { return diff --git a/br/pkg/checkpoint/checkpoint.go b/br/pkg/checkpoint/checkpoint.go index 493334d748..d5352380e8 100644 --- a/br/pkg/checkpoint/checkpoint.go +++ b/br/pkg/checkpoint/checkpoint.go @@ -33,6 +33,7 @@ import ( "github.com/pingcap/tidb/br/pkg/rtree" "github.com/pingcap/tidb/br/pkg/storage" "github.com/pingcap/tidb/br/pkg/summary" + "github.com/pingcap/tidb/br/pkg/utils" "github.com/pingcap/tidb/pkg/util" "go.uber.org/zap" "golang.org/x/sync/errgroup" @@ -561,7 +562,7 @@ func parseCheckpointData[K KeyType, V ValueType]( *pastDureTime = checkpointData.DureTime } for _, meta := range checkpointData.RangeGroupMetas { - decryptContent, err := metautil.Decrypt(meta.RangeGroupsEncriptedData, cipher, meta.CipherIv) + decryptContent, err := utils.Decrypt(meta.RangeGroupsEncriptedData, cipher, meta.CipherIv) if err != nil { return errors.Trace(err) } diff --git a/br/pkg/conn/conn.go b/br/pkg/conn/conn.go index f34a6c276c..1f7da470b6 100644 --- a/br/pkg/conn/conn.go +++ b/br/pkg/conn/conn.go @@ -299,8 +299,8 @@ func (mgr *Mgr) Close() { mgr.PdController.Close() } -// GetTS gets current ts from pd. -func (mgr *Mgr) GetTS(ctx context.Context) (uint64, error) { +// GetCurrentTsFromPD gets current ts from PD. +func (mgr *Mgr) GetCurrentTsFromPD(ctx context.Context) (uint64, error) { p, l, err := mgr.GetPDClient().GetTS(ctx) if err != nil { return 0, errors.Trace(err) diff --git a/br/pkg/encryption/BUILD.bazel b/br/pkg/encryption/BUILD.bazel new file mode 100644 index 0000000000..1f45ff8104 --- /dev/null +++ b/br/pkg/encryption/BUILD.bazel @@ -0,0 +1,15 @@ +load("@io_bazel_rules_go//go:def.bzl", "go_library") + +go_library( + name = "encryption", + srcs = ["manager.go"], + importpath = "github.com/pingcap/tidb/br/pkg/encryption", + visibility = ["//visibility:public"], + deps = [ + "//br/pkg/encryption/master_key", + "//br/pkg/utils", + "@com_github_pingcap_errors//:errors", + "@com_github_pingcap_kvproto//pkg/brpb", + "@com_github_pingcap_kvproto//pkg/encryptionpb", + ], +) diff --git a/br/pkg/encryption/manager.go b/br/pkg/encryption/manager.go new file mode 100644 index 0000000000..7edaba0b21 --- /dev/null +++ b/br/pkg/encryption/manager.go @@ -0,0 +1,93 @@ +// Copyright 2024 PingCAP, Inc. Licensed under Apache-2.0. + +package encryption + +import ( + "context" + + "github.com/pingcap/errors" + backuppb "github.com/pingcap/kvproto/pkg/brpb" + "github.com/pingcap/kvproto/pkg/encryptionpb" + encryption "github.com/pingcap/tidb/br/pkg/encryption/master_key" + "github.com/pingcap/tidb/br/pkg/utils" +) + +type Manager struct { + cipherInfo *backuppb.CipherInfo + masterKeyBackends *encryption.MultiMasterKeyBackend + encryptionMethod *encryptionpb.EncryptionMethod +} + +func NewManager(cipherInfo *backuppb.CipherInfo, masterKeyConfigs *backuppb.MasterKeyConfig) (*Manager, error) { + // should never happen since config has default + if cipherInfo == nil || masterKeyConfigs == nil { + return nil, errors.New("cipherInfo or masterKeyConfigs is nil") + } + + if utils.IsEffectiveEncryptionMethod(cipherInfo.CipherType) { + return &Manager{ + cipherInfo: cipherInfo, + masterKeyBackends: nil, + encryptionMethod: nil, + }, nil + } + if utils.IsEffectiveEncryptionMethod(masterKeyConfigs.EncryptionType) { + masterKeyBackends, err := encryption.NewMultiMasterKeyBackend(masterKeyConfigs.GetMasterKeys()) + if err != nil { + return nil, errors.Trace(err) + } + return &Manager{ + cipherInfo: nil, + masterKeyBackends: masterKeyBackends, + encryptionMethod: &masterKeyConfigs.EncryptionType, + }, nil + } + return nil, nil +} + +func (m *Manager) Decrypt(ctx context.Context, content []byte, fileEncryptionInfo *encryptionpb.FileEncryptionInfo) ( + []byte, error) { + switch mode := fileEncryptionInfo.Mode.(type) { + case *encryptionpb.FileEncryptionInfo_PlainTextDataKey: + if m.cipherInfo == nil { + return nil, errors.New("plaintext data key info is required but not set") + } + decryptedContent, err := utils.Decrypt(content, m.cipherInfo, fileEncryptionInfo.FileIv) + if err != nil { + return nil, errors.Annotate(err, "failed to decrypt content using plaintext data key") + } + return decryptedContent, nil + case *encryptionpb.FileEncryptionInfo_MasterKeyBased: + encryptedContents := fileEncryptionInfo.GetMasterKeyBased().DataKeyEncryptedContent + if len(encryptedContents) == 0 { + return nil, errors.New("should contain at least one encrypted data key") + } + // pick first one, the list is for future expansion of multiple encrypted data keys by different master key backend + encryptedContent := encryptedContents[0] + decryptedDataKey, err := m.masterKeyBackends.Decrypt(ctx, encryptedContent) + if err != nil { + return nil, errors.Annotate(err, "failed to decrypt data key using master key") + } + + cipherInfo := backuppb.CipherInfo{ + CipherType: fileEncryptionInfo.EncryptionMethod, + CipherKey: decryptedDataKey, + } + decryptedContent, err := utils.Decrypt(content, &cipherInfo, fileEncryptionInfo.FileIv) + if err != nil { + return nil, errors.Annotate(err, "failed to decrypt content using decrypted data key") + } + return decryptedContent, nil + default: + return nil, errors.Errorf("internal error: unsupported encryption mode type %T", mode) + } +} + +func (m *Manager) Close() { + if m == nil { + return + } + if m.masterKeyBackends != nil { + m.masterKeyBackends.Close() + } +} diff --git a/br/pkg/encryption/master_key/BUILD.bazel b/br/pkg/encryption/master_key/BUILD.bazel new file mode 100644 index 0000000000..cc412a52c1 --- /dev/null +++ b/br/pkg/encryption/master_key/BUILD.bazel @@ -0,0 +1,44 @@ +load("@io_bazel_rules_go//go:def.bzl", "go_library", "go_test") + +go_library( + name = "master_key", + srcs = [ + "common.go", + "file_backend.go", + "kms_backend.go", + "master_key.go", + "mem_backend.go", + "multi_master_key_backend.go", + ], + importpath = "github.com/pingcap/tidb/br/pkg/encryption/master_key", + visibility = ["//visibility:public"], + deps = [ + "//br/pkg/kms:aws", + "//br/pkg/utils", + "@com_github_pingcap_errors//:errors", + "@com_github_pingcap_kvproto//pkg/encryptionpb", + "@com_github_pingcap_log//:log", + "@org_uber_go_multierr//:multierr", + "@org_uber_go_zap//:zap", + ], +) + +go_test( + name = "master_key_test", + timeout = "short", + srcs = [ + "file_backend_test.go", + "kms_backend_test.go", + "mem_backend_test.go", + "multi_master_key_backend_test.go", + ], + embed = [":master_key"], + flaky = True, + shard_count = 11, + deps = [ + "@com_github_pingcap_errors//:errors", + "@com_github_pingcap_kvproto//pkg/encryptionpb", + "@com_github_stretchr_testify//mock", + "@com_github_stretchr_testify//require", + ], +) diff --git a/br/pkg/encryption/master_key/common.go b/br/pkg/encryption/master_key/common.go new file mode 100644 index 0000000000..b9e70c9905 --- /dev/null +++ b/br/pkg/encryption/master_key/common.go @@ -0,0 +1,60 @@ +// Copyright 2024 PingCAP, Inc. Licensed under Apache-2.0. + +package encryption + +import ( + "crypto/rand" + + "github.com/pingcap/errors" +) + +// must keep it same with the constants in TiKV implementation +const ( + MetadataKeyMethod string = "method" + MetadataKeyIv string = "iv" + MetadataKeyAesGcmTag string = "aes_gcm_tag" + MetadataKeyKmsVendor string = "kms_vendor" + MetadataKeyKmsCiphertextKey string = "kms_ciphertext_key" + MetadataMethodAes256Gcm string = "aes256-gcm" +) + +const ( + GcmIv12 = 12 + CtrIv16 = 16 +) + +type IvType int + +const ( + IvTypeGcm IvType = iota + IvTypeCtr +) + +type IV struct { + Type IvType + Data []byte +} + +func NewIVGcm() (IV, error) { + iv := make([]byte, GcmIv12) + _, err := rand.Read(iv) + if err != nil { + return IV{}, err + } + return IV{Type: IvTypeGcm, Data: iv}, nil +} + +func NewIVFromSlice(src []byte) (IV, error) { + switch len(src) { + case CtrIv16: + return IV{Type: IvTypeCtr, Data: append([]byte(nil), src...)}, nil + case GcmIv12: + return IV{Type: IvTypeGcm, Data: append([]byte(nil), src...)}, nil + default: + return IV{}, errors.Errorf("invalid IV length, must be 12 or 16 bytes, got %d", len(src)) + } +} + +func (iv IV) AsSlice() []byte { + return iv.Data +} diff --git a/br/pkg/encryption/master_key/file_backend.go b/br/pkg/encryption/master_key/file_backend.go new file mode 100644 index 0000000000..20c4ed0d7d --- /dev/null +++ b/br/pkg/encryption/master_key/file_backend.go @@ -0,0 +1,68 @@ +// Copyright 2024 PingCAP, Inc. Licensed under Apache-2.0. + +package encryption + +import ( + "context" + "encoding/hex" + "os" + + "github.com/pingcap/errors" + "github.com/pingcap/kvproto/pkg/encryptionpb" +) + +const AesGcmKeyLen = 32 // AES-256 key length + +// FileBackend is ported from TiKV FileBackend +type FileBackend struct { + memCache *MemAesGcmBackend +} + +func createFileBackend(keyPath string) (*FileBackend, error) { + // FileBackend uses AES-256-GCM + keyLen := AesGcmKeyLen + + content, err := os.ReadFile(keyPath) + if err != nil { + return nil, errors.Annotate(err, "failed to read master key file from disk") + } + + fileLen := len(content) + expectedLen := keyLen*2 + 1 // hex-encoded key + newline + + if fileLen != expectedLen { + return nil, errors.Errorf("mismatch master key file size, expected %d, actual %d", expectedLen, fileLen) + } + + if content[fileLen-1] != '\n' { + return nil, errors.Errorf("master key file should end with newline") + } + + key, err := hex.DecodeString(string(content[:fileLen-1])) + if err != nil { + return nil, errors.Annotate(err, "failed to decode hex format master key from file") + } + + backend, err := NewMemAesGcmBackend(key) + if err != nil { + return nil, errors.Annotate(err, "failed to create MemAesGcmBackend") + } + + return &FileBackend{memCache: backend}, nil +} + +func (f *FileBackend) Encrypt(ctx context.Context, plaintext []byte) (*encryptionpb.EncryptedContent, error) { + iv, err := NewIVGcm() + if err != nil { + return nil, err + } + return f.memCache.EncryptContent(ctx, plaintext, iv) +} + +func (f *FileBackend) Decrypt(ctx context.Context, content *encryptionpb.EncryptedContent) ([]byte, error) { + return f.memCache.DecryptContent(ctx, content) +} + +func (f *FileBackend) Close() { + // nothing to close +} diff --git a/br/pkg/encryption/master_key/file_backend_test.go b/br/pkg/encryption/master_key/file_backend_test.go new file mode 100644 index 0000000000..1ac91383d0 --- /dev/null +++ b/br/pkg/encryption/master_key/file_backend_test.go @@ -0,0 +1,105 @@ +// Copyright 2024 PingCAP, Inc. Licensed under Apache-2.0. + +package encryption + +import ( + "context" + "encoding/hex" + "os" + "testing" + + "github.com/stretchr/testify/require" +) + +// TempKeyFile represents a temporary key file for testing +type TempKeyFile struct { + Path string + file *os.File +} + +// Cleanup closes and removes the temporary file +func (tkf *TempKeyFile) Cleanup() { + if tkf.file != nil { + tkf.file.Close() + } + os.Remove(tkf.Path) +} + +// createMasterKeyFile creates a temporary master key file for testing +func createMasterKeyFile() (*TempKeyFile, error) { + tempFile, err := os.CreateTemp("", "test_key_*.txt") + if err != nil { + return nil, err + } + + _, err = tempFile.WriteString("c3d99825f2181f4808acd2068eac7441a65bd428f14d2aab43fefc0129091139\n") + if err != nil { + tempFile.Close() + os.Remove(tempFile.Name()) + return nil, err + } + + return &TempKeyFile{ + Path: tempFile.Name(), + file: tempFile, + }, nil +} + +func TestFileBackendAes256Gcm(t *testing.T) { + pt, err := hex.DecodeString("25431587e9ecffc7c37f8d6d52a9bc3310651d46fb0e3bad2726c8f2db653749") + require.NoError(t, err) + ct, err := hex.DecodeString("84e5f23f95648fa247cb28eef53abec947dbf05ac953734618111583840bd980") + require.NoError(t, err) + ivBytes, err := hex.DecodeString("cafabd9672ca6c79a2fbdc22") + require.NoError(t, err) + + tempKeyFile, err := createMasterKeyFile() + require.NoError(t, err) + defer tempKeyFile.Cleanup() + + backend, err := createFileBackend(tempKeyFile.Path) + require.NoError(t, err) + + ctx := context.Background() + iv, err := NewIVFromSlice(ivBytes) + require.NoError(t, err) + + encryptedContent, err := backend.memCache.EncryptContent(ctx, pt, iv) + require.NoError(t, err) + require.Equal(t, ct, encryptedContent.Content) + + plaintext, err := backend.Decrypt(ctx, encryptedContent) + require.NoError(t, err) + require.Equal(t, pt, plaintext) +} + +func TestFileBackendAuthenticate(t *testing.T) { + pt := []byte{1, 2, 3} + + tempKeyFile, err := createMasterKeyFile() + require.NoError(t, err) + defer tempKeyFile.Cleanup() + + backend, err := createFileBackend(tempKeyFile.Path) + require.NoError(t, err) + + ctx := context.Background() + encryptedContent, err := backend.Encrypt(ctx, pt) + require.NoError(t, err) + + plaintext, err := backend.Decrypt(ctx, encryptedContent) + require.NoError(t, err) + require.Equal(t, pt, plaintext) + + // Test checksum mismatch + encryptedContent1 := *encryptedContent + encryptedContent1.Metadata[MetadataKeyAesGcmTag][0] ^= 0xFF + _, err = backend.Decrypt(ctx, &encryptedContent1) + require.ErrorContains(t, err, wrongMasterKey) + + // Test checksum not found + encryptedContent2 := *encryptedContent + delete(encryptedContent2.Metadata, MetadataKeyAesGcmTag) + _, err = backend.Decrypt(ctx, &encryptedContent2) + require.ErrorContains(t, err, gcmTagNotFound) +} diff --git a/br/pkg/encryption/master_key/kms_backend.go b/br/pkg/encryption/master_key/kms_backend.go new file mode 100644 index 0000000000..1538a379d9 --- /dev/null +++ b/br/pkg/encryption/master_key/kms_backend.go @@ -0,0 +1,88 @@ +// Copyright 2024 PingCAP, Inc. Licensed under Apache-2.0. + +package encryption + +import ( + "context" + "sync" + + "github.com/pingcap/errors" + "github.com/pingcap/kvproto/pkg/encryptionpb" + "github.com/pingcap/tidb/br/pkg/kms" + "github.com/pingcap/tidb/br/pkg/utils" +) + +type CachedKeys struct { + encryptionBackend *MemAesGcmBackend + cachedCiphertextKey *kms.EncryptedKey +} + +type KmsBackend struct { + state struct { + sync.Mutex + cached *CachedKeys + } + kmsProvider kms.Provider +} + +func NewKmsBackend(kmsProvider kms.Provider) (*KmsBackend, error) { + return &KmsBackend{ + kmsProvider: kmsProvider, + }, nil +} + +func (k *KmsBackend) Decrypt(ctx context.Context, content *encryptionpb.EncryptedContent) ([]byte, error) { + vendorName := k.kmsProvider.Name() + if val, ok := content.Metadata[MetadataKeyKmsVendor]; !ok { + return nil, errors.New("wrong master key: missing KMS vendor") + } else if string(val) != vendorName { + return nil, errors.Errorf("KMS vendor mismatch expect %s got %s", vendorName, string(val)) + } + + ciphertextKeyBytes, ok := content.Metadata[MetadataKeyKmsCiphertextKey] + if !ok { + return nil, errors.New("KMS ciphertext key not found") + } + ciphertextKey, err := kms.NewEncryptedKey(ciphertextKeyBytes) + if err != nil { + return nil, errors.Annotate(err, "failed to create encrypted key") + } + + k.state.Lock() + defer k.state.Unlock() + + if k.state.cached != nil && k.state.cached.cachedCiphertextKey.Equal(&ciphertextKey) { + return k.state.cached.encryptionBackend.DecryptContent(ctx, content) + } + + // piggyback on NewDownloadSSTBackoffer, a refactor is ongoing to remove all the backoffers + // so user don't need to write a backoffer for every type + decryptedKey, err := + utils.WithRetryV2(ctx, utils.NewDownloadSSTBackoffer(), func(ctx context.Context) ([]byte, error) { + return k.kmsProvider.DecryptDataKey(ctx, ciphertextKey) + }) + if err != nil { + return nil, errors.Annotate(err, "decrypt encrypted key failed") + } + + plaintextKey, err := kms.NewPlainKey(decryptedKey, kms.CryptographyTypeAesGcm256) + if err != nil { + return nil, errors.Annotate(err, "decrypt encrypted key failed") + } + + backend, err := NewMemAesGcmBackend(plaintextKey.Key()) + if err != nil { + return nil, errors.Annotate(err, "failed to create MemAesGcmBackend") + } + + k.state.cached = &CachedKeys{ + encryptionBackend: backend, + cachedCiphertextKey: &ciphertextKey, + } + + return k.state.cached.encryptionBackend.DecryptContent(ctx, content) +} + +func (k *KmsBackend) Close() { + k.kmsProvider.Close() +} diff --git a/br/pkg/encryption/master_key/kms_backend_test.go b/br/pkg/encryption/master_key/kms_backend_test.go new file mode 100644 index 0000000000..09e97f020c --- /dev/null +++ b/br/pkg/encryption/master_key/kms_backend_test.go @@ -0,0 +1,113 @@ +// Copyright 2024 PingCAP, Inc. Licensed under Apache-2.0. + +package encryption + +import ( + "context" + "crypto/rand" + "testing" + + "github.com/pingcap/kvproto/pkg/encryptionpb" + "github.com/stretchr/testify/require" +) + +type mockKmsProvider struct { + name string + decryptCounter int +} + +func (m *mockKmsProvider) Name() string { + return m.name +} + +func (m *mockKmsProvider) DecryptDataKey(_ctx context.Context, _encryptedKey []byte) ([]byte, error) { + m.decryptCounter++ + key := make([]byte, 32) // 256 bits = 32 bytes + _, err := rand.Read(key) + if err != nil { + return nil, err + } + return key, nil +} + +func (m *mockKmsProvider) Close() { + // do nothing +} + +func TestKmsBackendDecrypt(t *testing.T) { + ctx := context.Background() + mockProvider := &mockKmsProvider{name: "mock_kms"} + backend, err := NewKmsBackend(mockProvider) + require.NoError(t, err) + + ciphertextKey := []byte("ciphertext_key") + content := &encryptionpb.EncryptedContent{ + Metadata: map[string][]byte{ + MetadataKeyKmsVendor: []byte("mock_kms"), + MetadataKeyKmsCiphertextKey: ciphertextKey, + }, + Content: []byte("encrypted_content"), + } + + // First decryption + _, _ = backend.Decrypt(ctx, content) + require.Equal(t, 1, mockProvider.decryptCounter, "KMS provider should be called once") + + // Second decryption with the same ciphertext key (should use cache) + _, _ = backend.Decrypt(ctx, content) + require.Equal(t, 1, mockProvider.decryptCounter, "KMS provider should not be called again") + + // Third decryption with a different ciphertext key + content.Metadata[MetadataKeyKmsCiphertextKey] = []byte("new_ciphertext_key") + _, _ = backend.Decrypt(ctx, content) + require.Equal(t, 2, mockProvider.decryptCounter, "KMS provider should be called again for a new key") +} + +func TestKmsBackendDecryptErrors(t *testing.T) { + ctx := context.Background() + mockProvider := &mockKmsProvider{name: "mock_kms"} + backend, err := NewKmsBackend(mockProvider) + require.NoError(t, err) + + testCases := []struct { + name string + content *encryptionpb.EncryptedContent + errMsg string + }{ + { + name: "missing KMS vendor", + content: &encryptionpb.EncryptedContent{ + Metadata: map[string][]byte{ + MetadataKeyKmsCiphertextKey: []byte("ciphertext_key"), + }, + }, + errMsg: "wrong master key: missing KMS vendor", + }, + { + name: "KMS vendor mismatch", + content: &encryptionpb.EncryptedContent{ + Metadata: map[string][]byte{ + MetadataKeyKmsVendor: []byte("wrong_kms"), + MetadataKeyKmsCiphertextKey: []byte("ciphertext_key"), + }, + }, + errMsg: "KMS vendor mismatch expect mock_kms got wrong_kms", + }, + { + name: "missing ciphertext key", + content: &encryptionpb.EncryptedContent{ + Metadata: map[string][]byte{ + MetadataKeyKmsVendor: []byte("mock_kms"), + }, + }, + errMsg: "KMS ciphertext key not found", + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + _, err := backend.Decrypt(ctx, tc.content) + require.ErrorContains(t, err, tc.errMsg) + }) + } +} diff --git a/br/pkg/encryption/master_key/master_key.go b/br/pkg/encryption/master_key/master_key.go new file mode 100644 index 0000000000..1ceac58c9d --- /dev/null +++ b/br/pkg/encryption/master_key/master_key.go @@ -0,0 +1,77 @@ +// Copyright 2024 PingCAP, Inc. Licensed under Apache-2.0. + +package encryption + +import ( + "context" + + "github.com/pingcap/errors" + "github.com/pingcap/kvproto/pkg/encryptionpb" + "github.com/pingcap/log" + "github.com/pingcap/tidb/br/pkg/kms" + "go.uber.org/zap" +) + +const ( + StorageVendorNameAWS = "aws" + StorageVendorNameAzure = "azure" + StorageVendorNameGCP = "gcp" +) + +// Backend is an interface that defines the methods required for an encryption backend. +type Backend interface { + // Decrypt takes an EncryptedContent and returns the decrypted plaintext as a byte slice or an error. + Decrypt(ctx context.Context, ciphertext *encryptionpb.EncryptedContent) ([]byte, error) + Close() +} + +func CreateBackend(config *encryptionpb.MasterKey) (Backend, error) { + if config == nil { + return nil, errors.Errorf("master key config is nil") + } + + switch backend := config.Backend.(type) { + case *encryptionpb.MasterKey_Plaintext: + // should not plaintext type as guarded by caller + return nil, errors.New("should not create plaintext master key") + case *encryptionpb.MasterKey_File: + fileBackend, err := createFileBackend(backend.File.Path) + if err != nil { + return nil, errors.Annotate(err, "master key config is nil") + } + return fileBackend, nil + case *encryptionpb.MasterKey_Kms: + return createCloudBackend(backend.Kms) + default: + return nil, errors.New("unknown master key backend type") + } +} + +func createCloudBackend(config *encryptionpb.MasterKeyKms) (Backend, error) { + log.Info("creating cloud KMS backend", + zap.String("region", config.GetRegion()), + zap.String("endpoint", config.GetEndpoint()), + zap.String("key_id", config.GetKeyId()), + zap.String("Vendor", config.GetVendor())) + + switch config.Vendor { + case StorageVendorNameAWS: + kmsProvider, err := kms.NewAwsKms(config) + if err != nil { + return nil, errors.Annotate(err, "new AWS KMS") + } + return NewKmsBackend(kmsProvider) + + case StorageVendorNameAzure: + return nil, errors.Errorf("not implemented Azure KMS") + case StorageVendorNameGCP: + kmsProvider, err := kms.NewGcpKms(config) + if err != nil { + return nil, errors.Annotate(err, "new GCP KMS") + } + return NewKmsBackend(kmsProvider) + + default: + return nil, errors.Errorf("vendor not found: %s", config.Vendor) + } +} diff --git a/br/pkg/encryption/master_key/mem_backend.go b/br/pkg/encryption/master_key/mem_backend.go new file mode 100644 index 0000000000..d65532b69d --- /dev/null +++ b/br/pkg/encryption/master_key/mem_backend.go @@ -0,0 +1,103 @@ +// Copyright 2024 PingCAP, Inc. Licensed under Apache-2.0. + +package encryption + +import ( + "context" + "crypto/aes" + "crypto/cipher" + + "github.com/pingcap/errors" + "github.com/pingcap/kvproto/pkg/encryptionpb" + "github.com/pingcap/tidb/br/pkg/kms" +) + +const ( + gcmTagNotFound = "aes gcm tag not found" + wrongMasterKey = "wrong master key" +) + +type MemAesGcmBackend struct { + key *kms.PlainKey +} + +func NewMemAesGcmBackend(key []byte) (*MemAesGcmBackend, error) { + plainKey, err := kms.NewPlainKey(key, kms.CryptographyTypeAesGcm256) + if err != nil { + return nil, errors.Annotate(err, "failed to create new mem aes gcm backend") + } + return &MemAesGcmBackend{ + key: plainKey, + }, nil +} + +func (m *MemAesGcmBackend) EncryptContent(_ctx context.Context, plaintext []byte, iv IV) ( + *encryptionpb.EncryptedContent, error) { + content := encryptionpb.EncryptedContent{ + Metadata: make(map[string][]byte), + } + content.Metadata[MetadataKeyMethod] = []byte(MetadataMethodAes256Gcm) + content.Metadata[MetadataKeyIv] = iv.AsSlice() + + block, err := aes.NewCipher(m.key.Key()) + if err != nil { + return nil, err + } + aesgcm, err := cipher.NewGCM(block) + if err != nil { + return nil, err + } + + // The Seal function in AES-GCM mode appends the authentication tag to the ciphertext. + // We need to separate the actual ciphertext from the tag for storage and later verification. + // Reference: https://pkg.go.dev/crypto/cipher#AEAD + ciphertext := aesgcm.Seal(nil, iv.AsSlice(), plaintext, nil) + content.Content = ciphertext[:len(ciphertext)-aesgcm.Overhead()] + content.Metadata[MetadataKeyAesGcmTag] = ciphertext[len(ciphertext)-aesgcm.Overhead():] + + return &content, nil +} + +func (m *MemAesGcmBackend) DecryptContent(_ctx context.Context, content *encryptionpb.EncryptedContent) ( + []byte, error) { + method, ok := content.Metadata[MetadataKeyMethod] + if !ok { + return nil, errors.Errorf("metadata %s not found", MetadataKeyMethod) + } + if string(method) != MetadataMethodAes256Gcm { + return nil, errors.Errorf("encryption method mismatch, expected %s vs actual %s", + MetadataMethodAes256Gcm, method) + } + + ivValue, ok := content.Metadata[MetadataKeyIv] + if !ok { + return nil, errors.Errorf("metadata %s not found", MetadataKeyIv) + } + + iv, err := NewIVFromSlice(ivValue) + if err != nil { + return nil, err + } + + tag, ok := content.Metadata[MetadataKeyAesGcmTag] + if !ok { + return nil, errors.New("aes gcm tag not found") + } + + block, err := aes.NewCipher(m.key.Key()) + if err != nil { + return nil, err + } + aesgcm, err := cipher.NewGCM(block) + if err != nil { + return nil, err + } + + ciphertext := append(content.Content, tag...) + plaintext, err := aesgcm.Open(nil, iv.AsSlice(), ciphertext, nil) + if err != nil { + return nil, errors.Annotate(err, wrongMasterKey+" :decrypt in GCM mode failed") + } + + return plaintext, nil +} diff --git a/br/pkg/encryption/master_key/mem_backend_test.go b/br/pkg/encryption/master_key/mem_backend_test.go new file mode 100644 index 0000000000..83ef33c6e6 --- /dev/null +++ b/br/pkg/encryption/master_key/mem_backend_test.go @@ -0,0 +1,115 @@ +// Copyright 2024 PingCAP, Inc. Licensed under Apache-2.0. + +package encryption + +import ( + "bytes" + "context" + "testing" + + "github.com/stretchr/testify/require" +) + +func TestNewMemAesGcmBackend(t *testing.T) { + key := make([]byte, 32) // 256-bit key + _, err := NewMemAesGcmBackend(key) + require.NoError(t, err, "Failed to create MemAesGcmBackend") + + shortKey := make([]byte, 16) + _, err = NewMemAesGcmBackend(shortKey) + require.Error(t, err, "Expected error for short key") +} + +func TestEncryptDecrypt(t *testing.T) { + key := make([]byte, 32) + backend, err := NewMemAesGcmBackend(key) + require.NoError(t, err, "Failed to create MemAesGcmBackend") + + plaintext := []byte("Hello, World!") + + iv, err := NewIVGcm() + require.NoError(t, err, "failed to create gcm iv") + + ctx := context.Background() + encrypted, err := backend.EncryptContent(ctx, plaintext, iv) + require.NoError(t, err, "Encryption failed") + + decrypted, err := backend.DecryptContent(ctx, encrypted) + require.NoError(t, err, "Decryption failed") + + require.Equal(t, plaintext, decrypted, "Decrypted text doesn't match original") +} + +func TestDecryptWithWrongKey(t *testing.T) { + key1 := make([]byte, 32) + key2 := make([]byte, 32) + for i := range key2 { + key2[i] = 1 // Different from key1 + } + + backend1, _ := NewMemAesGcmBackend(key1) + backend2, _ := NewMemAesGcmBackend(key2) + + plaintext := []byte("Hello, World!") + + iv, err := NewIVGcm() + require.NoError(t, err, "failed to create gcm iv") + + ctx := context.Background() + encrypted, _ := backend1.EncryptContent(ctx, plaintext, iv) + _, err = backend2.DecryptContent(ctx, encrypted) + require.Error(t, err, "Expected decryption with wrong key to fail") +} + +func TestDecryptWithTamperedCiphertext(t *testing.T) { + key := make([]byte, 32) + backend, _ := NewMemAesGcmBackend(key) + + plaintext := []byte("Hello, World!") + + iv, err := NewIVGcm() + require.NoError(t, err, "failed to create gcm iv") + + ctx := context.Background() + encrypted, _ := backend.EncryptContent(ctx, plaintext, iv) + encrypted.Content[0] ^= 1 // Tamper with the ciphertext + + _, err = backend.DecryptContent(ctx, encrypted) + require.Error(t, err, "Expected decryption of tampered ciphertext to fail") +} + +func TestDecryptWithMissingMetadata(t *testing.T) { + key := make([]byte, 32) + backend, _ := NewMemAesGcmBackend(key) + + plaintext := []byte("Hello, World!") + + iv, err := NewIVGcm() + require.NoError(t, err, "failed to create gcm iv") + + ctx := context.Background() + encrypted, _ := backend.EncryptContent(ctx, plaintext, iv) + delete(encrypted.Metadata, MetadataKeyMethod) + + _, err = backend.DecryptContent(ctx, encrypted) + require.Error(t, err, "Expected decryption with missing metadata to fail") +} + +func TestEncryptDecryptLargeData(t *testing.T) { + key := make([]byte, 32) + backend, _ := NewMemAesGcmBackend(key) + + plaintext := make([]byte, 1000000) // 1 MB of data + + iv, err := NewIVGcm() + require.NoError(t, err, "failed to create gcm iv") + + ctx := context.Background() + encrypted, err := backend.EncryptContent(ctx, plaintext, iv) + require.NoError(t, err, "Encryption of large data failed") + + decrypted, err := backend.DecryptContent(ctx, encrypted) + require.NoError(t, err, "Decryption of large data failed") + + require.True(t, bytes.Equal(plaintext, decrypted), "Decrypted large data doesn't match original") +} diff --git a/br/pkg/encryption/master_key/multi_master_key_backend.go b/br/pkg/encryption/master_key/multi_master_key_backend.go new file mode 100644 index 0000000000..9967c32f68 --- /dev/null +++ b/br/pkg/encryption/master_key/multi_master_key_backend.go @@ -0,0 +1,64 @@ +// Copyright 2024 PingCAP, Inc. Licensed under Apache-2.0. + +package encryption + +import ( + "context" + + "github.com/pingcap/errors" + "github.com/pingcap/kvproto/pkg/encryptionpb" + "go.uber.org/multierr" +) + +const ( + defaultBackendCapacity = 5 +) + +// MultiMasterKeyBackend can contain multiple master shard backends. +// If any one of those backends successfully decrypts the data, the data will be returned. +// The main purpose of this backend is to provide a high availability for master key in the future. +// Right now only one master key backend is used to encrypt/decrypt data. +type MultiMasterKeyBackend struct { + backends []Backend +} + +func NewMultiMasterKeyBackend(masterKeysProto []*encryptionpb.MasterKey) (*MultiMasterKeyBackend, error) { + if masterKeysProto == nil && len(masterKeysProto) == 0 { + return nil, errors.New("must provide at least one master key") + } + var backends = make([]Backend, 0, defaultBackendCapacity) + for _, masterKeyProto := range masterKeysProto { + backend, err := CreateBackend(masterKeyProto) + if err != nil { + return nil, errors.Trace(err) + } + backends = append(backends, backend) + } + return &MultiMasterKeyBackend{ + backends: backends, + }, nil +} + +func (m *MultiMasterKeyBackend) Decrypt(ctx context.Context, encryptedContent *encryptionpb.EncryptedContent) ( + []byte, error) { + if len(m.backends) == 0 { + return nil, errors.New("internal error: should always contain at least one backend") + } + + var err error + for _, masterKeyBackend := range m.backends { + res, decryptErr := masterKeyBackend.Decrypt(ctx, encryptedContent) + if decryptErr == nil { + return res, nil + } + err = multierr.Append(err, decryptErr) + } + + return nil, errors.Wrap(err, "failed to decrypt in multi master key backend") +} + +func (m *MultiMasterKeyBackend) Close() { + for _, backend := range m.backends { + backend.Close() + } +} diff --git a/br/pkg/encryption/master_key/multi_master_key_backend_test.go b/br/pkg/encryption/master_key/multi_master_key_backend_test.go new file mode 100644 index 0000000000..9c245495e9 --- /dev/null +++ b/br/pkg/encryption/master_key/multi_master_key_backend_test.go @@ -0,0 +1,105 @@ +// Copyright 2024 PingCAP, Inc. Licensed under Apache-2.0. + +package encryption + +import ( + "context" + "testing" + + "github.com/pingcap/errors" + "github.com/pingcap/kvproto/pkg/encryptionpb" + "github.com/stretchr/testify/mock" + "github.com/stretchr/testify/require" +) + +// MockBackend is a mock implementation of the Backend interface +type MockBackend struct { + mock.Mock +} + +func (m *MockBackend) Decrypt(ctx context.Context, encryptedContent *encryptionpb.EncryptedContent) ([]byte, error) { + args := m.Called(ctx, encryptedContent) + // The first return value should be []byte or nil + if ret := args.Get(0); ret != nil { + return ret.([]byte), args.Error(1) + } + return nil, args.Error(1) +} + +func (m *MockBackend) Close() { + // do nothing +} + +func TestMultiMasterKeyBackendDecrypt(t *testing.T) { + ctx := context.Background() + encryptedContent := &encryptionpb.EncryptedContent{Content: []byte("encrypted")} + + t.Run("success first backend", func(t *testing.T) { + mock1 := new(MockBackend) + mock1.On("Decrypt", ctx, encryptedContent).Return([]byte("decrypted"), nil) + + mock2 := new(MockBackend) + + backend := &MultiMasterKeyBackend{ + backends: []Backend{mock1, mock2}, + } + + result, err := backend.Decrypt(ctx, encryptedContent) + require.NoError(t, err) + require.Equal(t, []byte("decrypted"), result) + + mock1.AssertExpectations(t) + mock2.AssertNotCalled(t, "Decrypt") + }) + + t.Run("success second backend", func(t *testing.T) { + mock1 := new(MockBackend) + mock1.On("Decrypt", ctx, encryptedContent).Return(nil, errors.New("failed")) + + mock2 := new(MockBackend) + mock2.On("Decrypt", ctx, encryptedContent).Return([]byte("decrypted"), nil) + + backend := &MultiMasterKeyBackend{ + backends: []Backend{mock1, mock2}, + } + + result, err := backend.Decrypt(ctx, encryptedContent) + require.NoError(t, err) + require.Equal(t, []byte("decrypted"), result) + + mock1.AssertExpectations(t) + mock2.AssertExpectations(t) + }) + + t.Run("all backends fail", func(t *testing.T) { + mock1 := new(MockBackend) + mock1.On("Decrypt", ctx, encryptedContent).Return(nil, errors.New("failed1")) + + mock2 := new(MockBackend) + mock2.On("Decrypt", ctx, encryptedContent).Return(nil, errors.New("failed2")) + + backend := &MultiMasterKeyBackend{ + backends: []Backend{mock1, mock2}, + } + + result, err := backend.Decrypt(ctx, encryptedContent) + require.Error(t, err) + require.Nil(t, result) + require.Contains(t, err.Error(), "failed1") + require.Contains(t, err.Error(), "failed2") + + mock1.AssertExpectations(t) + mock2.AssertExpectations(t) + }) + + t.Run("no backends", func(t *testing.T) { + backend := &MultiMasterKeyBackend{ + backends: []Backend{}, + } + + result, err := backend.Decrypt(ctx, encryptedContent) + require.Error(t, err) + require.Nil(t, result) + require.Contains(t, err.Error(), "internal error") + }) +} diff --git a/br/pkg/kms/BUILD.bazel b/br/pkg/kms/BUILD.bazel new file mode 100644 index 0000000000..f7f34aab2d --- /dev/null +++ b/br/pkg/kms/BUILD.bazel @@ -0,0 +1,27 @@ +load("@io_bazel_rules_go//go:def.bzl", "go_library") + +go_library( + name = "aws", + srcs = [ + "aws.go", + "common.go", + "gcp.go", + "kms.go", + ], + importpath = "github.com/pingcap/tidb/br/pkg/kms", + visibility = ["//visibility:public"], + deps = [ + "@com_github_aws_aws_sdk_go//aws", + "@com_github_aws_aws_sdk_go//aws/credentials", + "@com_github_aws_aws_sdk_go//aws/session", + "@com_github_aws_aws_sdk_go//service/kms", + "@com_github_pingcap_errors//:errors", + "@com_github_pingcap_kvproto//pkg/encryptionpb", + "@com_github_pingcap_log//:log", + "@com_google_cloud_go_kms//apiv1", + "@com_google_cloud_go_kms//apiv1/kmspb", + "@org_golang_google_api//option", + "@org_golang_google_protobuf//types/known/wrapperspb", + "@org_uber_go_zap//:zap", + ], +) diff --git a/br/pkg/kms/aws.go b/br/pkg/kms/aws.go new file mode 100644 index 0000000000..5bfaf65141 --- /dev/null +++ b/br/pkg/kms/aws.go @@ -0,0 +1,92 @@ +// Copyright 2024 PingCAP, Inc. Licensed under Apache-2.0. + +package kms + +import ( + "context" + + "github.com/aws/aws-sdk-go/aws" + "github.com/aws/aws-sdk-go/aws/credentials" + "github.com/aws/aws-sdk-go/aws/session" + "github.com/aws/aws-sdk-go/service/kms" + pErrors "github.com/pingcap/errors" + "github.com/pingcap/kvproto/pkg/encryptionpb" +) + +const ( + // need to keep it exact same as in TiKV ENCRYPTION_VENDOR_NAME_AWS_KMS + EncryptionVendorNameAwsKms = "AWS" +) + +type AwsKms struct { + client *kms.KMS + currentKeyID string + region string + endpoint string +} + +func NewAwsKms(masterKeyConfig *encryptionpb.MasterKeyKms) (*AwsKms, error) { + config := &aws.Config{ + Region: aws.String(masterKeyConfig.Region), + Endpoint: aws.String(masterKeyConfig.Endpoint), + } + + // Only use static credentials if both access key and secret key are provided + if masterKeyConfig.AwsKms != nil && + masterKeyConfig.AwsKms.AccessKey != "" && + masterKeyConfig.AwsKms.SecretAccessKey != "" { + config.Credentials = credentials.NewStaticCredentials( + masterKeyConfig.AwsKms.AccessKey, + masterKeyConfig.AwsKms.SecretAccessKey, + "", + ) + } + + sess, err := session.NewSession(config) + if err != nil { + return nil, pErrors.Annotate(err, "failed to create AWS session") + } + + return &AwsKms{ + client: kms.New(sess), + currentKeyID: masterKeyConfig.KeyId, + region: masterKeyConfig.Region, + endpoint: masterKeyConfig.Endpoint, + }, nil +} + +func (a *AwsKms) Name() string { + return EncryptionVendorNameAwsKms +} + +func (a *AwsKms) DecryptDataKey(ctx context.Context, dataKey []byte) ([]byte, error) { + input := &kms.DecryptInput{ + CiphertextBlob: dataKey, + KeyId: aws.String(a.currentKeyID), + } + + result, err := a.client.DecryptWithContext(ctx, input) + if err != nil { + return nil, classifyDecryptError(err) + } + + return result.Plaintext, nil +} + +func (a *AwsKms) Close() { + // don't need to do manual close +} + +// Update classifyDecryptError to use v1 SDK error types +func classifyDecryptError(err error) error { + switch err := err.(type) { + case *kms.NotFoundException, *kms.InvalidKeyUsageException: + return pErrors.Annotate(err, "wrong master key") + case *kms.DependencyTimeoutException: + return pErrors.Annotate(err, "API timeout") + case *kms.InternalException: + return pErrors.Annotate(err, "API internal error") + default: + return pErrors.Annotate(err, "KMS error") + } +} diff --git a/br/pkg/kms/common.go b/br/pkg/kms/common.go new file mode 100644 index 0000000000..13e2f63504 --- /dev/null +++ b/br/pkg/kms/common.go @@ -0,0 +1,65 @@ +// Copyright 2024 PingCAP, Inc. Licensed under Apache-2.0. + +package kms + +import ( + "bytes" + + "github.com/pingcap/errors" +) + +// EncryptedKey is used to mark data as an encrypted key +type EncryptedKey []byte + +func NewEncryptedKey(key []byte) (EncryptedKey, error) { + if len(key) == 0 { + return nil, errors.New("encrypted key cannot be empty") + } + return key, nil +} + +// Equal method for EncryptedKey +func (e EncryptedKey) Equal(other *EncryptedKey) bool { + return bytes.Equal(e, *other) +} + +// CryptographyType represents different cryptography methods +type CryptographyType int + +const ( + CryptographyTypePlain CryptographyType = iota + CryptographyTypeAesGcm256 +) + +func (c CryptographyType) TargetKeySize() int { + switch c { + case CryptographyTypePlain: + return 0 // Plain text has no limitation + case CryptographyTypeAesGcm256: + return 32 + default: + return 0 + } +} + +// PlainKey is used to mark a byte slice as a plaintext key +type PlainKey struct { + tag CryptographyType + key []byte +} + +func NewPlainKey(key []byte, t CryptographyType) (*PlainKey, error) { + limitation := t.TargetKeySize() + if limitation > 0 && len(key) != limitation { + return nil, errors.Errorf("encryption method and key length mismatch, expect %d got %d", limitation, len(key)) + } + return &PlainKey{key: key, tag: t}, nil +} + +func (p *PlainKey) KeyTag() CryptographyType { + return p.tag +} + +func (p *PlainKey) Key() []byte { + return p.key +} diff --git a/br/pkg/kms/gcp.go b/br/pkg/kms/gcp.go new file mode 100644 index 0000000000..9ea5be306f --- /dev/null +++ b/br/pkg/kms/gcp.go @@ -0,0 +1,104 @@ +// Copyright 2024 PingCAP, Inc. Licensed under Apache-2.0. + +package kms + +import ( + "context" + "hash/crc32" + "strings" + + "cloud.google.com/go/kms/apiv1" + "cloud.google.com/go/kms/apiv1/kmspb" + "github.com/pingcap/errors" + "github.com/pingcap/kvproto/pkg/encryptionpb" + "github.com/pingcap/log" + "go.uber.org/zap" + "google.golang.org/api/option" + "google.golang.org/protobuf/types/known/wrapperspb" +) + +const ( + // need to keep it exactly same as TiKV STORAGE_VENDOR_NAME_GCP in TiKV + StorageVendorNameGcp = "gcp" +) + +type GcpKms struct { + config *encryptionpb.MasterKeyKms + // the location prefix of key id, + // format: projects/{project_name}/locations/{location} + location string + client *kms.KeyManagementClient +} + +func NewGcpKms(config *encryptionpb.MasterKeyKms) (*GcpKms, error) { + if config.GcpKms == nil { + return nil, errors.New("GCP config is missing") + } + + // config string pattern verified at parsing flag phase, we should have valid string at this stage. + config.KeyId = strings.TrimSuffix(config.KeyId, "/") + + // join the first 4 parts of the key id to get the location + location := strings.Join(strings.Split(config.KeyId, "/")[:4], "/") + + ctx := context.Background() + + var clientOpt option.ClientOption + if config.GcpKms.Credential != "" { + clientOpt = option.WithCredentialsFile(config.GcpKms.Credential) + } + + client, err := kms.NewKeyManagementClient(ctx, clientOpt) + if err != nil { + return nil, errors.Errorf("failed to create GCP KMS client: %v", err) + } + + return &GcpKms{ + config: config, + location: location, + client: client, + }, nil +} + +func (g *GcpKms) Name() string { + return StorageVendorNameGcp +} + +func (g *GcpKms) DecryptDataKey(ctx context.Context, dataKey []byte) ([]byte, error) { + req := &kmspb.DecryptRequest{ + Name: g.config.KeyId, + Ciphertext: dataKey, + CiphertextCrc32C: wrapperspb.Int64(int64(g.calculateCRC32C(dataKey))), + } + + resp, err := g.client.Decrypt(ctx, req) + if err != nil { + return nil, errors.Annotate(err, "gcp kms decrypt request failed") + } + + if int64(g.calculateCRC32C(resp.Plaintext)) != resp.PlaintextCrc32C.Value { + return nil, errors.New("response corrupted in-transit") + } + + return resp.Plaintext, nil +} + +func (g *GcpKms) checkCRC32(data []byte, expected int64) error { + crc := int64(g.calculateCRC32C(data)) + if crc != expected { + return errors.Errorf("crc32c mismatch, expected: %d, got: %d", expected, crc) + } + return nil +} + +func (g *GcpKms) calculateCRC32C(data []byte) uint32 { + t := crc32.MakeTable(crc32.Castagnoli) + return crc32.Checksum(data, t) +} + +func (g *GcpKms) Close() { + err := g.client.Close() + if err != nil { + log.Error("failed to close gcp kms client", zap.Error(err)) + } +} diff --git a/br/pkg/kms/kms.go b/br/pkg/kms/kms.go new file mode 100644 index 0000000000..950749e851 --- /dev/null +++ b/br/pkg/kms/kms.go @@ -0,0 +1,13 @@ +// Copyright 2024 PingCAP, Inc. Licensed under Apache-2.0. + +package kms + +import "context" + +// Provider is an interface for key management service providers +// implement encrypt data key in future if needed +type Provider interface { + DecryptDataKey(ctx context.Context, dataKey []byte) ([]byte, error) + Name() string + Close() +} diff --git a/br/pkg/metautil/BUILD.bazel b/br/pkg/metautil/BUILD.bazel index 44a35bc2b2..a7008e6283 100644 --- a/br/pkg/metautil/BUILD.bazel +++ b/br/pkg/metautil/BUILD.bazel @@ -14,6 +14,7 @@ go_library( "//br/pkg/logutil", "//br/pkg/storage", "//br/pkg/summary", + "//br/pkg/utils", "//pkg/meta/model", "//pkg/statistics/handle", "//pkg/statistics/handle/types", @@ -47,6 +48,7 @@ go_test( shard_count = 9, deps = [ "//br/pkg/storage", + "//br/pkg/utils", "//pkg/meta/model", "//pkg/parser/model", "//pkg/statistics/handle/types", diff --git a/br/pkg/metautil/metafile.go b/br/pkg/metautil/metafile.go index 83542d2880..813f63058c 100644 --- a/br/pkg/metautil/metafile.go +++ b/br/pkg/metautil/metafile.go @@ -23,6 +23,7 @@ import ( "github.com/pingcap/tidb/br/pkg/logutil" "github.com/pingcap/tidb/br/pkg/storage" "github.com/pingcap/tidb/br/pkg/summary" + "github.com/pingcap/tidb/br/pkg/utils" "github.com/pingcap/tidb/pkg/meta/model" "github.com/pingcap/tidb/pkg/statistics/handle/util" "github.com/pingcap/tidb/pkg/tablecodec" @@ -87,22 +88,17 @@ func Encrypt(content []byte, cipher *backuppb.CipherInfo) (encryptedContent, iv } } -// Decrypt decrypts the content according to CipherInfo and IV. -func Decrypt(content []byte, cipher *backuppb.CipherInfo, iv []byte) ([]byte, error) { - if len(content) == 0 || cipher == nil { - return content, nil +func DecryptFullBackupMetaIfNeeded(metaData []byte, cipherInfo *backuppb.CipherInfo) ([]byte, error) { + if cipherInfo == nil || !utils.IsEffectiveEncryptionMethod(cipherInfo.CipherType) { + return metaData, nil } - - switch cipher.CipherType { - case encryptionpb.EncryptionMethod_PLAINTEXT: - return content, nil - case encryptionpb.EncryptionMethod_AES128_CTR, - encryptionpb.EncryptionMethod_AES192_CTR, - encryptionpb.EncryptionMethod_AES256_CTR: - return encrypt.AESDecryptWithCTR(content, cipher.CipherKey, iv) - default: - return content, errors.Annotate(berrors.ErrInvalidArgument, "cipher type invalid") + // the prefix of backup meta file is iv(16 bytes) for ctr mode if encryption method is valid + iv := metaData[:CrypterIvLen] + decryptBackupMeta, err := utils.Decrypt(metaData[len(iv):], cipherInfo, iv) + if err != nil { + return nil, errors.Annotate(err, "decrypt failed with wrong key") } + return decryptBackupMeta, nil } // walkLeafMetaFile walks the leaves of the given metafile, and deal with it by calling the function `output`. @@ -130,7 +126,7 @@ func walkLeafMetaFile( return errors.Trace(err) } - decryptContent, err := Decrypt(content, cipher, node.CipherIv) + decryptContent, err := utils.Decrypt(content, cipher, node.CipherIv) if err != nil { return errors.Trace(err) } diff --git a/br/pkg/metautil/metafile_test.go b/br/pkg/metautil/metafile_test.go index 4a010b4965..1cea0f86d8 100644 --- a/br/pkg/metautil/metafile_test.go +++ b/br/pkg/metautil/metafile_test.go @@ -11,6 +11,7 @@ import ( backuppb "github.com/pingcap/kvproto/pkg/brpb" "github.com/pingcap/kvproto/pkg/encryptionpb" "github.com/pingcap/tidb/br/pkg/storage" + "github.com/pingcap/tidb/br/pkg/utils" "github.com/stretchr/testify/require" ) @@ -203,14 +204,14 @@ func TestEncryptAndDecrypt(t *testing.T) { require.NoError(t, err) require.Equal(t, originalData, encryptData) - decryptData, err := Decrypt(encryptData, &cipher, iv) + decryptData, err := utils.Decrypt(encryptData, &cipher, iv) require.NoError(t, err) require.Equal(t, decryptData, originalData) } else { require.NoError(t, err) require.NotEqual(t, originalData, encryptData) - decryptData, err := Decrypt(encryptData, &cipher, iv) + decryptData, err := utils.Decrypt(encryptData, &cipher, iv) require.NoError(t, err) require.Equal(t, decryptData, originalData) @@ -218,7 +219,7 @@ func TestEncryptAndDecrypt(t *testing.T) { CipherType: v.method, CipherKey: []byte(v.wrongKey), } - decryptData, err = Decrypt(encryptData, &wrongCipher, iv) + decryptData, err = utils.Decrypt(encryptData, &wrongCipher, iv) if len(v.rightKey) != len(v.wrongKey) { require.Error(t, err) } else { diff --git a/br/pkg/metautil/statsfile.go b/br/pkg/metautil/statsfile.go index 6621a87cb6..ddccdfe839 100644 --- a/br/pkg/metautil/statsfile.go +++ b/br/pkg/metautil/statsfile.go @@ -26,6 +26,7 @@ import ( backuppb "github.com/pingcap/kvproto/pkg/brpb" berrors "github.com/pingcap/tidb/br/pkg/errors" "github.com/pingcap/tidb/br/pkg/storage" + "github.com/pingcap/tidb/br/pkg/utils" "github.com/pingcap/tidb/pkg/meta/model" "github.com/pingcap/tidb/pkg/statistics/handle" statstypes "github.com/pingcap/tidb/pkg/statistics/handle/types" @@ -201,7 +202,7 @@ func downloadStats( return errors.Trace(err) } - decryptContent, err := Decrypt(content, cipher, statsFile.CipherIv) + decryptContent, err := utils.Decrypt(content, cipher, statsFile.CipherIv) if err != nil { return errors.Trace(err) } diff --git a/br/pkg/restore/log_client/BUILD.bazel b/br/pkg/restore/log_client/BUILD.bazel index 76a2f7de49..5975b0726a 100644 --- a/br/pkg/restore/log_client/BUILD.bazel +++ b/br/pkg/restore/log_client/BUILD.bazel @@ -16,6 +16,7 @@ go_library( "//br/pkg/checksum", "//br/pkg/conn", "//br/pkg/conn/util", + "//br/pkg/encryption", "//br/pkg/errors", "//br/pkg/glue", "//br/pkg/logutil", @@ -49,6 +50,7 @@ go_library( "@com_github_pingcap_errors//:errors", "@com_github_pingcap_failpoint//:failpoint", "@com_github_pingcap_kvproto//pkg/brpb", + "@com_github_pingcap_kvproto//pkg/encryptionpb", "@com_github_pingcap_kvproto//pkg/errorpb", "@com_github_pingcap_kvproto//pkg/import_sstpb", "@com_github_pingcap_kvproto//pkg/kvrpcpb", @@ -113,6 +115,7 @@ go_test( "@com_github_pingcap_errors//:errors", "@com_github_pingcap_failpoint//:failpoint", "@com_github_pingcap_kvproto//pkg/brpb", + "@com_github_pingcap_kvproto//pkg/encryptionpb", "@com_github_pingcap_kvproto//pkg/errorpb", "@com_github_pingcap_kvproto//pkg/import_sstpb", "@com_github_pingcap_kvproto//pkg/metapb", diff --git a/br/pkg/restore/log_client/client.go b/br/pkg/restore/log_client/client.go index 05aa19ace1..da021fe1b0 100644 --- a/br/pkg/restore/log_client/client.go +++ b/br/pkg/restore/log_client/client.go @@ -33,11 +33,13 @@ import ( "github.com/pingcap/errors" "github.com/pingcap/failpoint" backuppb "github.com/pingcap/kvproto/pkg/brpb" + "github.com/pingcap/kvproto/pkg/encryptionpb" "github.com/pingcap/log" "github.com/pingcap/tidb/br/pkg/checkpoint" "github.com/pingcap/tidb/br/pkg/checksum" "github.com/pingcap/tidb/br/pkg/conn" "github.com/pingcap/tidb/br/pkg/conn/util" + "github.com/pingcap/tidb/br/pkg/encryption" "github.com/pingcap/tidb/br/pkg/glue" "github.com/pingcap/tidb/br/pkg/logutil" "github.com/pingcap/tidb/br/pkg/metautil" @@ -293,13 +295,15 @@ func (rc *LogClient) InitCheckpointMetadataForLogRestore( return gcRatio, nil } -func (rc *LogClient) InstallLogFileManager(ctx context.Context, startTS, restoreTS uint64, metadataDownloadBatchSize uint) error { +func (rc *LogClient) InstallLogFileManager(ctx context.Context, startTS, restoreTS uint64, metadataDownloadBatchSize uint, + encryptionManager *encryption.Manager) error { init := LogFileManagerInit{ StartTS: startTS, RestoreTS: restoreTS, Storage: rc.storage, MetadataDownloadBatchSize: metadataDownloadBatchSize, + EncryptionManager: encryptionManager, } var err error rc.LogFileManager, err = CreateLogFileManager(ctx, init) @@ -436,7 +440,7 @@ func ApplyKVFilesWithBatchMethod( return nil } -func ApplyKVFilesWithSingelMethod( +func ApplyKVFilesWithSingleMethod( ctx context.Context, files LogIter, applyFunc func(file []*LogDataFileInfo, kvCount int64, size uint64), @@ -477,6 +481,8 @@ func (rc *LogClient) RestoreKVFiles( pitrBatchSize uint32, updateStats func(kvCount uint64, size uint64), onProgress func(cnt int64), + cipherInfo *backuppb.CipherInfo, + masterKeys []*encryptionpb.MasterKey, ) error { var ( err error @@ -488,7 +494,7 @@ func (rc *LogClient) RestoreKVFiles( defer func() { if err == nil { elapsed := time.Since(start) - log.Info("Restore KV files", zap.Duration("take", elapsed)) + log.Info("Restored KV files", zap.Duration("take", elapsed)) summary.CollectSuccessUnit("files", fileCount, elapsed) } }() @@ -548,7 +554,8 @@ func (rc *LogClient) RestoreKVFiles( } }() - return rc.fileImporter.ImportKVFiles(ectx, files, rule, rc.shiftStartTS, rc.startTS, rc.restoreTS, supportBatch) + return rc.fileImporter.ImportKVFiles(ectx, files, rule, rc.shiftStartTS, rc.startTS, rc.restoreTS, + supportBatch, cipherInfo, masterKeys) }) } } @@ -557,7 +564,7 @@ func (rc *LogClient) RestoreKVFiles( if supportBatch { err = ApplyKVFilesWithBatchMethod(ectx, logIter, int(pitrBatchCount), uint64(pitrBatchSize), applyFunc, &applyWg) } else { - err = ApplyKVFilesWithSingelMethod(ectx, logIter, applyFunc, &applyWg) + err = ApplyKVFilesWithSingleMethod(ectx, logIter, applyFunc, &applyWg) } return errors.Trace(err) }) @@ -619,18 +626,25 @@ func initFullBackupTables( ctx context.Context, s storage.ExternalStorage, tableFilter filter.Filter, + cipherInfo *backuppb.CipherInfo, ) (map[int64]*metautil.Table, error) { metaData, err := s.ReadFile(ctx, metautil.MetaFile) if err != nil { return nil, errors.Trace(err) } + + backupMetaBytes, err := metautil.DecryptFullBackupMetaIfNeeded(metaData, cipherInfo) + if err != nil { + return nil, errors.Annotate(err, "decrypt failed with wrong key") + } + backupMeta := &backuppb.BackupMeta{} - if err = backupMeta.Unmarshal(metaData); err != nil { + if err = backupMeta.Unmarshal(backupMetaBytes); err != nil { return nil, errors.Trace(err) } // read full backup databases to get map[table]table.Info - reader := metautil.NewMetaReader(backupMeta, s, nil) + reader := metautil.NewMetaReader(backupMeta, s, cipherInfo) databases, err := metautil.LoadBackupTables(ctx, reader, false) if err != nil { @@ -684,6 +698,7 @@ const UnsafePITRLogRestoreStartBeforeAnyUpstreamUserDDL = "UNSAFE_PITR_LOG_RESTO func (rc *LogClient) generateDBReplacesFromFullBackupStorage( ctx context.Context, cfg *InitSchemaConfig, + cipherInfo *backuppb.CipherInfo, ) (map[stream.UpstreamID]*stream.DBReplace, error) { dbReplaces := make(map[stream.UpstreamID]*stream.DBReplace) if cfg.FullBackupStorage == nil { @@ -698,7 +713,7 @@ func (rc *LogClient) generateDBReplacesFromFullBackupStorage( if err != nil { return nil, errors.Trace(err) } - fullBackupTables, err := initFullBackupTables(ctx, s, cfg.TableFilter) + fullBackupTables, err := initFullBackupTables(ctx, s, cfg.TableFilter, cipherInfo) if err != nil { return nil, errors.Trace(err) } @@ -741,6 +756,7 @@ func (rc *LogClient) generateDBReplacesFromFullBackupStorage( func (rc *LogClient) InitSchemasReplaceForDDL( ctx context.Context, cfg *InitSchemaConfig, + cipherInfo *backuppb.CipherInfo, ) (*stream.SchemasReplace, error) { var ( err error @@ -785,7 +801,7 @@ func (rc *LogClient) InitSchemasReplaceForDDL( if len(dbMaps) <= 0 { log.Info("no id maps, build the table replaces from cluster and full backup schemas") needConstructIdMap = true - dbReplaces, err = rc.generateDBReplacesFromFullBackupStorage(ctx, cfg) + dbReplaces, err = rc.generateDBReplacesFromFullBackupStorage(ctx, cfg, cipherInfo) if err != nil { return nil, errors.Trace(err) } diff --git a/br/pkg/restore/log_client/client_test.go b/br/pkg/restore/log_client/client_test.go index 7b00e30e6e..0b1baf7af5 100644 --- a/br/pkg/restore/log_client/client_test.go +++ b/br/pkg/restore/log_client/client_test.go @@ -866,7 +866,7 @@ func TestApplyKVFilesWithSingelMethod(t *testing.T) { } } - logclient.ApplyKVFilesWithSingelMethod( + logclient.ApplyKVFilesWithSingleMethod( context.TODO(), toLogDataFileInfoIter(iter.FromSlice(ds)), applyFunc, @@ -1293,7 +1293,7 @@ func TestApplyKVFilesWithBatchMethod5(t *testing.T) { require.Equal(t, backuppb.FileType_Delete, types[len(types)-1]) types = make([]backuppb.FileType, 0) - logclient.ApplyKVFilesWithSingelMethod( + logclient.ApplyKVFilesWithSingleMethod( context.TODO(), toLogDataFileInfoIter(iter.FromSlice(ds)), applyFunc, @@ -1377,7 +1377,7 @@ func TestInitSchemasReplaceForDDL(t *testing.T) { { client := logclient.TEST_NewLogClient(123, 1, 2, 1, domain.NewMockDomain(), fakeSession{}) cfg := &logclient.InitSchemaConfig{IsNewTask: false} - _, err := client.InitSchemasReplaceForDDL(ctx, cfg) + _, err := client.InitSchemasReplaceForDDL(ctx, cfg, nil) require.Error(t, err) require.Regexp(t, "failed to get pitr id map from mysql.tidb_pitr_id_map.* [2, 1]", err.Error()) } @@ -1385,7 +1385,7 @@ func TestInitSchemasReplaceForDDL(t *testing.T) { { client := logclient.TEST_NewLogClient(123, 1, 2, 1, domain.NewMockDomain(), fakeSession{}) cfg := &logclient.InitSchemaConfig{IsNewTask: true} - _, err := client.InitSchemasReplaceForDDL(ctx, cfg) + _, err := client.InitSchemasReplaceForDDL(ctx, cfg, nil) require.Error(t, err) require.Regexp(t, "failed to get pitr id map from mysql.tidb_pitr_id_map.* [1, 1]", err.Error()) } @@ -1399,7 +1399,7 @@ func TestInitSchemasReplaceForDDL(t *testing.T) { require.NoError(t, err) client := logclient.TEST_NewLogClient(123, 1, 2, 1, domain.NewMockDomain(), se) cfg := &logclient.InitSchemaConfig{IsNewTask: true} - _, err = client.InitSchemasReplaceForDDL(ctx, cfg) + _, err = client.InitSchemasReplaceForDDL(ctx, cfg, nil) require.Error(t, err) require.Contains(t, err.Error(), "miss upstream table information at `start-ts`(1) but the full backup path is not specified") } diff --git a/br/pkg/restore/log_client/export_test.go b/br/pkg/restore/log_client/export_test.go index db1324d61f..9a35b35e8e 100644 --- a/br/pkg/restore/log_client/export_test.go +++ b/br/pkg/restore/log_client/export_test.go @@ -19,6 +19,7 @@ import ( "github.com/pingcap/errors" backuppb "github.com/pingcap/kvproto/pkg/brpb" + "github.com/pingcap/kvproto/pkg/encryptionpb" "github.com/pingcap/tidb/br/pkg/glue" "github.com/pingcap/tidb/br/pkg/storage" "github.com/pingcap/tidb/br/pkg/stream" @@ -90,6 +91,7 @@ func (helper *FakeStreamMetadataHelper) ReadFile( length uint64, compressionType backuppb.CompressionType, storage storage.ExternalStorage, + encryptionInfo *encryptionpb.FileEncryptionInfo, ) ([]byte, error) { return helper.Data[offset : offset+length], nil } diff --git a/br/pkg/restore/log_client/import.go b/br/pkg/restore/log_client/import.go index 41abe7fe5c..138b89d243 100644 --- a/br/pkg/restore/log_client/import.go +++ b/br/pkg/restore/log_client/import.go @@ -23,6 +23,7 @@ import ( "github.com/pingcap/errors" backuppb "github.com/pingcap/kvproto/pkg/brpb" + "github.com/pingcap/kvproto/pkg/encryptionpb" "github.com/pingcap/kvproto/pkg/import_sstpb" "github.com/pingcap/kvproto/pkg/kvrpcpb" "github.com/pingcap/kvproto/pkg/metapb" @@ -101,6 +102,8 @@ func (importer *LogFileImporter) ImportKVFiles( startTS uint64, restoreTS uint64, supportBatch bool, + cipherInfo *backuppb.CipherInfo, + masterKeys []*encryptionpb.MasterKey, ) error { var ( startKey []byte @@ -111,7 +114,7 @@ func (importer *LogFileImporter) ImportKVFiles( if !supportBatch && len(files) > 1 { return errors.Annotatef(berrors.ErrInvalidArgument, - "do not support batch apply but files count:%v > 1", len(files)) + "do not support batch apply, file count: %v > 1", len(files)) } log.Debug("import kv files", zap.Int("batch file count", len(files))) @@ -143,7 +146,8 @@ func (importer *LogFileImporter) ImportKVFiles( if len(subfiles) == 0 { return RPCResultOK() } - return importer.importKVFileForRegion(ctx, subfiles, rule, shiftStartTS, startTS, restoreTS, r, supportBatch) + return importer.importKVFileForRegion(ctx, subfiles, rule, shiftStartTS, startTS, restoreTS, r, supportBatch, + cipherInfo, masterKeys) }) return errors.Trace(err) } @@ -184,9 +188,11 @@ func (importer *LogFileImporter) importKVFileForRegion( restoreTS uint64, info *split.RegionInfo, supportBatch bool, + cipherInfo *backuppb.CipherInfo, + masterKeys []*encryptionpb.MasterKey, ) RPCResult { // Try to download file. - result := importer.downloadAndApplyKVFile(ctx, files, rule, info, shiftStartTS, startTS, restoreTS, supportBatch) + result := importer.downloadAndApplyKVFile(ctx, files, rule, info, shiftStartTS, startTS, restoreTS, supportBatch, cipherInfo, masterKeys) if !result.OK() { errDownload := result.Err for _, e := range multierr.Errors(errDownload) { @@ -216,7 +222,8 @@ func (importer *LogFileImporter) downloadAndApplyKVFile( startTS uint64, restoreTS uint64, supportBatch bool, -) RPCResult { + cipherInfo *backuppb.CipherInfo, + masterKeys []*encryptionpb.MasterKey) RPCResult { leader := regionInfo.Leader if leader == nil { return RPCResultFromError(errors.Annotatef(berrors.ErrPDLeaderNotFound, @@ -251,11 +258,12 @@ func (importer *LogFileImporter) downloadAndApplyKVFile( } return startTS }(), - RestoreTs: restoreTS, - StartKey: regionInfo.Region.GetStartKey(), - EndKey: regionInfo.Region.GetEndKey(), - Sha256: file.GetSha256(), - CompressionType: file.CompressionType, + RestoreTs: restoreTS, + StartKey: regionInfo.Region.GetStartKey(), + EndKey: regionInfo.Region.GetEndKey(), + Sha256: file.GetSha256(), + CompressionType: file.CompressionType, + FileEncryptionInfo: file.FileEncryptionInfo, } metas = append(metas, meta) @@ -276,6 +284,8 @@ func (importer *LogFileImporter) downloadAndApplyKVFile( RewriteRules: rewriteRules, Context: reqCtx, StorageCacheId: importer.cacheKey, + CipherInfo: cipherInfo, + MasterKeys: masterKeys, } } else { req = &import_sstpb.ApplyRequest{ @@ -284,16 +294,18 @@ func (importer *LogFileImporter) downloadAndApplyKVFile( RewriteRule: *rewriteRules[0], Context: reqCtx, StorageCacheId: importer.cacheKey, + CipherInfo: cipherInfo, + MasterKeys: masterKeys, } } - log.Debug("apply kv file", logutil.Leader(leader)) + log.Debug("applying kv file", logutil.Leader(leader)) resp, err := importer.importClient.ApplyKVFile(ctx, leader.GetStoreId(), req) if err != nil { return RPCResultFromError(errors.Trace(err)) } if resp.GetError() != nil { - logutil.CL(ctx).Warn("import meet error", zap.Stringer("error", resp.GetError())) + logutil.CL(ctx).Warn("import has error", zap.Stringer("error", resp.GetError())) return RPCResultFromPBError(resp.GetError()) } return RPCResultOK() diff --git a/br/pkg/restore/log_client/import_test.go b/br/pkg/restore/log_client/import_test.go index af5add181d..b9852c45b7 100644 --- a/br/pkg/restore/log_client/import_test.go +++ b/br/pkg/restore/log_client/import_test.go @@ -60,6 +60,7 @@ func TestImportKVFiles(t *testing.T) { startTS, restoreTS, false, + nil, nil, ) require.True(t, berrors.ErrInvalidArgument.Equal(err)) } @@ -268,9 +269,9 @@ func TestFileImporter(t *testing.T) { require.NoError(t, err) rewriteRules, encodeKeyFiles := prepareData() - err = importer.ImportKVFiles(ctx, encodeKeyFiles, rewriteRules, 1, 1, 1, true) + err = importer.ImportKVFiles(ctx, encodeKeyFiles, rewriteRules, 1, 1, 1, true, nil, nil) require.NoError(t, err) - err = importer.ImportKVFiles(ctx, encodeKeyFiles, rewriteRules, 1, 1, 1, false) + err = importer.ImportKVFiles(ctx, encodeKeyFiles, rewriteRules, 1, 1, 1, false, nil, nil) require.NoError(t, err) } diff --git a/br/pkg/restore/log_client/log_file_manager.go b/br/pkg/restore/log_client/log_file_manager.go index d29b4f1f7a..81af10cf54 100644 --- a/br/pkg/restore/log_client/log_file_manager.go +++ b/br/pkg/restore/log_client/log_file_manager.go @@ -12,7 +12,9 @@ import ( "github.com/pingcap/errors" backuppb "github.com/pingcap/kvproto/pkg/brpb" + "github.com/pingcap/kvproto/pkg/encryptionpb" "github.com/pingcap/log" + "github.com/pingcap/tidb/br/pkg/encryption" berrors "github.com/pingcap/tidb/br/pkg/errors" "github.com/pingcap/tidb/br/pkg/storage" "github.com/pingcap/tidb/br/pkg/stream" @@ -54,6 +56,7 @@ type streamMetadataHelper interface { length uint64, compressionType backuppb.CompressionType, storage storage.ExternalStorage, + encryptionInfo *encryptionpb.FileEncryptionInfo, ) ([]byte, error) ParseToMetadata(rawMetaData []byte) (*backuppb.Metadata, error) } @@ -85,6 +88,7 @@ type LogFileManagerInit struct { Storage storage.ExternalStorage MetadataDownloadBatchSize uint + EncryptionManager *encryption.Manager } type DDLMetaGroup struct { @@ -99,7 +103,7 @@ func CreateLogFileManager(ctx context.Context, init LogFileManagerInit) (*LogFil startTS: init.StartTS, restoreTS: init.RestoreTS, storage: init.Storage, - helper: stream.NewMetadataHelper(), + helper: stream.NewMetadataHelper(stream.WithEncryptionManager(init.EncryptionManager)), metadataDownloadBatchSize: init.MetadataDownloadBatchSize, } @@ -329,7 +333,8 @@ func (rc *LogFileManager) ReadAllEntries( kvEntries := make([]*KvEntryWithTS, 0) nextKvEntries := make([]*KvEntryWithTS, 0) - buff, err := rc.helper.ReadFile(ctx, file.Path, file.RangeOffset, file.RangeLength, file.CompressionType, rc.storage) + buff, err := rc.helper.ReadFile(ctx, file.Path, file.RangeOffset, file.RangeLength, file.CompressionType, + rc.storage, file.FileEncryptionInfo) if err != nil { return nil, nil, errors.Trace(err) } diff --git a/br/pkg/restore/snap_client/client.go b/br/pkg/restore/snap_client/client.go index 50712e2308..22f47524d7 100644 --- a/br/pkg/restore/snap_client/client.go +++ b/br/pkg/restore/snap_client/client.go @@ -647,7 +647,7 @@ func (rc *SnapClient) CreateDatabases(ctx context.Context, dbs []*metautil.Datab return nil } - log.Info("create databases in db pool", zap.Int("pool size", len(rc.dbPool))) + log.Info("create databases in db pool", zap.Int("pool size", len(rc.dbPool)), zap.Int("number of db", len(dbs))) eg, ectx := errgroup.WithContext(ctx) workers := tidbutil.NewWorkerPool(uint(len(rc.dbPool)), "DB DDL workers") for _, db_ := range dbs { diff --git a/br/pkg/stream/BUILD.bazel b/br/pkg/stream/BUILD.bazel index aff6b7ee2a..f2a38a769c 100644 --- a/br/pkg/stream/BUILD.bazel +++ b/br/pkg/stream/BUILD.bazel @@ -15,6 +15,7 @@ go_library( importpath = "github.com/pingcap/tidb/br/pkg/stream", visibility = ["//visibility:public"], deps = [ + "//br/pkg/encryption", "//br/pkg/errors", "//br/pkg/glue", "//br/pkg/httputil", @@ -36,6 +37,7 @@ go_library( "@com_github_klauspost_compress//zstd", "@com_github_pingcap_errors//:errors", "@com_github_pingcap_kvproto//pkg/brpb", + "@com_github_pingcap_kvproto//pkg/encryptionpb", "@com_github_pingcap_kvproto//pkg/metapb", "@com_github_pingcap_log//:log", "@com_github_tikv_client_go_v2//oracle", diff --git a/br/pkg/stream/stream_mgr.go b/br/pkg/stream/stream_mgr.go index d53a66b7f1..7987112846 100644 --- a/br/pkg/stream/stream_mgr.go +++ b/br/pkg/stream/stream_mgr.go @@ -16,12 +16,16 @@ package stream import ( "context" + "crypto/sha256" + "encoding/hex" "strings" "github.com/klauspost/compress/zstd" "github.com/pingcap/errors" backuppb "github.com/pingcap/kvproto/pkg/brpb" + "github.com/pingcap/kvproto/pkg/encryptionpb" "github.com/pingcap/log" + "github.com/pingcap/tidb/br/pkg/encryption" "github.com/pingcap/tidb/br/pkg/storage" "github.com/pingcap/tidb/pkg/kv" "github.com/pingcap/tidb/pkg/meta" @@ -157,16 +161,31 @@ type ContentRef struct { // MetadataHelper make restore/truncate compatible with metadataV1 and metadataV2. type MetadataHelper struct { - cache map[string]*ContentRef - decoder *zstd.Decoder + cache map[string]*ContentRef + decoder *zstd.Decoder + encryptionManager *encryption.Manager } -func NewMetadataHelper() *MetadataHelper { +type MetadataHelperOption func(*MetadataHelper) + +func WithEncryptionManager(manager *encryption.Manager) MetadataHelperOption { + return func(mh *MetadataHelper) { + mh.encryptionManager = manager + } +} + +func NewMetadataHelper(opts ...MetadataHelperOption) *MetadataHelper { decoder, _ := zstd.NewReader(nil) - return &MetadataHelper{ + helper := &MetadataHelper{ cache: make(map[string]*ContentRef), decoder: decoder, } + + for _, opt := range opts { + opt(helper) + } + + return helper } func (m *MetadataHelper) InitCacheEntry(path string, ref int) { @@ -191,6 +210,35 @@ func (m *MetadataHelper) decodeCompressedData(data []byte, compressionType backu "failed to decode compressed data: compression type is unimplemented. type id is %d", compressionType) } +func (m *MetadataHelper) verifyChecksumAndDecryptIfNeeded(ctx context.Context, data []byte, + encryptionInfo *encryptionpb.FileEncryptionInfo) ([]byte, error) { + // no need to decrypt + if encryptionInfo == nil { + return data, nil + } + + if m.encryptionManager == nil { + return nil, errors.New("need to decrypt data but encryption manager not set") + } + + // Verify checksum before decryption + if encryptionInfo.Checksum != nil { + actualChecksum := sha256.Sum256(data) + expectedChecksumHex := hex.EncodeToString(encryptionInfo.Checksum) + actualChecksumHex := hex.EncodeToString(actualChecksum[:]) + if expectedChecksumHex != actualChecksumHex { + return nil, errors.Errorf("checksum mismatch before decryption, expected %s, actual %s", + expectedChecksumHex, actualChecksumHex) + } + } + + decryptedContent, err := m.encryptionManager.Decrypt(ctx, data, encryptionInfo) + if err != nil { + return nil, errors.Trace(err) + } + return decryptedContent, nil +} + func (m *MetadataHelper) ReadFile( ctx context.Context, path string, @@ -198,6 +246,7 @@ func (m *MetadataHelper) ReadFile( length uint64, compressionType backuppb.CompressionType, storage storage.ExternalStorage, + encryptionInfo *encryptionpb.FileEncryptionInfo, ) ([]byte, error) { var err error cref, exist := m.cache[path] @@ -212,7 +261,12 @@ func (m *MetadataHelper) ReadFile( if err != nil { return nil, errors.Trace(err) } - return m.decodeCompressedData(data, compressionType) + // decrypt if needed + decryptedData, err := m.verifyChecksumAndDecryptIfNeeded(ctx, data, encryptionInfo) + if err != nil { + return nil, errors.Trace(err) + } + return m.decodeCompressedData(decryptedData, compressionType) } cref.ref -= 1 @@ -223,8 +277,12 @@ func (m *MetadataHelper) ReadFile( return nil, errors.Trace(err) } } - - buf, err := m.decodeCompressedData(cref.data[offset:offset+length], compressionType) + // decrypt if needed + decryptedData, err := m.verifyChecksumAndDecryptIfNeeded(ctx, cref.data[offset:offset+length], encryptionInfo) + if err != nil { + return nil, errors.Trace(err) + } + buf, err := m.decodeCompressedData(decryptedData, compressionType) if cref.ref <= 0 { // need reset reference information. diff --git a/br/pkg/stream/stream_misc_test.go b/br/pkg/stream/stream_misc_test.go index 2de4784a5f..3f4efb3fc5 100644 --- a/br/pkg/stream/stream_misc_test.go +++ b/br/pkg/stream/stream_misc_test.go @@ -66,14 +66,14 @@ func TestMetadataHelperReadFile(t *testing.T) { require.NoError(t, err) helper.InitCacheEntry(filename2, 2) - get_data, err := helper.ReadFile(ctx, filename1, 0, 0, backuppb.CompressionType_UNKNOWN, s) + get_data, err := helper.ReadFile(ctx, filename1, 0, 0, backuppb.CompressionType_UNKNOWN, s, nil) require.NoError(t, err) require.Equal(t, data1, get_data) - get_data, err = helper.ReadFile(ctx, filename2, 0, uint64(len(data1)), backuppb.CompressionType_UNKNOWN, s) + get_data, err = helper.ReadFile(ctx, filename2, 0, uint64(len(data1)), backuppb.CompressionType_UNKNOWN, s, nil) require.NoError(t, err) require.Equal(t, data1, get_data) get_data, err = helper.ReadFile(ctx, filename2, uint64(len(data1)), uint64(len(data2)), - backuppb.CompressionType_ZSTD, s) + backuppb.CompressionType_ZSTD, s, nil) require.NoError(t, err) require.Equal(t, data1, get_data) } diff --git a/br/pkg/task/BUILD.bazel b/br/pkg/task/BUILD.bazel index e62938e7fa..0dabd8646f 100644 --- a/br/pkg/task/BUILD.bazel +++ b/br/pkg/task/BUILD.bazel @@ -8,6 +8,7 @@ go_library( "backup_raw.go", "backup_txn.go", "common.go", + "encryption.go", "restore.go", "restore_data.go", "restore_ebs_meta.go", @@ -27,6 +28,7 @@ go_library( "//br/pkg/config", "//br/pkg/conn", "//br/pkg/conn/util", + "//br/pkg/encryption", "//br/pkg/errors", "//br/pkg/glue", "//br/pkg/httputil", @@ -106,12 +108,13 @@ go_test( "backup_test.go", "common_test.go", "config_test.go", + "encryption_test.go", "restore_test.go", "stream_test.go", ], embed = [":task"], flaky = True, - shard_count = 33, + shard_count = 38, deps = [ "//br/pkg/backup", "//br/pkg/config", diff --git a/br/pkg/task/backup.go b/br/pkg/task/backup.go index a04f14f8b5..ab77b59bdc 100644 --- a/br/pkg/task/backup.go +++ b/br/pkg/task/backup.go @@ -72,7 +72,6 @@ const ( TableBackupCmd = "Table Backup" RawBackupCmd = "Raw Backup" TxnBackupCmd = "Txn Backup" - EBSBackupCmd = "EBS Backup" ) // CompressionConfig is the configuration for sst file compression. diff --git a/br/pkg/task/common.go b/br/pkg/task/common.go index 94ef67364c..93f9b3a08c 100644 --- a/br/pkg/task/common.go +++ b/br/pkg/task/common.go @@ -91,9 +91,13 @@ const ( defaultGRPCKeepaliveTimeout = 3 * time.Second defaultCloudAPIConcurrency = 8 - flagCipherType = "crypter.method" - flagCipherKey = "crypter.key" - flagCipherKeyFile = "crypter.key-file" + flagFullBackupCipherType = "crypter.method" + flagFullBackupCipherKey = "crypter.key" + flagFullBackupCipherKeyFile = "crypter.key-file" + + flagLogBackupCipherType = "log.crypter.method" + flagLogBackupCipherKey = "log.crypter.key" + flagLogBackupCipherKeyFile = "log.crypter.key-file" flagMetadataDownloadBatchSize = "metadata-download-batch-size" defaultMetadataDownloadBatchSize = 128 @@ -104,6 +108,10 @@ const ( crypterAES256KeyLen = 32 flagFullBackupType = "type" + + masterKeysDelimiter = "," + flagMasterKeyConfig = "master-key" + flagMasterKeyCipherType = "master-key-crypter-method" ) const ( @@ -260,8 +268,17 @@ type Config struct { // GrpcKeepaliveTimeout is the max time a grpc conn can keep idel before killed. GRPCKeepaliveTimeout time.Duration `json:"grpc-keepalive-timeout" toml:"grpc-keepalive-timeout"` + // Plaintext data key mainly used for full/snapshot backup and restore. CipherInfo backuppb.CipherInfo `json:"-" toml:"-"` + // Could be used in log backup and restore but not recommended in a serious environment since data key is stored + // in PD in plaintext. + LogBackupCipherInfo backuppb.CipherInfo `json:"-" toml:"-"` + + // Master key based encryption for log restore. + // More than one can be specified for log restore if master key rotated during log backup. + MasterKeyConfig backuppb.MasterKeyConfig `json:"master-key-config" toml:"master-key-config"` + // whether there's explicit filter ExplicitFilter bool `json:"-" toml:"-"` @@ -310,17 +327,34 @@ func DefineCommonFlags(flags *pflag.FlagSet) { flags.BoolP(flagSkipCheckPath, "", false, "Skip path verification") _ = flags.MarkHidden(flagSkipCheckPath) - flags.String(flagCipherType, "plaintext", "Encrypt/decrypt method, "+ + flags.String(flagFullBackupCipherType, "plaintext", "Encrypt/decrypt method, "+ "be one of plaintext|aes128-ctr|aes192-ctr|aes256-ctr case-insensitively, "+ "\"plaintext\" represents no encrypt/decrypt") - flags.String(flagCipherKey, "", + flags.String(flagFullBackupCipherKey, "", "aes-crypter key, used to encrypt/decrypt the data "+ "by the hexadecimal string, eg: \"0123456789abcdef0123456789abcdef\"") - flags.String(flagCipherKeyFile, "", "FilePath, its content is used as the cipher-key") + flags.String(flagFullBackupCipherKeyFile, "", "FilePath, its content is used as the cipher-key") flags.Uint(flagMetadataDownloadBatchSize, defaultMetadataDownloadBatchSize, "the batch size of downloading metadata, such as log restore metadata for truncate or restore") + // log backup plaintext key flags + flags.String(flagLogBackupCipherType, "plaintext", "Encrypt/decrypt method, "+ + "be one of plaintext|aes128-ctr|aes192-ctr|aes256-ctr case-insensitively, "+ + "\"plaintext\" represents no encrypt/decrypt") + flags.String(flagLogBackupCipherKey, "", + "aes-crypter key, used to encrypt/decrypt the data "+ + "by the hexadecimal string, eg: \"0123456789abcdef0123456789abcdef\"") + flags.String(flagLogBackupCipherKeyFile, "", "FilePath, its content is used as the cipher-key") + + // master key config + flags.String(flagMasterKeyCipherType, "plaintext", "Encrypt/decrypt method, "+ + "be one of plaintext|aes128-ctr|aes192-ctr|aes256-ctr case-insensitively, "+ + "\"plaintext\" represents no encrypt/decrypt") + flags.String(flagMasterKeyConfig, "", "Master key config for point in time restore "+ + "examples: \"local:///path/to/master/key/file,"+ + "aws-kms:///{key-id}?AWS_ACCESS_KEY_ID={access-key}&AWS_SECRET_ACCESS_KEY={secret-key}®ION={region},"+ + "gcp-kms:///projects/{project-id}/locations/{location}/keyRings/{keyring}/cryptoKeys/{key-name}?AUTH=specified&CREDENTIALS={credentials}\"") _ = flags.MarkHidden(flagMetadataDownloadBatchSize) storage.DefineFlags(flags) @@ -334,10 +368,15 @@ func HiddenFlagsForStream(flags *pflag.FlagSet) { _ = flags.MarkHidden(flagRateLimit) _ = flags.MarkHidden(flagRateLimitUnit) _ = flags.MarkHidden(flagRemoveTiFlash) - _ = flags.MarkHidden(flagCipherType) - _ = flags.MarkHidden(flagCipherKey) - _ = flags.MarkHidden(flagCipherKeyFile) + _ = flags.MarkHidden(flagFullBackupCipherType) + _ = flags.MarkHidden(flagFullBackupCipherKey) + _ = flags.MarkHidden(flagFullBackupCipherKeyFile) + _ = flags.MarkHidden(flagLogBackupCipherType) + _ = flags.MarkHidden(flagLogBackupCipherKey) + _ = flags.MarkHidden(flagLogBackupCipherKeyFile) _ = flags.MarkHidden(flagSwitchModeInterval) + _ = flags.MarkHidden(flagMasterKeyConfig) + _ = flags.MarkHidden(flagMasterKeyCipherType) storage.HiddenFlagsForStream(flags) } @@ -456,7 +495,7 @@ func checkCipherKeyMatch(cipher *backuppb.CipherInfo) bool { } func (cfg *Config) parseCipherInfo(flags *pflag.FlagSet) error { - crypterStr, err := flags.GetString(flagCipherType) + crypterStr, err := flags.GetString(flagFullBackupCipherType) if err != nil { return errors.Trace(err) } @@ -470,12 +509,12 @@ func (cfg *Config) parseCipherInfo(flags *pflag.FlagSet) error { return nil } - key, err := flags.GetString(flagCipherKey) + key, err := flags.GetString(flagFullBackupCipherKey) if err != nil { return errors.Trace(err) } - keyFilePath, err := flags.GetString(flagCipherKeyFile) + keyFilePath, err := flags.GetString(flagFullBackupCipherKeyFile) if err != nil { return errors.Trace(err) } @@ -492,6 +531,43 @@ func (cfg *Config) parseCipherInfo(flags *pflag.FlagSet) error { return nil } +func (cfg *Config) parseLogBackupCipherInfo(flags *pflag.FlagSet) (bool, error) { + crypterStr, err := flags.GetString(flagLogBackupCipherType) + if err != nil { + return false, errors.Trace(err) + } + + cfg.LogBackupCipherInfo.CipherType, err = parseCipherType(crypterStr) + if err != nil { + return false, errors.Trace(err) + } + + if !utils.IsEffectiveEncryptionMethod(cfg.LogBackupCipherInfo.CipherType) { + return false, nil + } + + key, err := flags.GetString(flagLogBackupCipherKey) + if err != nil { + return false, errors.Trace(err) + } + + keyFilePath, err := flags.GetString(flagLogBackupCipherKeyFile) + if err != nil { + return false, errors.Trace(err) + } + + cfg.LogBackupCipherInfo.CipherKey, err = GetCipherKeyContent(key, keyFilePath) + if err != nil { + return false, errors.Trace(err) + } + + if !checkCipherKeyMatch(&cfg.CipherInfo) { + return false, errors.Annotate(berrors.ErrInvalidArgument, "log backup encryption method and key length not match") + } + + return true, nil +} + func (cfg *Config) normalizePDURLs() error { for i := range cfg.PD { var err error @@ -618,7 +694,17 @@ func (cfg *Config) ParseFromFlags(flags *pflag.FlagSet) error { log.L().Info("--skip-check-path is deprecated, need explicitly set it anymore") } - if err = cfg.parseCipherInfo(flags); err != nil { + err = cfg.parseCipherInfo(flags) + if err != nil { + return errors.Trace(err) + } + + hasLogBackupPlaintextKey, err := cfg.parseLogBackupCipherInfo(flags) + if err != nil { + return errors.Trace(err) + } + + if err = cfg.parseAndValidateMasterKeyInfo(hasLogBackupPlaintextKey, flags); err != nil { return errors.Trace(err) } @@ -629,6 +715,51 @@ func (cfg *Config) ParseFromFlags(flags *pflag.FlagSet) error { return cfg.normalizePDURLs() } +func (cfg *Config) parseAndValidateMasterKeyInfo(hasPlaintextKey bool, flags *pflag.FlagSet) error { + masterKeyString, err := flags.GetString(flagMasterKeyConfig) + if err != nil { + return errors.Errorf("master key flag '%s' is not defined: %v", flagMasterKeyConfig, err) + } + + if masterKeyString == "" { + return nil + } + + if hasPlaintextKey { + return errors.Errorf("invalid argument: both plaintext data key encryption and master key based encryption are set at the same time") + } + + encryptionMethodString, err := flags.GetString(flagMasterKeyCipherType) + if err != nil { + return errors.Errorf("encryption method flag '%s' is not defined: %v", flagMasterKeyCipherType, err) + } + + encryptionMethod, err := parseCipherType(encryptionMethodString) + if err != nil { + return errors.Errorf("failed to parse encryption method: %v", err) + } + + if !utils.IsEffectiveEncryptionMethod(encryptionMethod) { + return errors.Errorf("invalid encryption method: %s", encryptionMethodString) + } + + masterKeyStrings := strings.Split(masterKeyString, masterKeysDelimiter) + cfg.MasterKeyConfig = backuppb.MasterKeyConfig{ + EncryptionType: encryptionMethod, + MasterKeys: make([]*encryptionpb.MasterKey, 0, len(masterKeyStrings)), + } + + for _, keyString := range masterKeyStrings { + masterKey, err := validateAndParseMasterKeyString(strings.TrimSpace(keyString)) + if err != nil { + return errors.Wrapf(err, "invalid master key configuration: %s", keyString) + } + cfg.MasterKeyConfig.MasterKeys = append(cfg.MasterKeyConfig.MasterKeys, &masterKey) + } + + return nil +} + // NewMgr creates a new mgr at the given PD address. func NewMgr(ctx context.Context, g glue.Glue, pds []string, @@ -726,7 +857,7 @@ func ReadBackupMeta( if cfg.CipherInfo.CipherType != encryptionpb.EncryptionMethod_PLAINTEXT { iv = metaData[:metautil.CrypterIvLen] } - decryptBackupMeta, err := metautil.Decrypt(metaData[len(iv):], &cfg.CipherInfo, iv) + decryptBackupMeta, err := utils.Decrypt(metaData[len(iv):], &cfg.CipherInfo, iv) if err != nil { return nil, nil, nil, errors.Annotate(err, "decrypt failed with wrong key") } diff --git a/br/pkg/task/common_test.go b/br/pkg/task/common_test.go index c942b96bc5..c4433da574 100644 --- a/br/pkg/task/common_test.go +++ b/br/pkg/task/common_test.go @@ -185,6 +185,7 @@ func expectedDefaultConfig() Config { GRPCKeepaliveTime: 10000000000, GRPCKeepaliveTimeout: 3000000000, CipherInfo: backup.CipherInfo{CipherType: 1}, + LogBackupCipherInfo: backup.CipherInfo{CipherType: 1}, MetadataDownloadBatchSize: 0x80, } } @@ -241,3 +242,132 @@ func TestDefaultRestore(t *testing.T) { defaultConfig := expectedDefaultRestoreConfig() require.Equal(t, defaultConfig, def) } + +func TestParseAndValidateMasterKeyInfo(t *testing.T) { + tests := []struct { + name string + input string + expectedKeys []*encryptionpb.MasterKey + expectError bool + }{ + { + name: "Empty input", + input: "", + expectedKeys: nil, + expectError: false, + }, + { + name: "Single local config", + input: "local:///path/to/key", + expectedKeys: []*encryptionpb.MasterKey{ + { + Backend: &encryptionpb.MasterKey_File{ + File: &encryptionpb.MasterKeyFile{Path: "/path/to/key"}, + }, + }, + }, + expectError: false, + }, + { + name: "Single AWS config", + input: "aws-kms:///key-id?AWS_ACCESS_KEY_ID=AKIAIOSFODNN7EXAMPLE&AWS_SECRET_ACCESS_KEY=wJalrXUtnFEMI/K7MDENG/bPxRfiCYEXAMPLEKEY®ION=us-west-2", + expectedKeys: []*encryptionpb.MasterKey{ + { + Backend: &encryptionpb.MasterKey_Kms{ + Kms: &encryptionpb.MasterKeyKms{ + Vendor: "aws", + KeyId: "key-id", + Region: "us-west-2", + }, + }, + }, + }, + expectError: false, + }, + { + name: "Single Azure config", + input: "azure-kms:///key-name/key-version?AZURE_TENANT_ID=tenant-id&AZURE_CLIENT_ID=client-id&AZURE_CLIENT_SECRET=client-secret&AZURE_VAULT_NAME=vault-name", + expectedKeys: []*encryptionpb.MasterKey{ + { + Backend: &encryptionpb.MasterKey_Kms{ + Kms: &encryptionpb.MasterKeyKms{ + Vendor: "azure", + KeyId: "key-name/key-version", + AzureKms: &encryptionpb.AzureKms{ + TenantId: "tenant-id", + ClientId: "client-id", + ClientSecret: "client-secret", + KeyVaultUrl: "vault-name", + }, + }, + }, + }, + }, + expectError: false, + }, + { + name: "Single GCP config", + input: "gcp-kms:///projects/project-id/locations/global/keyRings/ring-name/cryptoKeys/key-name?CREDENTIALS=credentials", + expectedKeys: []*encryptionpb.MasterKey{ + { + Backend: &encryptionpb.MasterKey_Kms{ + Kms: &encryptionpb.MasterKeyKms{ + Vendor: "gcp", + KeyId: "projects/project-id/locations/global/keyRings/ring-name/cryptoKeys/key-name", + GcpKms: &encryptionpb.GcpKms{ + Credential: "credentials", + }, + }, + }, + }, + }, + expectError: false, + }, + { + name: "Multiple configs", + input: "local:///path/to/key," + + "aws-kms:///key-id?AWS_ACCESS_KEY_ID=AKIAIOSFODNN7EXAMPLE&AWS_SECRET_ACCESS_KEY=wJalrXUtnFEMI/K7MDENG/bPxRfiCYEXAMPLEKEY®ION=us-west-2", + expectedKeys: []*encryptionpb.MasterKey{ + { + Backend: &encryptionpb.MasterKey_File{ + File: &encryptionpb.MasterKeyFile{Path: "/path/to/key"}, + }, + }, + { + Backend: &encryptionpb.MasterKey_Kms{ + Kms: &encryptionpb.MasterKeyKms{ + Vendor: "aws", + KeyId: "key-id", + Region: "us-west-2", + }, + }, + }, + }, + expectError: false, + }, + { + name: "Invalid config", + input: "invalid:///config", + expectedKeys: nil, + expectError: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + cfg := &Config{} + flags := pflag.NewFlagSet("test", pflag.ContinueOnError) + flags.String(flagMasterKeyConfig, tt.input, "") + flags.String(flagMasterKeyCipherType, "aes256-ctr", "") + + err := cfg.parseAndValidateMasterKeyInfo(false, flags) + + if tt.expectError { + require.Error(t, err) + } else { + require.NoError(t, err) + require.Equal(t, tt.expectedKeys, cfg.MasterKeyConfig.MasterKeys) + } + }) + } +} diff --git a/br/pkg/task/encryption.go b/br/pkg/task/encryption.go new file mode 100644 index 0000000000..385bedff5f --- /dev/null +++ b/br/pkg/task/encryption.go @@ -0,0 +1,166 @@ +// Copyright 2024 PingCAP, Inc. Licensed under Apache-2.0. + +package task + +import ( + "fmt" + "net/url" + "path" + "regexp" + + "github.com/pingcap/errors" + "github.com/pingcap/kvproto/pkg/encryptionpb" +) + +const ( + SchemeLocal = "local" + SchemeAWS = "aws-kms" + SchemeAzure = "azure-kms" + SchemeGCP = "gcp-kms" + + AWSVendor = "aws" + AWSRegion = "REGION" + AWSEndpoint = "ENDPOINT" + AWSAccessKeyId = "AWS_ACCESS_KEY_ID" + AWSSecretKey = "AWS_SECRET_ACCESS_KEY" + + AzureVendor = "azure" + AzureTenantID = "AZURE_TENANT_ID" + AzureClientID = "AZURE_CLIENT_ID" + AzureClientSecret = "AZURE_CLIENT_SECRET" + AzureVaultName = "AZURE_VAULT_NAME" + + GCPVendor = "gcp" + GCPCredentials = "CREDENTIALS" +) + +var ( + awsRegex = regexp.MustCompile(`^/([^/]+)$`) + azureRegex = regexp.MustCompile(`^/(.+)$`) + gcpRegex = regexp.MustCompile(`^/projects/([^/]+)/locations/([^/]+)/keyRings/([^/]+)/cryptoKeys/([^/]+)/?$`) +) + +func validateAndParseMasterKeyString(keyString string) (encryptionpb.MasterKey, error) { + u, err := url.Parse(keyString) + if err != nil { + return encryptionpb.MasterKey{}, errors.Trace(err) + } + + switch u.Scheme { + case SchemeLocal: + return parseLocalDiskConfig(u) + case SchemeAWS: + return parseAwsKmsConfig(u) + case SchemeAzure: + return parseAzureKmsConfig(u) + case SchemeGCP: + return parseGcpKmsConfig(u) + default: + return encryptionpb.MasterKey{}, errors.Errorf("unsupported master key type: %s", u.Scheme) + } +} + +func parseLocalDiskConfig(u *url.URL) (encryptionpb.MasterKey, error) { + if !path.IsAbs(u.Path) { + return encryptionpb.MasterKey{}, errors.New("local master key path must be absolute") + } + return encryptionpb.MasterKey{ + Backend: &encryptionpb.MasterKey_File{ + File: &encryptionpb.MasterKeyFile{ + Path: u.Path, + }, + }, + }, nil +} + +func parseAwsKmsConfig(u *url.URL) (encryptionpb.MasterKey, error) { + matches := awsRegex.FindStringSubmatch(u.Path) + if matches == nil { + return encryptionpb.MasterKey{}, errors.New("invalid AWS KMS key ID format") + } + keyID := matches[1] + + q := u.Query() + region := q.Get(AWSRegion) + accessKey := q.Get(AWSAccessKeyId) + secretAccessKey := q.Get(AWSSecretKey) + + if region == "" { + return encryptionpb.MasterKey{}, errors.New("missing AWS KMS region info") + } + + var awsKms *encryptionpb.AwsKms + if accessKey != "" && secretAccessKey != "" { + awsKms = &encryptionpb.AwsKms{ + AccessKey: accessKey, + SecretAccessKey: secretAccessKey, + } + } + + return encryptionpb.MasterKey{ + Backend: &encryptionpb.MasterKey_Kms{ + Kms: &encryptionpb.MasterKeyKms{ + Vendor: AWSVendor, + KeyId: keyID, + Region: region, + Endpoint: q.Get(AWSEndpoint), // Optional + AwsKms: awsKms, // Optional, can read from env + }, + }, + }, nil +} + +func parseAzureKmsConfig(u *url.URL) (encryptionpb.MasterKey, error) { + matches := azureRegex.FindStringSubmatch(u.Path) + if matches == nil { + return encryptionpb.MasterKey{}, errors.New("invalid Azure KMS path format") + } + + keyID := matches[1] // This now captures the entire path as the key ID + q := u.Query() + + azureKms := &encryptionpb.AzureKms{ + TenantId: q.Get(AzureTenantID), + ClientId: q.Get(AzureClientID), + ClientSecret: q.Get(AzureClientSecret), + KeyVaultUrl: q.Get(AzureVaultName), + } + + if azureKms.TenantId == "" || azureKms.ClientId == "" || azureKms.ClientSecret == "" || azureKms.KeyVaultUrl == "" { + return encryptionpb.MasterKey{}, errors.New("missing required Azure KMS parameters") + } + + return encryptionpb.MasterKey{ + Backend: &encryptionpb.MasterKey_Kms{ + Kms: &encryptionpb.MasterKeyKms{ + Vendor: AzureVendor, + KeyId: keyID, + AzureKms: azureKms, + }, + }, + }, nil +} + +func parseGcpKmsConfig(u *url.URL) (encryptionpb.MasterKey, error) { + matches := gcpRegex.FindStringSubmatch(u.Path) + if matches == nil { + return encryptionpb.MasterKey{}, errors.New("invalid GCP KMS path format") + } + + projectID, location, keyRing, keyName := matches[1], matches[2], matches[3], matches[4] + q := u.Query() + + gcpKms := &encryptionpb.GcpKms{ + Credential: q.Get(GCPCredentials), + } + + return encryptionpb.MasterKey{ + Backend: &encryptionpb.MasterKey_Kms{ + Kms: &encryptionpb.MasterKeyKms{ + Vendor: GCPVendor, + KeyId: fmt.Sprintf("projects/%s/locations/%s/keyRings/%s/cryptoKeys/%s", projectID, location, keyRing, keyName), + GcpKms: gcpKms, + }, + }, + }, nil +} diff --git a/br/pkg/task/encryption_test.go b/br/pkg/task/encryption_test.go new file mode 100644 index 0000000000..033062511f --- /dev/null +++ b/br/pkg/task/encryption_test.go @@ -0,0 +1,192 @@ +// Copyright 2024 PingCAP, Inc. Licensed under Apache-2.0. + +package task + +import ( + "net/url" + "testing" + + "github.com/pingcap/kvproto/pkg/encryptionpb" + "github.com/stretchr/testify/assert" +) + +func TestParseLocalDiskConfig(t *testing.T) { + tests := []struct { + name string + input string + expected encryptionpb.MasterKey + expectError bool + }{ + { + name: "Valid local path", + input: "local:///path/to/key", + expected: encryptionpb.MasterKey{Backend: &encryptionpb.MasterKey_File{File: &encryptionpb.MasterKeyFile{Path: "/path/to/key"}}}, + expectError: false, + }, + { + name: "Invalid local path", + input: "local://relative/path", + expectError: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + u, _ := url.Parse(tt.input) + result, err := parseLocalDiskConfig(u) + + if tt.expectError { + assert.Error(t, err) + } else { + assert.NoError(t, err) + assert.Equal(t, tt.expected, result) + } + }) + } +} + +func TestParseAwsKmsConfig(t *testing.T) { + tests := []struct { + name string + input string + expected encryptionpb.MasterKey + expectError bool + }{ + { + name: "Valid AWS config", + input: "aws-kms:///key-id?AWS_ACCESS_KEY_ID=AKIAIOSFODNN7EXAMPLE&AWS_SECRET_ACCESS_KEY=wJalrXUtnFEMI/K7MDENG/bPxRfiCYEXAMPLEKEY®ION=us-west-2", + expected: encryptionpb.MasterKey{ + Backend: &encryptionpb.MasterKey_Kms{ + Kms: &encryptionpb.MasterKeyKms{ + Vendor: "aws", + KeyId: "key-id", + Region: "us-west-2", + }, + }, + }, + expectError: false, + }, + { + name: "Missing key ID", + input: "aws-kms:///?AWS_ACCESS_KEY_ID=AKIAIOSFODNN7EXAMPLE&AWS_SECRET_ACCESS_KEY=wJalrXUtnFEMI/K7MDENG/bPxRfiCYEXAMPLEKEY®ION=us-west-2", + expectError: true, + }, + { + name: "Missing required parameter", + input: "aws-kms:///key-id?AWS_ACCESS_KEY_ID=AKIAIOSFODNN7EXAMPLE®ION=us-west-2", + expectError: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + u, _ := url.Parse(tt.input) + result, err := parseAwsKmsConfig(u) + + if tt.expectError { + assert.Error(t, err) + } else { + assert.NoError(t, err) + assert.Equal(t, tt.expected, result) + } + }) + } +} + +func TestParseAzureKmsConfig(t *testing.T) { + tests := []struct { + name string + input string + expected encryptionpb.MasterKey + expectError bool + }{ + { + name: "Valid Azure config", + input: "azure-kms:///key-name/key-version?AZURE_TENANT_ID=tenant-id&AZURE_CLIENT_ID=client-id&AZURE_CLIENT_SECRET=client-secret&AZURE_VAULT_NAME=vault-name", + expected: encryptionpb.MasterKey{ + Backend: &encryptionpb.MasterKey_Kms{ + Kms: &encryptionpb.MasterKeyKms{ + Vendor: "azure", + KeyId: "key-name/key-version", + AzureKms: &encryptionpb.AzureKms{ + TenantId: "tenant-id", + ClientId: "client-id", + ClientSecret: "client-secret", + KeyVaultUrl: "vault-name", + }, + }, + }, + }, + expectError: false, + }, + { + name: "Missing required parameter", + input: "azure-kms:///key-name/key-version?AZURE_TENANT_ID=tenant-id&AZURE_CLIENT_ID=client-id&AZURE_VAULT_NAME=vault-name", + expectError: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + u, _ := url.Parse(tt.input) + result, err := parseAzureKmsConfig(u) + + if tt.expectError { + assert.Error(t, err) + } else { + assert.NoError(t, err) + assert.Equal(t, tt.expected, result) + } + }) + } +} + +func TestParseGcpKmsConfig(t *testing.T) { + tests := []struct { + name string + input string + expected encryptionpb.MasterKey + expectError bool + }{ + { + name: "Valid GCP config", + input: "gcp-kms:///projects/project-id/locations/global/keyRings/ring-name/cryptoKeys/key-name?CREDENTIALS=credentials", + expected: encryptionpb.MasterKey{ + Backend: &encryptionpb.MasterKey_Kms{ + Kms: &encryptionpb.MasterKeyKms{ + Vendor: "gcp", + KeyId: "projects/project-id/locations/global/keyRings/ring-name/cryptoKeys/key-name", + GcpKms: &encryptionpb.GcpKms{ + Credential: "credentials", + }, + }, + }, + }, + expectError: false, + }, + { + name: "Invalid path format", + input: "gcp-kms:///invalid/path?CREDENTIALS=credentials", + expectError: true, + }, + { + name: "Missing credentials", + input: "gcp-kms:///projects/project-id/locations/global/keyRings/ring-name/cryptoKeys/key-name", + expectError: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + u, _ := url.Parse(tt.input) + result, err := parseGcpKmsConfig(u) + + if tt.expectError { + assert.Error(t, err) + } else { + assert.NoError(t, err) + assert.Equal(t, tt.expected, result) + } + }) + } +} diff --git a/br/pkg/task/restore.go b/br/pkg/task/restore.go index be6b655c08..21f7686f12 100644 --- a/br/pkg/task/restore.go +++ b/br/pkg/task/restore.go @@ -247,7 +247,8 @@ type RestoreConfig struct { AllowPITRFromIncremental bool `json:"allow-pitr-from-incremental" toml:"allow-pitr-from-incremental"` // [startTs, RestoreTS] is used to `restore log` from StartTS to RestoreTS. - StartTS uint64 `json:"start-ts" toml:"start-ts"` + StartTS uint64 `json:"start-ts" toml:"start-ts"` + // if not specified system will restore to the max TS available RestoreTS uint64 `json:"restore-ts" toml:"restore-ts"` tiflashRecorder *tiflashrec.TiFlashRecorder `json:"-" toml:"-"` PitrBatchCount uint32 `json:"pitr-batch-count" toml:"pitr-batch-count"` @@ -695,12 +696,12 @@ func RunRestore(c context.Context, g glue.Glue, cmdName string, cfg *RestoreConf if err := version.CheckClusterVersion(c, mgr.GetPDClient(), version.CheckVersionForBRPiTR); err != nil { return errors.Trace(err) } - restoreError = RunStreamRestore(c, mgr, g, cmdName, cfg) + restoreError = RunStreamRestore(c, mgr, g, cfg) } else { if err := version.CheckClusterVersion(c, mgr.GetPDClient(), version.CheckVersionForBR); err != nil { return errors.Trace(err) } - restoreError = runRestore(c, mgr, g, cmdName, cfg, nil) + restoreError = runSnapshotRestore(c, mgr, g, cmdName, cfg, nil) } if restoreError != nil { return errors.Trace(restoreError) @@ -733,12 +734,13 @@ func RunRestore(c context.Context, g glue.Glue, cmdName string, cfg *RestoreConf return nil } -func runRestore(c context.Context, mgr *conn.Mgr, g glue.Glue, cmdName string, cfg *RestoreConfig, checkInfo *PiTRTaskInfo) error { +func runSnapshotRestore(c context.Context, mgr *conn.Mgr, g glue.Glue, cmdName string, cfg *RestoreConfig, checkInfo *PiTRTaskInfo) error { cfg.Adjust() defer summary.Summary(cmdName) ctx, cancel := context.WithCancel(c) defer cancel() + log.Info("starting snapshot restore") if span := opentracing.SpanFromContext(ctx); span != nil && span.Tracer() != nil { span1 := span.Tracer().StartSpan("task.RunRestore", opentracing.ChildOf(span.Context())) defer span1.Finish() @@ -836,7 +838,7 @@ func runRestore(c context.Context, mgr *conn.Mgr, g glue.Glue, cmdName string, c if client.IsIncremental() { // don't support checkpoint for the ddl restore - log.Info("the incremental snapshot restore doesn't support checkpoint mode, so unuse checkpoint.") + log.Info("the incremental snapshot restore doesn't support checkpoint mode, disable checkpoint.") cfg.UseCheckpoint = false } @@ -860,7 +862,7 @@ func runRestore(c context.Context, mgr *conn.Mgr, g glue.Glue, cmdName string, c log.Info("finish removing pd scheduler") }() - var checkpointFirstRun bool = true + var checkpointFirstRun = true if cfg.UseCheckpoint { // if the checkpoint metadata exists in the checkpoint storage, the restore is not // for the first time. diff --git a/br/pkg/task/stream.go b/br/pkg/task/stream.go index 7b167a474f..c6935fcaaf 100644 --- a/br/pkg/task/stream.go +++ b/br/pkg/task/stream.go @@ -36,6 +36,7 @@ import ( "github.com/pingcap/tidb/br/pkg/backup" "github.com/pingcap/tidb/br/pkg/checkpoint" "github.com/pingcap/tidb/br/pkg/conn" + "github.com/pingcap/tidb/br/pkg/encryption" berrors "github.com/pingcap/tidb/br/pkg/errors" "github.com/pingcap/tidb/br/pkg/glue" "github.com/pingcap/tidb/br/pkg/httputil" @@ -302,8 +303,8 @@ func NewStreamMgr(ctx context.Context, cfg *StreamConfig, g glue.Glue, isStreamS } }() - // just stream start need Storage - s := &streamMgr{ + // only stream start command needs Storage + streamManager := &streamMgr{ cfg: cfg, mgr: mgr, } @@ -323,12 +324,12 @@ func NewStreamMgr(ctx context.Context, cfg *StreamConfig, g glue.Glue, isStreamS if err = client.SetStorage(ctx, backend, &opts); err != nil { return nil, errors.Trace(err) } - s.bc = client + streamManager.bc = client // create http client to do some requirements check. - s.httpCli = httputil.NewClient(mgr.GetTLSConfig()) + streamManager.httpCli = httputil.NewClient(mgr.GetTLSConfig()) } - return s, nil + return streamManager, nil } func (s *streamMgr) close() { @@ -346,7 +347,7 @@ func (s *streamMgr) setLock(ctx context.Context) error { // adjustAndCheckStartTS checks that startTS should be smaller than currentTS, // and endTS is larger than currentTS. func (s *streamMgr) adjustAndCheckStartTS(ctx context.Context) error { - currentTS, err := s.mgr.GetTS(ctx) + currentTS, err := s.mgr.GetCurrentTsFromPD(ctx) if err != nil { return errors.Trace(err) } @@ -499,7 +500,7 @@ func RunStreamCommand( } if err := commandFn(ctx, g, cmdName, cfg); err != nil { - log.Error("failed to stream", zap.String("command", cmdName), zap.Error(err)) + log.Error("failed to run stream command", zap.String("command", cmdName), zap.Error(err)) summary.SetSuccessStatus(false) summary.CollectFailureUnit(cmdName, err) return err @@ -547,22 +548,28 @@ func RunStreamStart( log.Warn("failed to close etcd client", zap.Error(closeErr)) } }() + + // check if any import/restore task is running, it's not allowed to start log backup + // while restore is ongoing. if err = streamMgr.checkImportTaskRunning(ctx, cli.Client); err != nil { return errors.Trace(err) } + // It supports single stream log task currently. if count, err := cli.GetTaskCount(ctx); err != nil { return errors.Trace(err) } else if count > 0 { - return errors.Annotate(berrors.ErrStreamLogTaskExist, "It supports single stream log task currently") + return errors.Annotate(berrors.ErrStreamLogTaskExist, "failed to start the log backup, allow only one running task") } - exist, err := streamMgr.checkLock(ctx) + // make sure external file lock is available + locked, err := streamMgr.checkLock(ctx) if err != nil { return errors.Trace(err) } - // exist is true, which represents restart a stream task. Or create a new stream task. - if exist { + + // locked means this is a stream task restart. Or create a new stream task. + if locked { logInfo, err := getLogRange(ctx, &cfg.Config) if err != nil { return errors.Trace(err) @@ -614,6 +621,7 @@ func RunStreamStart( return errors.Annotate(berrors.ErrInvalidArgument, "nothing need to observe") } + securityConfig := generateSecurityConfig(cfg) ti := streamhelper.TaskInfo{ PBInfo: backuppb.StreamBackupTaskInfo{ Storage: streamMgr.bc.GetStorageBackend(), @@ -622,6 +630,7 @@ func RunStreamStart( Name: cfg.TaskName, TableFilter: cfg.FilterStr, CompressionType: backuppb.CompressionType_ZSTD, + SecurityConfig: &securityConfig, }, Ranges: ranges, Pausing: false, @@ -633,6 +642,30 @@ func RunStreamStart( return nil } +func generateSecurityConfig(cfg *StreamConfig) backuppb.StreamBackupTaskSecurityConfig { + if len(cfg.LogBackupCipherInfo.CipherKey) > 0 && utils.IsEffectiveEncryptionMethod(cfg.LogBackupCipherInfo.CipherType) { + return backuppb.StreamBackupTaskSecurityConfig{ + Encryption: &backuppb.StreamBackupTaskSecurityConfig_PlaintextDataKey{ + PlaintextDataKey: &backuppb.CipherInfo{ + CipherType: cfg.LogBackupCipherInfo.CipherType, + CipherKey: cfg.LogBackupCipherInfo.CipherKey, + }, + }, + } + } + if len(cfg.MasterKeyConfig.MasterKeys) > 0 && utils.IsEffectiveEncryptionMethod(cfg.MasterKeyConfig.EncryptionType) { + return backuppb.StreamBackupTaskSecurityConfig{ + Encryption: &backuppb.StreamBackupTaskSecurityConfig_MasterKeyConfig{ + MasterKeyConfig: &backuppb.MasterKeyConfig{ + EncryptionType: cfg.MasterKeyConfig.EncryptionType, + MasterKeys: cfg.MasterKeyConfig.MasterKeys, + }, + }, + } + } + return backuppb.StreamBackupTaskSecurityConfig{} +} + func RunStreamMetadata( c context.Context, g glue.Glue, @@ -995,13 +1028,13 @@ func RunStreamTruncate(c context.Context, g glue.Glue, cmdName string, cfg *Stre } if cfg.Until < sp { - console.Println("According to the log, you have truncated backup data before", em(formatTS(sp))) + console.Println("According to the log, you have truncated log backup data before", em(formatTS(sp))) if !cfg.SkipPrompt && !console.PromptBool("Continue? ") { return nil } } - readMetaDone := console.ShowTask("Reading Metadata... ", glue.WithTimeCost()) + readMetaDone := console.ShowTask("Reading log backup metadata... ", glue.WithTimeCost()) metas := stream.StreamMetadataSet{ MetadataDownloadBatchSize: cfg.MetadataDownloadBatchSize, Helper: stream.NewMetadataHelper(), @@ -1025,11 +1058,11 @@ func RunStreamTruncate(c context.Context, g glue.Glue, cmdName string, cfg *Stre kvCount += d.KVCount return }) - console.Printf("We are going to remove %s files, until %s.\n", + console.Printf("We are going to truncate %s files, up to TS %s.\n", em(fileCount), em(formatTS(cfg.Until)), ) - if !cfg.SkipPrompt && !console.PromptBool(warn("Sure? ")) { + if !cfg.SkipPrompt && !console.PromptBool(warn("Are you sure?")) { return nil } @@ -1042,7 +1075,7 @@ func RunStreamTruncate(c context.Context, g glue.Glue, cmdName string, cfg *Stre // begin to remove p := console.StartProgressBar( - "Clearing Data Files and Metadata", fileCount, + "Truncating Data Files and Metadata", fileCount, glue.WithTimeCost(), glue.WithConstExtraField("kv-count", kvCount), glue.WithConstExtraField("kv-size", fmt.Sprintf("%d(%s)", totalSize, units.HumanSize(float64(totalSize)))), @@ -1113,7 +1146,6 @@ func RunStreamRestore( c context.Context, mgr *conn.Mgr, g glue.Glue, - cmdName string, cfg *RestoreConfig, ) (err error) { ctx, cancelFn := context.WithCancel(c) @@ -1132,6 +1164,8 @@ func RunStreamRestore( if err != nil { return errors.Trace(err) } + + // if not set by user, restore to the max TS available if cfg.RestoreTS == 0 { cfg.RestoreTS = logInfo.logMaxTS } @@ -1157,7 +1191,7 @@ func RunStreamRestore( } } - log.Info("start restore on point", + log.Info("start point in time restore", zap.Uint64("restore-from", cfg.StartTS), zap.Uint64("restore-to", cfg.RestoreTS), zap.Uint64("log-min-ts", logInfo.logMinTS), zap.Uint64("log-max-ts", logInfo.logMaxTS)) if err := checkLogRange(cfg.StartTS, cfg.RestoreTS, logInfo.logMinTS, logInfo.logMaxTS); err != nil { @@ -1180,7 +1214,7 @@ func RunStreamRestore( logStorage := cfg.Config.Storage cfg.Config.Storage = cfg.FullBackupStorage // TiFlash replica is restored to down-stream on 'pitr' currently. - if err = runRestore(ctx, mgr, g, FullRestoreCmd, cfg, checkInfo); err != nil { + if err = runSnapshotRestore(ctx, mgr, g, FullRestoreCmd, cfg, checkInfo); err != nil { return errors.Trace(err) } cfg.Config.Storage = logStorage @@ -1191,9 +1225,6 @@ func RunStreamRestore( } if checkInfo.CheckpointInfo != nil && checkInfo.CheckpointInfo.Metadata != nil && checkInfo.CheckpointInfo.Metadata.TiFlashItems != nil { log.Info("load tiflash records of snapshot restore from checkpoint") - if err != nil { - return errors.Trace(err) - } cfg.tiflashRecorder.Load(checkInfo.CheckpointInfo.Metadata.TiFlashItems) } } @@ -1329,7 +1360,12 @@ func restoreStream( }() } - err = client.InstallLogFileManager(ctx, cfg.StartTS, cfg.RestoreTS, cfg.MetadataDownloadBatchSize) + encryptionManager, err := encryption.NewManager(&cfg.LogBackupCipherInfo, &cfg.MasterKeyConfig) + if err != nil { + return errors.Annotate(err, "failed to create encryption manager for log restore") + } + defer encryptionManager.Close() + err = client.InstallLogFileManager(ctx, cfg.StartTS, cfg.RestoreTS, cfg.MetadataDownloadBatchSize, encryptionManager) if err != nil { return err } @@ -1345,12 +1381,13 @@ func restoreStream( newTask = false } // get the schemas ID replace information. + // since targeted full backup storage, need to use the full backup cipher schemasReplace, err := client.InitSchemasReplaceForDDL(ctx, &logclient.InitSchemaConfig{ IsNewTask: newTask, TableFilter: cfg.TableFilter, TiFlashRecorder: cfg.tiflashRecorder, FullBackupStorage: fullBackupStorage, - }) + }, &cfg.Config.CipherInfo) if err != nil { return errors.Trace(err) } @@ -1431,7 +1468,8 @@ func restoreStream( return errors.Trace(err) } - return client.RestoreKVFiles(ctx, rewriteRules, idrules, logFilesIterWithSplit, checkpointRunner, cfg.PitrBatchCount, cfg.PitrBatchSize, updateStats, p.IncBy) + return client.RestoreKVFiles(ctx, rewriteRules, idrules, logFilesIterWithSplit, checkpointRunner, + cfg.PitrBatchCount, cfg.PitrBatchSize, updateStats, p.IncBy, &cfg.LogBackupCipherInfo, cfg.MasterKeyConfig.MasterKeys) }) if err != nil { return errors.Annotate(err, "failed to restore kv files") @@ -1547,15 +1585,15 @@ func getExternalStorageOptions(cfg *Config, u *backuppb.StorageBackend) storage. } } -func checkLogRange(restoreFrom, restoreTo, logMinTS, logMaxTS uint64) error { - // serveral ts constraint: - // logMinTS <= restoreFrom <= restoreTo <= logMaxTS - if logMinTS > restoreFrom || restoreFrom > restoreTo || restoreTo > logMaxTS { +func checkLogRange(restoreFromTS, restoreToTS, logMinTS, logMaxTS uint64) error { + // several ts constraint: + // logMinTS <= restoreFromTS <= restoreToTS <= logMaxTS + if logMinTS > restoreFromTS || restoreFromTS > restoreToTS || restoreToTS > logMaxTS { return errors.Annotatef(berrors.ErrInvalidArgument, "restore log from %d(%s) to %d(%s), "+ " but the current existed log from %d(%s) to %d(%s)", - restoreFrom, oracle.GetTimeFromTS(restoreFrom), - restoreTo, oracle.GetTimeFromTS(restoreTo), + restoreFromTS, oracle.GetTimeFromTS(restoreFromTS), + restoreToTS, oracle.GetTimeFromTS(restoreToTS), logMinTS, oracle.GetTimeFromTS(logMinTS), logMaxTS, oracle.GetTimeFromTS(logMaxTS), ) @@ -1648,7 +1686,7 @@ func getGlobalCheckpointFromStorage(ctx context.Context, s storage.ExternalStora return globalCheckPointTS, errors.Trace(err) } -// getFullBackupTS gets the snapshot-ts of full bakcup +// getFullBackupTS gets the snapshot-ts of full backup func getFullBackupTS( ctx context.Context, cfg *RestoreConfig, @@ -1663,11 +1701,17 @@ func getFullBackupTS( return 0, 0, errors.Trace(err) } - backupmeta := &backuppb.BackupMeta{} - if err = backupmeta.Unmarshal(metaData); err != nil { + decryptedMetaData, err := metautil.DecryptFullBackupMetaIfNeeded(metaData, &cfg.CipherInfo) + if err != nil { return 0, 0, errors.Trace(err) } + backupmeta := &backuppb.BackupMeta{} + if err = backupmeta.Unmarshal(decryptedMetaData); err != nil { + return 0, 0, errors.Trace(err) + } + + // start and end are identical in full backup, pick random one return backupmeta.GetEndVersion(), backupmeta.GetClusterId(), nil } @@ -1757,7 +1801,7 @@ func checkPiTRTaskInfo( cfg *RestoreConfig, ) (*PiTRTaskInfo, error) { var ( - doFullRestore = (len(cfg.FullBackupStorage) > 0) + doFullRestore = len(cfg.FullBackupStorage) > 0 curTaskInfo *checkpoint.CheckpointTaskInfoForLogRestore ) checkInfo := &PiTRTaskInfo{} diff --git a/br/pkg/task/stream_test.go b/br/pkg/task/stream_test.go index 627924e823..847699bc15 100644 --- a/br/pkg/task/stream_test.go +++ b/br/pkg/task/stream_test.go @@ -110,35 +110,6 @@ func TestCheckLogRange(t *testing.T) { } } -type fakeResolvedInfo struct { - storeID int64 - resolvedTS uint64 -} - -func fakeMetaFiles(ctx context.Context, tempDir string, infos []fakeResolvedInfo) error { - backupMetaDir := filepath.Join(tempDir, stream.GetStreamBackupMetaPrefix()) - s, err := storage.NewLocalStorage(backupMetaDir) - if err != nil { - return errors.Trace(err) - } - - for _, info := range infos { - meta := &backuppb.Metadata{ - StoreId: info.storeID, - ResolvedTs: info.resolvedTS, - } - buff, err := meta.Marshal() - if err != nil { - return errors.Trace(err) - } - filename := fmt.Sprintf("%d_%d.meta", info.storeID, info.resolvedTS) - if err = s.WriteFile(ctx, filename, buff); err != nil { - return errors.Trace(err) - } - } - return nil -} - func fakeCheckpointFiles( ctx context.Context, tmpDir string, @@ -154,7 +125,7 @@ func fakeCheckpointFiles( for _, info := range infos { filename := fmt.Sprintf("%v.ts", info.storeID) buff := make([]byte, 8) - binary.LittleEndian.PutUint64(buff, info.global_checkpoint) + binary.LittleEndian.PutUint64(buff, info.globalCheckpoint) if _, err := s.Create(ctx, filename, nil); err != nil { return errors.Trace(err) } @@ -170,8 +141,8 @@ func fakeCheckpointFiles( } type fakeGlobalCheckPoint struct { - storeID int64 - global_checkpoint uint64 + storeID int64 + globalCheckpoint uint64 } func TestGetGlobalCheckpointFromStorage(t *testing.T) { @@ -182,16 +153,16 @@ func TestGetGlobalCheckpointFromStorage(t *testing.T) { infos := []fakeGlobalCheckPoint{ { - storeID: 1, - global_checkpoint: 98, + storeID: 1, + globalCheckpoint: 98, }, { - storeID: 2, - global_checkpoint: 90, + storeID: 2, + globalCheckpoint: 90, }, { - storeID: 2, - global_checkpoint: 99, + storeID: 2, + globalCheckpoint: 99, }, } diff --git a/br/pkg/utils/BUILD.bazel b/br/pkg/utils/BUILD.bazel index d185e66916..fa18a8317b 100644 --- a/br/pkg/utils/BUILD.bazel +++ b/br/pkg/utils/BUILD.bazel @@ -7,6 +7,7 @@ go_library( "db.go", "dyn_pprof_other.go", "dyn_pprof_unix.go", + "encryption.go", "error_handling.go", "json.go", "key.go", @@ -35,6 +36,7 @@ go_library( "//pkg/parser/types", "//pkg/sessionctx", "//pkg/util", + "//pkg/util/encrypt", "//pkg/util/logutil", "//pkg/util/sqlexec", "@com_github_cheggaaa_pb_v3//:pb", @@ -43,6 +45,7 @@ go_library( "@com_github_pingcap_errors//:errors", "@com_github_pingcap_failpoint//:failpoint", "@com_github_pingcap_kvproto//pkg/brpb", + "@com_github_pingcap_kvproto//pkg/encryptionpb", "@com_github_pingcap_kvproto//pkg/metapb", "@com_github_pingcap_log//:log", "@com_github_tikv_client_go_v2//oracle", diff --git a/br/pkg/utils/encryption.go b/br/pkg/utils/encryption.go new file mode 100644 index 0000000000..9a81f01b2e --- /dev/null +++ b/br/pkg/utils/encryption.go @@ -0,0 +1,30 @@ +package utils + +import ( + "github.com/pingcap/errors" + backuppb "github.com/pingcap/kvproto/pkg/brpb" + "github.com/pingcap/kvproto/pkg/encryptionpb" + berrors "github.com/pingcap/tidb/br/pkg/errors" + "github.com/pingcap/tidb/pkg/util/encrypt" +) + +func Decrypt(content []byte, cipher *backuppb.CipherInfo, iv []byte) ([]byte, error) { + if len(content) == 0 || cipher == nil { + return content, nil + } + + switch cipher.CipherType { + case encryptionpb.EncryptionMethod_PLAINTEXT: + return content, nil + case encryptionpb.EncryptionMethod_AES128_CTR, + encryptionpb.EncryptionMethod_AES192_CTR, + encryptionpb.EncryptionMethod_AES256_CTR: + return encrypt.AESDecryptWithCTR(content, cipher.CipherKey, iv) + default: + return content, errors.Annotatef(berrors.ErrInvalidArgument, "cipher type invalid %s", cipher.CipherType) + } +} + +func IsEffectiveEncryptionMethod(method encryptionpb.EncryptionMethod) bool { + return method != encryptionpb.EncryptionMethod_UNKNOWN && method != encryptionpb.EncryptionMethod_PLAINTEXT +} diff --git a/br/tests/README.md b/br/tests/README.md index 6338d3ebcd..009a0230e4 100644 --- a/br/tests/README.md +++ b/br/tests/README.md @@ -33,7 +33,7 @@ This folder contains all tests which relies on external processes such as TiDB. ## Preparations -1. The following 9 executables must be copied or linked into these locations: +1. The following 9 executables must be copied or linked into the `bin` folder under the TiDB root dir: * `bin/tidb-server` * `bin/tikv-server` @@ -80,7 +80,7 @@ If you have docker installed, you can skip step 1 and step 2 by running 1. Build `br.test` using `make build_for_br_integration_test` 2. Check that all 9 required executables and `br` executable exist 3. Select the tests to run using `export TEST_NAME=" ..."` -3. Execute `tests/run.sh` +4. Execute `br/tests/run.sh` If the first two steps are done before, you could also run `tests/run.sh` directly. diff --git a/br/tests/br_encryption/run.sh b/br/tests/br_encryption/run.sh new file mode 100755 index 0000000000..84075aff98 --- /dev/null +++ b/br/tests/br_encryption/run.sh @@ -0,0 +1,435 @@ +#!/bin/sh +# +# Copyright 2024 PingCAP, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +set -eu +. run_services +CUR=$(cd "$(dirname "$0")" && pwd) + +# const value +PREFIX="encryption_backup" +res_file="$TEST_DIR/sql_res.$TEST_NAME.txt" +DB="$TEST_NAME" +TABLE="usertable" +DB_COUNT=3 +TASK_NAME="encryption_test" + +create_db_with_table() { + for i in $(seq $DB_COUNT); do + run_sql "CREATE DATABASE $DB${i};" + go-ycsb load mysql -P $CUR/workload -p mysql.host=$TIDB_IP -p mysql.port=$TIDB_PORT -p mysql.user=root -p mysql.db=$DB${i} -p recordcount=1000 + done +} + +start_log_backup() { + _storage=$1 + _encryption_args=$2 + echo "start log backup task" + run_br --pd "$PD_ADDR" log start --task-name $TASK_NAME -s "$_storage" $_encryption_args +} + +drop_db() { + for i in $(seq $DB_COUNT); do + run_sql "DROP DATABASE IF EXISTS $DB${i};" + done +} + +insert_additional_data() { + local prefix=$1 + for i in $(seq $DB_COUNT); do + go-ycsb load mysql -P $CUR/workload -p mysql.host=$TIDB_IP -p mysql.port=$TIDB_PORT -p mysql.user=root -p mysql.db=$DB${i} -p insertcount=1000 -p insertstart=1000000 -p recordcount=1001000 -p workload=core + done +} + +wait_log_checkpoint_advance() { + echo "wait for log checkpoint to advance" + sleep 10 + local current_ts=$(python3 -c "import time; print(int(time.time() * 1000) << 18)") + echo "current ts: $current_ts" + i=0 + while true; do + # extract the checkpoint ts of the log backup task. If there is some error, the checkpoint ts should be empty + log_backup_status=$(unset BR_LOG_TO_TERM && run_br --skip-goleak --pd $PD_ADDR log status --task-name $TASK_NAME --json 2>br.log) + echo "log backup status: $log_backup_status" + local checkpoint_ts=$(echo "$log_backup_status" | head -n 1 | jq 'if .[0].last_errors | length == 0 then .[0].checkpoint else empty end') + echo "checkpoint ts: $checkpoint_ts" + + # check whether the checkpoint ts is a number + if [ $checkpoint_ts -gt 0 ] 2>/dev/null; then + if [ $checkpoint_ts -gt $current_ts ]; then + echo "the checkpoint has advanced" + break + fi + echo "the checkpoint hasn't advanced" + i=$((i+1)) + if [ "$i" -gt 50 ]; then + echo 'the checkpoint lag is too large' + exit 1 + fi + sleep 10 + else + echo "TEST: [$TEST_NAME] failed to wait checkpoint advance!" + exit 1 + fi + done +} + +calculate_checksum() { + local db=$1 + local checksum=$(run_sql "USE $db; ADMIN CHECKSUM TABLE $TABLE;" | awk '/CHECKSUM/{print $2}') + echo $checksum +} + +check_db_consistency() { + fail=false + for i in $(seq $DB_COUNT); do + local original_checksum=${checksum_ori[i]} + local new_checksum=$(calculate_checksum "$DB${i}") + + if [ "$original_checksum" != "$new_checksum" ]; then + fail=true + echo "TEST: [$TEST_NAME] checksum mismatch on database $DB${i}" + echo "Original checksum: $original_checksum, New checksum: $new_checksum" + else + echo "Database $DB${i} checksum match: $new_checksum" + fi + done + + if $fail; then + echo "TEST: [$TEST_NAME] data consistency check failed!" + return 1 + fi + echo "TEST: [$TEST_NAME] data consistency check passed." + return 0 +} + +verify_dbs_empty() { + echo "Verifying databases are empty" + for i in $(seq $DB_COUNT); do + db_name="$DB${i}" + table_count=$(run_sql "USE $db_name; SHOW TABLES;" | wc -l) + if [ "$table_count" -ne 0 ]; then + echo "ERROR: Database $db_name is not empty" + return 1 + fi + done + echo "All databases are empty" + return 0 +} + +run_backup_restore_test() { + local encryption_mode=$1 + local full_encryption_args=$2 + local log_encryption_args=$3 + + echo "===== run_backup_restore_test $encryption_mode $full_encryption_args $log_encryption_args =====" + + restart_services || { echo "Failed to restart services"; exit 1; } + + # Drop existing databases before starting the test + drop_db || { echo "Failed to drop databases"; exit 1; } + + # Start log backup + start_log_backup "local://$TEST_DIR/$PREFIX/log" "$log_encryption_args" || { echo "Failed to start log backup"; exit 1; } + + # Create test databases and insert initial data + create_db_with_table || { echo "Failed to create databases and tables"; exit 1; } + + # Calculate and store original checksums + for i in $(seq $DB_COUNT); do + checksum_ori[${i}]=$(calculate_checksum "$DB${i}") || { echo "Failed to calculate initial checksum"; exit 1; } + done + + # Full backup + echo "run full backup with $encryption_mode" + run_br --pd "$PD_ADDR" backup full -s "local://$TEST_DIR/$PREFIX/full" $full_encryption_args || { echo "Full backup failed"; exit 1; } + + # Insert additional test data + insert_additional_data "${encryption_mode}_after_full_backup" || { echo "Failed to insert additional data"; exit 1; } + + # Update checksums after inserting additional data + for i in $(seq $DB_COUNT); do + checksum_ori[${i}]=$(calculate_checksum "$DB${i}") || { echo "Failed to calculate checksum after insertion"; exit 1; } + done + + wait_log_checkpoint_advance || { echo "Failed to wait for log checkpoint"; exit 1; } + + #sanity check pause still works + run_br log pause --task-name $TASK_NAME --pd $PD_ADDR || { echo "Failed to pause log backup"; exit 1; } + + #sanity check resume still works + run_br log resume --task-name $TASK_NAME --pd $PD_ADDR || { echo "Failed to resume log backup"; exit 1; } + + #sanity check stop still works + run_br log stop --task-name $TASK_NAME --pd $PD_ADDR || { echo "Failed to stop log backup"; exit 1; } + + # restart service should clean up everything + restart_services || { echo "Failed to restart services"; exit 1; } + + verify_dbs_empty || { echo "Failed to verify databases are empty"; exit 1; } + + # Run pitr restore and measure the performance + echo "restore log backup with $full_encryption_args and $log_encryption_args" + local start_time=$(date +%s.%N) + timeout 300 run_br --pd "$PD_ADDR" restore point -s "local://$TEST_DIR/$PREFIX/log" --full-backup-storage "local://$TEST_DIR/$PREFIX/full" $full_encryption_args $log_encryption_args || { + echo "Log backup restore failed or timed out after 5 minutes" + exit 1 + } + local end_time=$(date +%s.%N) + local duration=$(echo "$end_time - $start_time" | bc | awk '{printf "%.3f", $0}') + echo "${encryption_mode} took ${duration} seconds" + echo "${encryption_mode},${duration}" >> "$TEST_DIR/performance_results.csv" + + # Check data consistency after restore + echo "check data consistency after restore" + check_db_consistency || { echo "TEST: [$TEST_NAME] $encryption_mode backup and restore (including log) failed"; exit 1; } + + # sanity check truncate still works + # make sure some files exists in log dir + log_dir="$TEST_DIR/$PREFIX/log" + if [ -z "$(ls -A $log_dir)" ]; then + echo "Error: No files found in the log directory $log_dir" + exit 1 + else + echo "Files exist in the log directory $log_dir" + fi + current_time=$(date -u +"%Y-%m-%d %H:%M:%S+0000") + run_br log truncate -s "local://$TEST_DIR/$PREFIX/log" --until "$current_time" -y || { echo "Failed to truncate log backup"; exit 1; } + # make sure no files exist in log dir + if [ -z "$(ls -A $log_dir)" ]; then + echo "Error: Files still exist in the log directory $log_dir" + exit 1 + else + echo "No files exist in the log directory $log_dir" + fi + + # Clean up after the test + drop_db || { echo "Failed to drop databases after test"; exit 1; } + rm -rf "$TEST_DIR/$PREFIX" + + echo "TEST: [$TEST_NAME] $encryption_mode backup and restore (including log) passed" +} + +start_and_wait_for_localstack() { + # Start LocalStack in the background with only the required services + SERVICES=s3,ec2,kms localstack start -d + + echo "Waiting for LocalStack services to be ready..." + max_attempts=30 + attempt=0 + while [ $attempt -lt $max_attempts ]; do + response=$(curl -s "http://localhost:4566/_localstack/health") + if echo "$response" | jq -e '.services.s3 == "running" or .services.s3 == "available"' > /dev/null && \ + echo "$response" | jq -e '.services.ec2 == "running" or .services.ec2 == "available"' > /dev/null && \ + echo "$response" | jq -e '.services.kms == "running" or .services.kms == "available"' > /dev/null; then + echo "LocalStack services are ready" + return 0 + fi + attempt=$((attempt+1)) + echo "Waiting for LocalStack services... Attempt $attempt of $max_attempts" + sleep 2 + done + + echo "LocalStack services did not become ready in time" + localstack stop + return 1 +} + +test_backup_encrypted_restore_unencrypted() { + echo "===== Testing backup with encryption, restore without encryption =====" + + restart_services || { echo "Failed to restart services"; exit 1; } + + # Start log backup + start_log_backup "local://$TEST_DIR/$PREFIX/log" "--log.crypter.method AES256-CTR --log.crypter.key 0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef" || { echo "Failed to start log backup"; exit 1; } + + # Create test databases and insert initial data + create_db_with_table || { echo "Failed to create databases and tables"; exit 1; } + + # Backup with encryption + run_br --pd $PD_ADDR backup full -s "local://$TEST_DIR/$PREFIX/full --crypter.method AES256-CTR --crypter.key 0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef" + + # Insert additional test data + insert_additional_data "insert_after_full_backup" || { echo "Failed to insert additional data"; exit 1; } + + wait_log_checkpoint_advance || { echo "Failed to wait for log checkpoint"; exit 1; } + + + # Stop and clean the cluster + restart_services || { echo "Failed to restart services"; exit 1; } + + # Try to restore without encryption (this should fail) + if run_br --pd "$PD_ADDR" restore point -s "local://$TEST_DIR/$PREFIX/log" --full-backup-storage "local://$TEST_DIR/$PREFIX/full --crypter.method AES256-CTR --crypter.key 0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef"; then + echo "Error: Restore without encryption should have failed, but it succeeded" + exit 1 + else + echo "Restore without encryption failed as expected" + fi + + # Clean up after the test + drop_db || { echo "Failed to drop databases after test"; exit 1; } + rm -rf "$TEST_DIR/$PREFIX" + + echo "TEST: test_backup_encrypted_restore_unencrypted passed" +} + + +test_plaintext() { + run_backup_restore_test "plaintext" "" "" +} + +test_plaintext_data_key() { + run_backup_restore_test "plaintext-data-key" "--crypter.method AES256-CTR --crypter.key 0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef" "--log.crypter.method AES256-CTR --log.crypter.key 0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef" +} + +test_local_master_key() { + _MASTER_KEY_DIR="$TEST_DIR/$PREFIX/master_key" + mkdir -p "$_MASTER_KEY_DIR" + openssl rand -hex 32 > "$_MASTER_KEY_DIR/master.key" + + _MASTER_KEY_PATH="local:///$_MASTER_KEY_DIR/master.key" + + run_backup_restore_test "local_master_key" "" "--master-key-crypter-method AES256-CTR --master-key $_MASTER_KEY_PATH" + + rm -rf "$_MASTER_KEY_DIR" +} + +test_aws_kms() { + # Start LocalStack and wait for services to be ready + if ! start_and_wait_for_localstack; then + echo "Failed to start LocalStack services" + exit 1 + fi + + # localstack listening port + ENDPOINT="http://localhost:4566" + + # Create KMS key using curl + KMS_RESPONSE=$(curl -X POST "$ENDPOINT/kms/" \ + -H "Content-Type: application/x-amz-json-1.1" \ + -H "X-Amz-Target: TrentService.CreateKey" \ + -d '{ + "Description": "My test key", + "KeyUsage": "ENCRYPT_DECRYPT", + "Origin": "AWS_KMS" + }') + + echo "KMS CreateKey response: $KMS_RESPONSE" + + AWS_KMS_KEY_ID=$(echo "$KMS_RESPONSE" | jq -r '.KeyMetadata.KeyId') + AWS_ACCESS_KEY_ID="TEST" + AWS_SECRET_ACCESS_KEY="TEST" + REGION="us-east-1" + + AWS_KMS_URI="aws-kms:///${AWS_KMS_KEY_ID}?AWS_ACCESS_KEY_ID=${AWS_ACCESS_KEY_ID}&AWS_SECRET_ACCESS_KEY=${AWS_SECRET_ACCESS_KEY}®ION=${REGION}&ENDPOINT=${ENDPOINT}" + + run_backup_restore_test "aws_kms" "--crypter.method AES256-CTR --crypter.key 0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef" "--master-key-crypter-method AES256-CTR --master-key $AWS_KMS_URI" + + # Stop LocalStack + localstack stop +} + +test_aws_kms_with_iam() { + # Start LocalStack and wait for services to be ready + if ! start_and_wait_for_localstack; then + echo "Failed to start LocalStack services" + exit 1 + fi + + # localstack listening port + ENDPOINT="http://localhost:4566" + + # Create KMS key using curl + KMS_RESPONSE=$(curl -X POST "$ENDPOINT/kms/" \ + -H "Content-Type: application/x-amz-json-1.1" \ + -H "X-Amz-Target: TrentService.CreateKey" \ + -d '{ + "Description": "My test key", + "KeyUsage": "ENCRYPT_DECRYPT", + "Origin": "AWS_KMS" + }') + + echo "KMS CreateKey response: $KMS_RESPONSE" + + AWS_KMS_KEY_ID=$(echo "$KMS_RESPONSE" | jq -r '.KeyMetadata.KeyId') + # export these two as required by aws-kms backend + export AWS_ACCESS_KEY_ID="TEST" + export AWS_SECRET_ACCESS_KEY="TEST" + REGION="us-east-1" + + AWS_KMS_URI="aws-kms:///${AWS_KMS_KEY_ID}?®ION=${REGION}&ENDPOINT=${ENDPOINT}" + + run_backup_restore_test "aws_kms" "--crypter.method AES256-CTR --crypter.key 0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef" "--master-key-crypter-method AES256-CTR --master-key $AWS_KMS_URI" + + # Stop LocalStack + localstack stop +} + +test_gcp_kms() { + # Ensure GCP credentials are set + if [ -z "$GOOGLE_APPLICATION_CREDENTIALS" ]; then + echo "GCP credentials not set. Skipping GCP KMS test." + return + fi + + # Replace these with your actual GCP KMS details + GCP_PROJECT_ID="carbide-network-435219-q3" + GCP_LOCATION="us-west1" + GCP_KEY_RING="local-kms-testing" + GCP_KEY_NAME="kms-testing-key" + GCP_CREDENTIALS="$GOOGLE_APPLICATION_CREDENTIALS" + + GCP_KMS_URI="gcp-kms:///projects/$GCP_PROJECT_ID/locations/$GCP_LOCATION/keyRings/$GCP_KEY_RING/cryptoKeys/$GCP_KEY_NAME?AUTH=specified&CREDENTIALS=$GCP_CREDENTIALS" + + run_backup_restore_test "gcp_kms" "--crypter.method AES256-CTR --crypter.key 0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef" "--master-key-crypter-method AES256-CTR --master-key $GCP_KMS_URI" +} + +test_mixed_full_encrypted_log_plain() { + local full_encryption_args="--crypter.method AES256-CTR --crypter.key 0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef" + local log_encryption_args="" + + run_backup_restore_test "mixed_full_encrypted_log_plain" "$full_encryption_args" "$log_encryption_args" +} + +test_mixed_full_plain_log_encrypted() { + local full_encryption_args="" + local log_encryption_args="--log.crypter.method AES256-CTR --log.crypter.key 0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef" + + run_backup_restore_test "mixed_full_plain_log_encrypted" "$full_encryption_args" "$log_encryption_args" +} + +# Initialize performance results file +echo "Operation,Encryption Mode,Duration (seconds)" > "$TEST_DIR/performance_results.csv" + +# Run tests +test_backup_encrypted_restore_unencrypted +test_plaintext +test_plaintext_data_key +test_local_master_key +# some issue running in CI, will fix later +#test_aws_kms +#test_aws_kms_with_iam +test_mixed_full_encrypted_log_plain +test_mixed_full_plain_log_encrypted + +# uncomment for manual GCP KMS testing +#test_gcp_kms + +echo "All encryption tests passed successfully" + +# Display performance results +echo "Performance Results:" +cat "$TEST_DIR/performance_results.csv" + diff --git a/br/tests/br_encryption/workload b/br/tests/br_encryption/workload new file mode 100644 index 0000000000..448ca3c1a4 --- /dev/null +++ b/br/tests/br_encryption/workload @@ -0,0 +1,12 @@ +recordcount=1000 +operationcount=0 +workload=core + +readallfields=true + +readproportion=0 +updateproportion=0 +scanproportion=0 +insertproportion=0 + +requestdistribution=uniform \ No newline at end of file diff --git a/br/tests/config/tikv.toml b/br/tests/config/tikv.toml index a469b38998..22126549ab 100644 --- a/br/tests/config/tikv.toml +++ b/br/tests/config/tikv.toml @@ -36,3 +36,6 @@ path = "/tmp/backup_restore_test/master-key-file" [log-backup] max-flush-interval = "50s" +[gc] +ratio-threshold = 1.1 + diff --git a/br/tests/download_integration_test_binaries.sh b/br/tests/download_integration_test_binaries.sh index ef04bf04a5..36bf829467 100755 --- a/br/tests/download_integration_test_binaries.sh +++ b/br/tests/download_integration_test_binaries.sh @@ -43,6 +43,8 @@ minio_cli_url="${file_server_url}/download/builds/minio/minio/RELEASE.2020-02-27 kes_url="${file_server_url}/download/kes" fake_gcs_server_url="${file_server_url}/download/builds/fake-gcs-server" brv_url="${file_server_url}/download/builds/brv4.0.8" +# already manually uploaded to file server for localstack +localstack_url="${file_server_url}/download/localstack-cli.tar.gz" set -o nounset @@ -103,6 +105,13 @@ function main() { download "$fake_gcs_server_url" "fake-gcs-server" "third_bin/fake-gcs-server" download "$brv_url" "brv4.0.8" "third_bin/brv4.0.8" + # Download and set up LocalStack + download "$localstack_url" "localstack-cli.tar.gz" "tmp/localstack-cli.tar.gz" + mkdir -p tmp/localstack_extract + tar -xzf tmp/localstack-cli.tar.gz -C tmp/localstack_extract + mv tmp/localstack_extract/localstack/* third_bin/ + rm -rf tmp/localstack_extract + chmod +x third_bin/* rm -rf tmp rm -rf third_bin/bin diff --git a/br/tests/run_group_br_tests.sh b/br/tests/run_group_br_tests.sh index 8b49716388..a678adba80 100755 --- a/br/tests/run_group_br_tests.sh +++ b/br/tests/run_group_br_tests.sh @@ -26,7 +26,7 @@ groups=( ["G03"]='br_incompatible_tidb_config br_incremental br_incremental_index br_incremental_only_ddl br_incremental_same_table br_insert_after_restore br_key_locked br_log_test br_move_backup br_mv_index br_other br_partition_add_index br_tidb_placement_policy br_tiflash br_tiflash_conflict' ["G04"]='br_range br_replica_read br_restore_TDE_enable br_restore_log_task_enable br_s3 br_shuffle_leader br_shuffle_region br_single_table' ["G05"]='br_skip_checksum br_split_region_fail br_systables br_table_filter br_txn br_stats br_clustered_index br_crypter' - ["G06"]='br_tikv_outage br_tikv_outage3 br_restore_checkpoint' + ["G06"]='br_tikv_outage br_tikv_outage3 br_restore_checkpoint br_encryption' ["G07"]='br_pitr' ["G08"]='br_tikv_outage2 br_ttl br_views_and_sequences br_z_gc_safepoint br_autorandom br_file_corruption' ) diff --git a/build/BUILD.bazel b/build/BUILD.bazel index 2069bf10e2..e2e4619533 100644 --- a/build/BUILD.bazel +++ b/build/BUILD.bazel @@ -1,9 +1,9 @@ -package(default_visibility = ["//visibility:public"]) - -load("@io_bazel_rules_go//go:def.bzl", "go_library", "nogo") load("@bazel_skylib//rules:common_settings.bzl", "bool_flag") +load("@io_bazel_rules_go//go:def.bzl", "go_library", "nogo") load("//build/linter/staticcheck:def.bzl", "staticcheck_analyzers") +package(default_visibility = ["//visibility:public"]) + bool_flag( name = "with_nogo_flag", build_setting_default = False, diff --git a/go.mod b/go.mod index 280263b36f..6829dca916 100644 --- a/go.mod +++ b/go.mod @@ -3,6 +3,7 @@ module github.com/pingcap/tidb go 1.21 require ( + cloud.google.com/go/kms v1.15.7 cloud.google.com/go/storage v1.38.0 github.com/Azure/azure-sdk-for-go/sdk/azcore v1.12.0 github.com/Azure/azure-sdk-for-go/sdk/azidentity v1.6.0 @@ -305,7 +306,7 @@ require ( google.golang.org/genproto v0.0.0-20240227224415-6ceb2ff114de // indirect google.golang.org/genproto/googleapis/api v0.0.0-20240318140521-94a12d6c2237 // indirect google.golang.org/genproto/googleapis/rpc v0.0.0-20240515191416-fc5f0ca64291 // indirect - google.golang.org/protobuf v1.34.2 // indirect + google.golang.org/protobuf v1.34.2 gopkg.in/inf.v0 v0.9.1 // indirect gopkg.in/natefinch/lumberjack.v2 v2.2.1 // indirect gopkg.in/yaml.v3 v3.0.1 // indirect diff --git a/tests/_utils/run_services b/tests/_utils/run_services index 5963ef125d..8f8a31caf7 100644 --- a/tests/_utils/run_services +++ b/tests/_utils/run_services @@ -47,8 +47,10 @@ stop() { } restart_services() { + echo "Restarting services" stop_services start_services + echo "Services restarted" } stop_services() {