br: add log backup/restore encryption support (#55757)

close pingcap/tidb#55834
This commit is contained in:
Wenqi Mou
2024-09-25 22:31:09 -04:00
committed by GitHub
parent 7b6209df99
commit 7c88876a7f
56 changed files with 2750 additions and 166 deletions

View File

@ -74,14 +74,14 @@ func runRestoreCommand(command *cobra.Command, cmdName string) error {
if err := task.RunRestore(GetDefaultContext(), tidbGlue, cmdName, &cfg); err != nil {
log.Error("failed to restore", zap.Error(err))
printWorkaroundOnFullRestoreError(command, err)
printWorkaroundOnFullRestoreError(err)
return errors.Trace(err)
}
return nil
}
// print workaround when we met not fresh or incompatible cluster error on full cluster restore
func printWorkaroundOnFullRestoreError(command *cobra.Command, err error) {
func printWorkaroundOnFullRestoreError(err error) {
if !errors.ErrorEqual(err, berrors.ErrRestoreNotFreshCluster) &&
!errors.ErrorEqual(err, berrors.ErrRestoreIncompatibleSys) {
return

View File

@ -33,6 +33,7 @@ import (
"github.com/pingcap/tidb/br/pkg/rtree"
"github.com/pingcap/tidb/br/pkg/storage"
"github.com/pingcap/tidb/br/pkg/summary"
"github.com/pingcap/tidb/br/pkg/utils"
"github.com/pingcap/tidb/pkg/util"
"go.uber.org/zap"
"golang.org/x/sync/errgroup"
@ -561,7 +562,7 @@ func parseCheckpointData[K KeyType, V ValueType](
*pastDureTime = checkpointData.DureTime
}
for _, meta := range checkpointData.RangeGroupMetas {
decryptContent, err := metautil.Decrypt(meta.RangeGroupsEncriptedData, cipher, meta.CipherIv)
decryptContent, err := utils.Decrypt(meta.RangeGroupsEncriptedData, cipher, meta.CipherIv)
if err != nil {
return errors.Trace(err)
}

View File

@ -299,8 +299,8 @@ func (mgr *Mgr) Close() {
mgr.PdController.Close()
}
// GetTS gets current ts from pd.
func (mgr *Mgr) GetTS(ctx context.Context) (uint64, error) {
// GetCurrentTsFromPD gets current ts from PD.
func (mgr *Mgr) GetCurrentTsFromPD(ctx context.Context) (uint64, error) {
p, l, err := mgr.GetPDClient().GetTS(ctx)
if err != nil {
return 0, errors.Trace(err)

View File

@ -0,0 +1,15 @@
load("@io_bazel_rules_go//go:def.bzl", "go_library")
go_library(
name = "encryption",
srcs = ["manager.go"],
importpath = "github.com/pingcap/tidb/br/pkg/encryption",
visibility = ["//visibility:public"],
deps = [
"//br/pkg/encryption/master_key",
"//br/pkg/utils",
"@com_github_pingcap_errors//:errors",
"@com_github_pingcap_kvproto//pkg/brpb",
"@com_github_pingcap_kvproto//pkg/encryptionpb",
],
)

View File

@ -0,0 +1,93 @@
// Copyright 2024 PingCAP, Inc. Licensed under Apache-2.0.
package encryption
import (
"context"
"github.com/pingcap/errors"
backuppb "github.com/pingcap/kvproto/pkg/brpb"
"github.com/pingcap/kvproto/pkg/encryptionpb"
encryption "github.com/pingcap/tidb/br/pkg/encryption/master_key"
"github.com/pingcap/tidb/br/pkg/utils"
)
type Manager struct {
cipherInfo *backuppb.CipherInfo
masterKeyBackends *encryption.MultiMasterKeyBackend
encryptionMethod *encryptionpb.EncryptionMethod
}
func NewManager(cipherInfo *backuppb.CipherInfo, masterKeyConfigs *backuppb.MasterKeyConfig) (*Manager, error) {
// should never happen since config has default
if cipherInfo == nil || masterKeyConfigs == nil {
return nil, errors.New("cipherInfo or masterKeyConfigs is nil")
}
if utils.IsEffectiveEncryptionMethod(cipherInfo.CipherType) {
return &Manager{
cipherInfo: cipherInfo,
masterKeyBackends: nil,
encryptionMethod: nil,
}, nil
}
if utils.IsEffectiveEncryptionMethod(masterKeyConfigs.EncryptionType) {
masterKeyBackends, err := encryption.NewMultiMasterKeyBackend(masterKeyConfigs.GetMasterKeys())
if err != nil {
return nil, errors.Trace(err)
}
return &Manager{
cipherInfo: nil,
masterKeyBackends: masterKeyBackends,
encryptionMethod: &masterKeyConfigs.EncryptionType,
}, nil
}
return nil, nil
}
func (m *Manager) Decrypt(ctx context.Context, content []byte, fileEncryptionInfo *encryptionpb.FileEncryptionInfo) (
[]byte, error) {
switch mode := fileEncryptionInfo.Mode.(type) {
case *encryptionpb.FileEncryptionInfo_PlainTextDataKey:
if m.cipherInfo == nil {
return nil, errors.New("plaintext data key info is required but not set")
}
decryptedContent, err := utils.Decrypt(content, m.cipherInfo, fileEncryptionInfo.FileIv)
if err != nil {
return nil, errors.Annotate(err, "failed to decrypt content using plaintext data key")
}
return decryptedContent, nil
case *encryptionpb.FileEncryptionInfo_MasterKeyBased:
encryptedContents := fileEncryptionInfo.GetMasterKeyBased().DataKeyEncryptedContent
if len(encryptedContents) == 0 {
return nil, errors.New("should contain at least one encrypted data key")
}
// pick first one, the list is for future expansion of multiple encrypted data keys by different master key backend
encryptedContent := encryptedContents[0]
decryptedDataKey, err := m.masterKeyBackends.Decrypt(ctx, encryptedContent)
if err != nil {
return nil, errors.Annotate(err, "failed to decrypt data key using master key")
}
cipherInfo := backuppb.CipherInfo{
CipherType: fileEncryptionInfo.EncryptionMethod,
CipherKey: decryptedDataKey,
}
decryptedContent, err := utils.Decrypt(content, &cipherInfo, fileEncryptionInfo.FileIv)
if err != nil {
return nil, errors.Annotate(err, "failed to decrypt content using decrypted data key")
}
return decryptedContent, nil
default:
return nil, errors.Errorf("internal error: unsupported encryption mode type %T", mode)
}
}
func (m *Manager) Close() {
if m == nil {
return
}
if m.masterKeyBackends != nil {
m.masterKeyBackends.Close()
}
}

View File

@ -0,0 +1,44 @@
load("@io_bazel_rules_go//go:def.bzl", "go_library", "go_test")
go_library(
name = "master_key",
srcs = [
"common.go",
"file_backend.go",
"kms_backend.go",
"master_key.go",
"mem_backend.go",
"multi_master_key_backend.go",
],
importpath = "github.com/pingcap/tidb/br/pkg/encryption/master_key",
visibility = ["//visibility:public"],
deps = [
"//br/pkg/kms:aws",
"//br/pkg/utils",
"@com_github_pingcap_errors//:errors",
"@com_github_pingcap_kvproto//pkg/encryptionpb",
"@com_github_pingcap_log//:log",
"@org_uber_go_multierr//:multierr",
"@org_uber_go_zap//:zap",
],
)
go_test(
name = "master_key_test",
timeout = "short",
srcs = [
"file_backend_test.go",
"kms_backend_test.go",
"mem_backend_test.go",
"multi_master_key_backend_test.go",
],
embed = [":master_key"],
flaky = True,
shard_count = 11,
deps = [
"@com_github_pingcap_errors//:errors",
"@com_github_pingcap_kvproto//pkg/encryptionpb",
"@com_github_stretchr_testify//mock",
"@com_github_stretchr_testify//require",
],
)

View File

@ -0,0 +1,60 @@
// Copyright 2024 PingCAP, Inc. Licensed under Apache-2.0.
package encryption
import (
"crypto/rand"
"github.com/pingcap/errors"
)
// must keep it same with the constants in TiKV implementation
const (
MetadataKeyMethod string = "method"
MetadataKeyIv string = "iv"
MetadataKeyAesGcmTag string = "aes_gcm_tag"
MetadataKeyKmsVendor string = "kms_vendor"
MetadataKeyKmsCiphertextKey string = "kms_ciphertext_key"
MetadataMethodAes256Gcm string = "aes256-gcm"
)
const (
GcmIv12 = 12
CtrIv16 = 16
)
type IvType int
const (
IvTypeGcm IvType = iota
IvTypeCtr
)
type IV struct {
Type IvType
Data []byte
}
func NewIVGcm() (IV, error) {
iv := make([]byte, GcmIv12)
_, err := rand.Read(iv)
if err != nil {
return IV{}, err
}
return IV{Type: IvTypeGcm, Data: iv}, nil
}
func NewIVFromSlice(src []byte) (IV, error) {
switch len(src) {
case CtrIv16:
return IV{Type: IvTypeCtr, Data: append([]byte(nil), src...)}, nil
case GcmIv12:
return IV{Type: IvTypeGcm, Data: append([]byte(nil), src...)}, nil
default:
return IV{}, errors.Errorf("invalid IV length, must be 12 or 16 bytes, got %d", len(src))
}
}
func (iv IV) AsSlice() []byte {
return iv.Data
}

View File

@ -0,0 +1,68 @@
// Copyright 2024 PingCAP, Inc. Licensed under Apache-2.0.
package encryption
import (
"context"
"encoding/hex"
"os"
"github.com/pingcap/errors"
"github.com/pingcap/kvproto/pkg/encryptionpb"
)
const AesGcmKeyLen = 32 // AES-256 key length
// FileBackend is ported from TiKV FileBackend
type FileBackend struct {
memCache *MemAesGcmBackend
}
func createFileBackend(keyPath string) (*FileBackend, error) {
// FileBackend uses AES-256-GCM
keyLen := AesGcmKeyLen
content, err := os.ReadFile(keyPath)
if err != nil {
return nil, errors.Annotate(err, "failed to read master key file from disk")
}
fileLen := len(content)
expectedLen := keyLen*2 + 1 // hex-encoded key + newline
if fileLen != expectedLen {
return nil, errors.Errorf("mismatch master key file size, expected %d, actual %d", expectedLen, fileLen)
}
if content[fileLen-1] != '\n' {
return nil, errors.Errorf("master key file should end with newline")
}
key, err := hex.DecodeString(string(content[:fileLen-1]))
if err != nil {
return nil, errors.Annotate(err, "failed to decode hex format master key from file")
}
backend, err := NewMemAesGcmBackend(key)
if err != nil {
return nil, errors.Annotate(err, "failed to create MemAesGcmBackend")
}
return &FileBackend{memCache: backend}, nil
}
func (f *FileBackend) Encrypt(ctx context.Context, plaintext []byte) (*encryptionpb.EncryptedContent, error) {
iv, err := NewIVGcm()
if err != nil {
return nil, err
}
return f.memCache.EncryptContent(ctx, plaintext, iv)
}
func (f *FileBackend) Decrypt(ctx context.Context, content *encryptionpb.EncryptedContent) ([]byte, error) {
return f.memCache.DecryptContent(ctx, content)
}
func (f *FileBackend) Close() {
// nothing to close
}

View File

@ -0,0 +1,105 @@
// Copyright 2024 PingCAP, Inc. Licensed under Apache-2.0.
package encryption
import (
"context"
"encoding/hex"
"os"
"testing"
"github.com/stretchr/testify/require"
)
// TempKeyFile represents a temporary key file for testing
type TempKeyFile struct {
Path string
file *os.File
}
// Cleanup closes and removes the temporary file
func (tkf *TempKeyFile) Cleanup() {
if tkf.file != nil {
tkf.file.Close()
}
os.Remove(tkf.Path)
}
// createMasterKeyFile creates a temporary master key file for testing
func createMasterKeyFile() (*TempKeyFile, error) {
tempFile, err := os.CreateTemp("", "test_key_*.txt")
if err != nil {
return nil, err
}
_, err = tempFile.WriteString("c3d99825f2181f4808acd2068eac7441a65bd428f14d2aab43fefc0129091139\n")
if err != nil {
tempFile.Close()
os.Remove(tempFile.Name())
return nil, err
}
return &TempKeyFile{
Path: tempFile.Name(),
file: tempFile,
}, nil
}
func TestFileBackendAes256Gcm(t *testing.T) {
pt, err := hex.DecodeString("25431587e9ecffc7c37f8d6d52a9bc3310651d46fb0e3bad2726c8f2db653749")
require.NoError(t, err)
ct, err := hex.DecodeString("84e5f23f95648fa247cb28eef53abec947dbf05ac953734618111583840bd980")
require.NoError(t, err)
ivBytes, err := hex.DecodeString("cafabd9672ca6c79a2fbdc22")
require.NoError(t, err)
tempKeyFile, err := createMasterKeyFile()
require.NoError(t, err)
defer tempKeyFile.Cleanup()
backend, err := createFileBackend(tempKeyFile.Path)
require.NoError(t, err)
ctx := context.Background()
iv, err := NewIVFromSlice(ivBytes)
require.NoError(t, err)
encryptedContent, err := backend.memCache.EncryptContent(ctx, pt, iv)
require.NoError(t, err)
require.Equal(t, ct, encryptedContent.Content)
plaintext, err := backend.Decrypt(ctx, encryptedContent)
require.NoError(t, err)
require.Equal(t, pt, plaintext)
}
func TestFileBackendAuthenticate(t *testing.T) {
pt := []byte{1, 2, 3}
tempKeyFile, err := createMasterKeyFile()
require.NoError(t, err)
defer tempKeyFile.Cleanup()
backend, err := createFileBackend(tempKeyFile.Path)
require.NoError(t, err)
ctx := context.Background()
encryptedContent, err := backend.Encrypt(ctx, pt)
require.NoError(t, err)
plaintext, err := backend.Decrypt(ctx, encryptedContent)
require.NoError(t, err)
require.Equal(t, pt, plaintext)
// Test checksum mismatch
encryptedContent1 := *encryptedContent
encryptedContent1.Metadata[MetadataKeyAesGcmTag][0] ^= 0xFF
_, err = backend.Decrypt(ctx, &encryptedContent1)
require.ErrorContains(t, err, wrongMasterKey)
// Test checksum not found
encryptedContent2 := *encryptedContent
delete(encryptedContent2.Metadata, MetadataKeyAesGcmTag)
_, err = backend.Decrypt(ctx, &encryptedContent2)
require.ErrorContains(t, err, gcmTagNotFound)
}

View File

@ -0,0 +1,88 @@
// Copyright 2024 PingCAP, Inc. Licensed under Apache-2.0.
package encryption
import (
"context"
"sync"
"github.com/pingcap/errors"
"github.com/pingcap/kvproto/pkg/encryptionpb"
"github.com/pingcap/tidb/br/pkg/kms"
"github.com/pingcap/tidb/br/pkg/utils"
)
type CachedKeys struct {
encryptionBackend *MemAesGcmBackend
cachedCiphertextKey *kms.EncryptedKey
}
type KmsBackend struct {
state struct {
sync.Mutex
cached *CachedKeys
}
kmsProvider kms.Provider
}
func NewKmsBackend(kmsProvider kms.Provider) (*KmsBackend, error) {
return &KmsBackend{
kmsProvider: kmsProvider,
}, nil
}
func (k *KmsBackend) Decrypt(ctx context.Context, content *encryptionpb.EncryptedContent) ([]byte, error) {
vendorName := k.kmsProvider.Name()
if val, ok := content.Metadata[MetadataKeyKmsVendor]; !ok {
return nil, errors.New("wrong master key: missing KMS vendor")
} else if string(val) != vendorName {
return nil, errors.Errorf("KMS vendor mismatch expect %s got %s", vendorName, string(val))
}
ciphertextKeyBytes, ok := content.Metadata[MetadataKeyKmsCiphertextKey]
if !ok {
return nil, errors.New("KMS ciphertext key not found")
}
ciphertextKey, err := kms.NewEncryptedKey(ciphertextKeyBytes)
if err != nil {
return nil, errors.Annotate(err, "failed to create encrypted key")
}
k.state.Lock()
defer k.state.Unlock()
if k.state.cached != nil && k.state.cached.cachedCiphertextKey.Equal(&ciphertextKey) {
return k.state.cached.encryptionBackend.DecryptContent(ctx, content)
}
// piggyback on NewDownloadSSTBackoffer, a refactor is ongoing to remove all the backoffers
// so user don't need to write a backoffer for every type
decryptedKey, err :=
utils.WithRetryV2(ctx, utils.NewDownloadSSTBackoffer(), func(ctx context.Context) ([]byte, error) {
return k.kmsProvider.DecryptDataKey(ctx, ciphertextKey)
})
if err != nil {
return nil, errors.Annotate(err, "decrypt encrypted key failed")
}
plaintextKey, err := kms.NewPlainKey(decryptedKey, kms.CryptographyTypeAesGcm256)
if err != nil {
return nil, errors.Annotate(err, "decrypt encrypted key failed")
}
backend, err := NewMemAesGcmBackend(plaintextKey.Key())
if err != nil {
return nil, errors.Annotate(err, "failed to create MemAesGcmBackend")
}
k.state.cached = &CachedKeys{
encryptionBackend: backend,
cachedCiphertextKey: &ciphertextKey,
}
return k.state.cached.encryptionBackend.DecryptContent(ctx, content)
}
func (k *KmsBackend) Close() {
k.kmsProvider.Close()
}

View File

@ -0,0 +1,113 @@
// Copyright 2024 PingCAP, Inc. Licensed under Apache-2.0.
package encryption
import (
"context"
"crypto/rand"
"testing"
"github.com/pingcap/kvproto/pkg/encryptionpb"
"github.com/stretchr/testify/require"
)
type mockKmsProvider struct {
name string
decryptCounter int
}
func (m *mockKmsProvider) Name() string {
return m.name
}
func (m *mockKmsProvider) DecryptDataKey(_ctx context.Context, _encryptedKey []byte) ([]byte, error) {
m.decryptCounter++
key := make([]byte, 32) // 256 bits = 32 bytes
_, err := rand.Read(key)
if err != nil {
return nil, err
}
return key, nil
}
func (m *mockKmsProvider) Close() {
// do nothing
}
func TestKmsBackendDecrypt(t *testing.T) {
ctx := context.Background()
mockProvider := &mockKmsProvider{name: "mock_kms"}
backend, err := NewKmsBackend(mockProvider)
require.NoError(t, err)
ciphertextKey := []byte("ciphertext_key")
content := &encryptionpb.EncryptedContent{
Metadata: map[string][]byte{
MetadataKeyKmsVendor: []byte("mock_kms"),
MetadataKeyKmsCiphertextKey: ciphertextKey,
},
Content: []byte("encrypted_content"),
}
// First decryption
_, _ = backend.Decrypt(ctx, content)
require.Equal(t, 1, mockProvider.decryptCounter, "KMS provider should be called once")
// Second decryption with the same ciphertext key (should use cache)
_, _ = backend.Decrypt(ctx, content)
require.Equal(t, 1, mockProvider.decryptCounter, "KMS provider should not be called again")
// Third decryption with a different ciphertext key
content.Metadata[MetadataKeyKmsCiphertextKey] = []byte("new_ciphertext_key")
_, _ = backend.Decrypt(ctx, content)
require.Equal(t, 2, mockProvider.decryptCounter, "KMS provider should be called again for a new key")
}
func TestKmsBackendDecryptErrors(t *testing.T) {
ctx := context.Background()
mockProvider := &mockKmsProvider{name: "mock_kms"}
backend, err := NewKmsBackend(mockProvider)
require.NoError(t, err)
testCases := []struct {
name string
content *encryptionpb.EncryptedContent
errMsg string
}{
{
name: "missing KMS vendor",
content: &encryptionpb.EncryptedContent{
Metadata: map[string][]byte{
MetadataKeyKmsCiphertextKey: []byte("ciphertext_key"),
},
},
errMsg: "wrong master key: missing KMS vendor",
},
{
name: "KMS vendor mismatch",
content: &encryptionpb.EncryptedContent{
Metadata: map[string][]byte{
MetadataKeyKmsVendor: []byte("wrong_kms"),
MetadataKeyKmsCiphertextKey: []byte("ciphertext_key"),
},
},
errMsg: "KMS vendor mismatch expect mock_kms got wrong_kms",
},
{
name: "missing ciphertext key",
content: &encryptionpb.EncryptedContent{
Metadata: map[string][]byte{
MetadataKeyKmsVendor: []byte("mock_kms"),
},
},
errMsg: "KMS ciphertext key not found",
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
_, err := backend.Decrypt(ctx, tc.content)
require.ErrorContains(t, err, tc.errMsg)
})
}
}

View File

@ -0,0 +1,77 @@
// Copyright 2024 PingCAP, Inc. Licensed under Apache-2.0.
package encryption
import (
"context"
"github.com/pingcap/errors"
"github.com/pingcap/kvproto/pkg/encryptionpb"
"github.com/pingcap/log"
"github.com/pingcap/tidb/br/pkg/kms"
"go.uber.org/zap"
)
const (
StorageVendorNameAWS = "aws"
StorageVendorNameAzure = "azure"
StorageVendorNameGCP = "gcp"
)
// Backend is an interface that defines the methods required for an encryption backend.
type Backend interface {
// Decrypt takes an EncryptedContent and returns the decrypted plaintext as a byte slice or an error.
Decrypt(ctx context.Context, ciphertext *encryptionpb.EncryptedContent) ([]byte, error)
Close()
}
func CreateBackend(config *encryptionpb.MasterKey) (Backend, error) {
if config == nil {
return nil, errors.Errorf("master key config is nil")
}
switch backend := config.Backend.(type) {
case *encryptionpb.MasterKey_Plaintext:
// should not plaintext type as guarded by caller
return nil, errors.New("should not create plaintext master key")
case *encryptionpb.MasterKey_File:
fileBackend, err := createFileBackend(backend.File.Path)
if err != nil {
return nil, errors.Annotate(err, "master key config is nil")
}
return fileBackend, nil
case *encryptionpb.MasterKey_Kms:
return createCloudBackend(backend.Kms)
default:
return nil, errors.New("unknown master key backend type")
}
}
func createCloudBackend(config *encryptionpb.MasterKeyKms) (Backend, error) {
log.Info("creating cloud KMS backend",
zap.String("region", config.GetRegion()),
zap.String("endpoint", config.GetEndpoint()),
zap.String("key_id", config.GetKeyId()),
zap.String("Vendor", config.GetVendor()))
switch config.Vendor {
case StorageVendorNameAWS:
kmsProvider, err := kms.NewAwsKms(config)
if err != nil {
return nil, errors.Annotate(err, "new AWS KMS")
}
return NewKmsBackend(kmsProvider)
case StorageVendorNameAzure:
return nil, errors.Errorf("not implemented Azure KMS")
case StorageVendorNameGCP:
kmsProvider, err := kms.NewGcpKms(config)
if err != nil {
return nil, errors.Annotate(err, "new GCP KMS")
}
return NewKmsBackend(kmsProvider)
default:
return nil, errors.Errorf("vendor not found: %s", config.Vendor)
}
}

View File

@ -0,0 +1,103 @@
// Copyright 2024 PingCAP, Inc. Licensed under Apache-2.0.
package encryption
import (
"context"
"crypto/aes"
"crypto/cipher"
"github.com/pingcap/errors"
"github.com/pingcap/kvproto/pkg/encryptionpb"
"github.com/pingcap/tidb/br/pkg/kms"
)
const (
gcmTagNotFound = "aes gcm tag not found"
wrongMasterKey = "wrong master key"
)
type MemAesGcmBackend struct {
key *kms.PlainKey
}
func NewMemAesGcmBackend(key []byte) (*MemAesGcmBackend, error) {
plainKey, err := kms.NewPlainKey(key, kms.CryptographyTypeAesGcm256)
if err != nil {
return nil, errors.Annotate(err, "failed to create new mem aes gcm backend")
}
return &MemAesGcmBackend{
key: plainKey,
}, nil
}
func (m *MemAesGcmBackend) EncryptContent(_ctx context.Context, plaintext []byte, iv IV) (
*encryptionpb.EncryptedContent, error) {
content := encryptionpb.EncryptedContent{
Metadata: make(map[string][]byte),
}
content.Metadata[MetadataKeyMethod] = []byte(MetadataMethodAes256Gcm)
content.Metadata[MetadataKeyIv] = iv.AsSlice()
block, err := aes.NewCipher(m.key.Key())
if err != nil {
return nil, err
}
aesgcm, err := cipher.NewGCM(block)
if err != nil {
return nil, err
}
// The Seal function in AES-GCM mode appends the authentication tag to the ciphertext.
// We need to separate the actual ciphertext from the tag for storage and later verification.
// Reference: https://pkg.go.dev/crypto/cipher#AEAD
ciphertext := aesgcm.Seal(nil, iv.AsSlice(), plaintext, nil)
content.Content = ciphertext[:len(ciphertext)-aesgcm.Overhead()]
content.Metadata[MetadataKeyAesGcmTag] = ciphertext[len(ciphertext)-aesgcm.Overhead():]
return &content, nil
}
func (m *MemAesGcmBackend) DecryptContent(_ctx context.Context, content *encryptionpb.EncryptedContent) (
[]byte, error) {
method, ok := content.Metadata[MetadataKeyMethod]
if !ok {
return nil, errors.Errorf("metadata %s not found", MetadataKeyMethod)
}
if string(method) != MetadataMethodAes256Gcm {
return nil, errors.Errorf("encryption method mismatch, expected %s vs actual %s",
MetadataMethodAes256Gcm, method)
}
ivValue, ok := content.Metadata[MetadataKeyIv]
if !ok {
return nil, errors.Errorf("metadata %s not found", MetadataKeyIv)
}
iv, err := NewIVFromSlice(ivValue)
if err != nil {
return nil, err
}
tag, ok := content.Metadata[MetadataKeyAesGcmTag]
if !ok {
return nil, errors.New("aes gcm tag not found")
}
block, err := aes.NewCipher(m.key.Key())
if err != nil {
return nil, err
}
aesgcm, err := cipher.NewGCM(block)
if err != nil {
return nil, err
}
ciphertext := append(content.Content, tag...)
plaintext, err := aesgcm.Open(nil, iv.AsSlice(), ciphertext, nil)
if err != nil {
return nil, errors.Annotate(err, wrongMasterKey+" :decrypt in GCM mode failed")
}
return plaintext, nil
}

View File

@ -0,0 +1,115 @@
// Copyright 2024 PingCAP, Inc. Licensed under Apache-2.0.
package encryption
import (
"bytes"
"context"
"testing"
"github.com/stretchr/testify/require"
)
func TestNewMemAesGcmBackend(t *testing.T) {
key := make([]byte, 32) // 256-bit key
_, err := NewMemAesGcmBackend(key)
require.NoError(t, err, "Failed to create MemAesGcmBackend")
shortKey := make([]byte, 16)
_, err = NewMemAesGcmBackend(shortKey)
require.Error(t, err, "Expected error for short key")
}
func TestEncryptDecrypt(t *testing.T) {
key := make([]byte, 32)
backend, err := NewMemAesGcmBackend(key)
require.NoError(t, err, "Failed to create MemAesGcmBackend")
plaintext := []byte("Hello, World!")
iv, err := NewIVGcm()
require.NoError(t, err, "failed to create gcm iv")
ctx := context.Background()
encrypted, err := backend.EncryptContent(ctx, plaintext, iv)
require.NoError(t, err, "Encryption failed")
decrypted, err := backend.DecryptContent(ctx, encrypted)
require.NoError(t, err, "Decryption failed")
require.Equal(t, plaintext, decrypted, "Decrypted text doesn't match original")
}
func TestDecryptWithWrongKey(t *testing.T) {
key1 := make([]byte, 32)
key2 := make([]byte, 32)
for i := range key2 {
key2[i] = 1 // Different from key1
}
backend1, _ := NewMemAesGcmBackend(key1)
backend2, _ := NewMemAesGcmBackend(key2)
plaintext := []byte("Hello, World!")
iv, err := NewIVGcm()
require.NoError(t, err, "failed to create gcm iv")
ctx := context.Background()
encrypted, _ := backend1.EncryptContent(ctx, plaintext, iv)
_, err = backend2.DecryptContent(ctx, encrypted)
require.Error(t, err, "Expected decryption with wrong key to fail")
}
func TestDecryptWithTamperedCiphertext(t *testing.T) {
key := make([]byte, 32)
backend, _ := NewMemAesGcmBackend(key)
plaintext := []byte("Hello, World!")
iv, err := NewIVGcm()
require.NoError(t, err, "failed to create gcm iv")
ctx := context.Background()
encrypted, _ := backend.EncryptContent(ctx, plaintext, iv)
encrypted.Content[0] ^= 1 // Tamper with the ciphertext
_, err = backend.DecryptContent(ctx, encrypted)
require.Error(t, err, "Expected decryption of tampered ciphertext to fail")
}
func TestDecryptWithMissingMetadata(t *testing.T) {
key := make([]byte, 32)
backend, _ := NewMemAesGcmBackend(key)
plaintext := []byte("Hello, World!")
iv, err := NewIVGcm()
require.NoError(t, err, "failed to create gcm iv")
ctx := context.Background()
encrypted, _ := backend.EncryptContent(ctx, plaintext, iv)
delete(encrypted.Metadata, MetadataKeyMethod)
_, err = backend.DecryptContent(ctx, encrypted)
require.Error(t, err, "Expected decryption with missing metadata to fail")
}
func TestEncryptDecryptLargeData(t *testing.T) {
key := make([]byte, 32)
backend, _ := NewMemAesGcmBackend(key)
plaintext := make([]byte, 1000000) // 1 MB of data
iv, err := NewIVGcm()
require.NoError(t, err, "failed to create gcm iv")
ctx := context.Background()
encrypted, err := backend.EncryptContent(ctx, plaintext, iv)
require.NoError(t, err, "Encryption of large data failed")
decrypted, err := backend.DecryptContent(ctx, encrypted)
require.NoError(t, err, "Decryption of large data failed")
require.True(t, bytes.Equal(plaintext, decrypted), "Decrypted large data doesn't match original")
}

View File

@ -0,0 +1,64 @@
// Copyright 2024 PingCAP, Inc. Licensed under Apache-2.0.
package encryption
import (
"context"
"github.com/pingcap/errors"
"github.com/pingcap/kvproto/pkg/encryptionpb"
"go.uber.org/multierr"
)
const (
defaultBackendCapacity = 5
)
// MultiMasterKeyBackend can contain multiple master shard backends.
// If any one of those backends successfully decrypts the data, the data will be returned.
// The main purpose of this backend is to provide a high availability for master key in the future.
// Right now only one master key backend is used to encrypt/decrypt data.
type MultiMasterKeyBackend struct {
backends []Backend
}
func NewMultiMasterKeyBackend(masterKeysProto []*encryptionpb.MasterKey) (*MultiMasterKeyBackend, error) {
if masterKeysProto == nil && len(masterKeysProto) == 0 {
return nil, errors.New("must provide at least one master key")
}
var backends = make([]Backend, 0, defaultBackendCapacity)
for _, masterKeyProto := range masterKeysProto {
backend, err := CreateBackend(masterKeyProto)
if err != nil {
return nil, errors.Trace(err)
}
backends = append(backends, backend)
}
return &MultiMasterKeyBackend{
backends: backends,
}, nil
}
func (m *MultiMasterKeyBackend) Decrypt(ctx context.Context, encryptedContent *encryptionpb.EncryptedContent) (
[]byte, error) {
if len(m.backends) == 0 {
return nil, errors.New("internal error: should always contain at least one backend")
}
var err error
for _, masterKeyBackend := range m.backends {
res, decryptErr := masterKeyBackend.Decrypt(ctx, encryptedContent)
if decryptErr == nil {
return res, nil
}
err = multierr.Append(err, decryptErr)
}
return nil, errors.Wrap(err, "failed to decrypt in multi master key backend")
}
func (m *MultiMasterKeyBackend) Close() {
for _, backend := range m.backends {
backend.Close()
}
}

View File

@ -0,0 +1,105 @@
// Copyright 2024 PingCAP, Inc. Licensed under Apache-2.0.
package encryption
import (
"context"
"testing"
"github.com/pingcap/errors"
"github.com/pingcap/kvproto/pkg/encryptionpb"
"github.com/stretchr/testify/mock"
"github.com/stretchr/testify/require"
)
// MockBackend is a mock implementation of the Backend interface
type MockBackend struct {
mock.Mock
}
func (m *MockBackend) Decrypt(ctx context.Context, encryptedContent *encryptionpb.EncryptedContent) ([]byte, error) {
args := m.Called(ctx, encryptedContent)
// The first return value should be []byte or nil
if ret := args.Get(0); ret != nil {
return ret.([]byte), args.Error(1)
}
return nil, args.Error(1)
}
func (m *MockBackend) Close() {
// do nothing
}
func TestMultiMasterKeyBackendDecrypt(t *testing.T) {
ctx := context.Background()
encryptedContent := &encryptionpb.EncryptedContent{Content: []byte("encrypted")}
t.Run("success first backend", func(t *testing.T) {
mock1 := new(MockBackend)
mock1.On("Decrypt", ctx, encryptedContent).Return([]byte("decrypted"), nil)
mock2 := new(MockBackend)
backend := &MultiMasterKeyBackend{
backends: []Backend{mock1, mock2},
}
result, err := backend.Decrypt(ctx, encryptedContent)
require.NoError(t, err)
require.Equal(t, []byte("decrypted"), result)
mock1.AssertExpectations(t)
mock2.AssertNotCalled(t, "Decrypt")
})
t.Run("success second backend", func(t *testing.T) {
mock1 := new(MockBackend)
mock1.On("Decrypt", ctx, encryptedContent).Return(nil, errors.New("failed"))
mock2 := new(MockBackend)
mock2.On("Decrypt", ctx, encryptedContent).Return([]byte("decrypted"), nil)
backend := &MultiMasterKeyBackend{
backends: []Backend{mock1, mock2},
}
result, err := backend.Decrypt(ctx, encryptedContent)
require.NoError(t, err)
require.Equal(t, []byte("decrypted"), result)
mock1.AssertExpectations(t)
mock2.AssertExpectations(t)
})
t.Run("all backends fail", func(t *testing.T) {
mock1 := new(MockBackend)
mock1.On("Decrypt", ctx, encryptedContent).Return(nil, errors.New("failed1"))
mock2 := new(MockBackend)
mock2.On("Decrypt", ctx, encryptedContent).Return(nil, errors.New("failed2"))
backend := &MultiMasterKeyBackend{
backends: []Backend{mock1, mock2},
}
result, err := backend.Decrypt(ctx, encryptedContent)
require.Error(t, err)
require.Nil(t, result)
require.Contains(t, err.Error(), "failed1")
require.Contains(t, err.Error(), "failed2")
mock1.AssertExpectations(t)
mock2.AssertExpectations(t)
})
t.Run("no backends", func(t *testing.T) {
backend := &MultiMasterKeyBackend{
backends: []Backend{},
}
result, err := backend.Decrypt(ctx, encryptedContent)
require.Error(t, err)
require.Nil(t, result)
require.Contains(t, err.Error(), "internal error")
})
}

27
br/pkg/kms/BUILD.bazel Normal file
View File

@ -0,0 +1,27 @@
load("@io_bazel_rules_go//go:def.bzl", "go_library")
go_library(
name = "aws",
srcs = [
"aws.go",
"common.go",
"gcp.go",
"kms.go",
],
importpath = "github.com/pingcap/tidb/br/pkg/kms",
visibility = ["//visibility:public"],
deps = [
"@com_github_aws_aws_sdk_go//aws",
"@com_github_aws_aws_sdk_go//aws/credentials",
"@com_github_aws_aws_sdk_go//aws/session",
"@com_github_aws_aws_sdk_go//service/kms",
"@com_github_pingcap_errors//:errors",
"@com_github_pingcap_kvproto//pkg/encryptionpb",
"@com_github_pingcap_log//:log",
"@com_google_cloud_go_kms//apiv1",
"@com_google_cloud_go_kms//apiv1/kmspb",
"@org_golang_google_api//option",
"@org_golang_google_protobuf//types/known/wrapperspb",
"@org_uber_go_zap//:zap",
],
)

92
br/pkg/kms/aws.go Normal file
View File

@ -0,0 +1,92 @@
// Copyright 2024 PingCAP, Inc. Licensed under Apache-2.0.
package kms
import (
"context"
"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/credentials"
"github.com/aws/aws-sdk-go/aws/session"
"github.com/aws/aws-sdk-go/service/kms"
pErrors "github.com/pingcap/errors"
"github.com/pingcap/kvproto/pkg/encryptionpb"
)
const (
// need to keep it exact same as in TiKV ENCRYPTION_VENDOR_NAME_AWS_KMS
EncryptionVendorNameAwsKms = "AWS"
)
type AwsKms struct {
client *kms.KMS
currentKeyID string
region string
endpoint string
}
func NewAwsKms(masterKeyConfig *encryptionpb.MasterKeyKms) (*AwsKms, error) {
config := &aws.Config{
Region: aws.String(masterKeyConfig.Region),
Endpoint: aws.String(masterKeyConfig.Endpoint),
}
// Only use static credentials if both access key and secret key are provided
if masterKeyConfig.AwsKms != nil &&
masterKeyConfig.AwsKms.AccessKey != "" &&
masterKeyConfig.AwsKms.SecretAccessKey != "" {
config.Credentials = credentials.NewStaticCredentials(
masterKeyConfig.AwsKms.AccessKey,
masterKeyConfig.AwsKms.SecretAccessKey,
"",
)
}
sess, err := session.NewSession(config)
if err != nil {
return nil, pErrors.Annotate(err, "failed to create AWS session")
}
return &AwsKms{
client: kms.New(sess),
currentKeyID: masterKeyConfig.KeyId,
region: masterKeyConfig.Region,
endpoint: masterKeyConfig.Endpoint,
}, nil
}
func (a *AwsKms) Name() string {
return EncryptionVendorNameAwsKms
}
func (a *AwsKms) DecryptDataKey(ctx context.Context, dataKey []byte) ([]byte, error) {
input := &kms.DecryptInput{
CiphertextBlob: dataKey,
KeyId: aws.String(a.currentKeyID),
}
result, err := a.client.DecryptWithContext(ctx, input)
if err != nil {
return nil, classifyDecryptError(err)
}
return result.Plaintext, nil
}
func (a *AwsKms) Close() {
// don't need to do manual close
}
// Update classifyDecryptError to use v1 SDK error types
func classifyDecryptError(err error) error {
switch err := err.(type) {
case *kms.NotFoundException, *kms.InvalidKeyUsageException:
return pErrors.Annotate(err, "wrong master key")
case *kms.DependencyTimeoutException:
return pErrors.Annotate(err, "API timeout")
case *kms.InternalException:
return pErrors.Annotate(err, "API internal error")
default:
return pErrors.Annotate(err, "KMS error")
}
}

65
br/pkg/kms/common.go Normal file
View File

@ -0,0 +1,65 @@
// Copyright 2024 PingCAP, Inc. Licensed under Apache-2.0.
package kms
import (
"bytes"
"github.com/pingcap/errors"
)
// EncryptedKey is used to mark data as an encrypted key
type EncryptedKey []byte
func NewEncryptedKey(key []byte) (EncryptedKey, error) {
if len(key) == 0 {
return nil, errors.New("encrypted key cannot be empty")
}
return key, nil
}
// Equal method for EncryptedKey
func (e EncryptedKey) Equal(other *EncryptedKey) bool {
return bytes.Equal(e, *other)
}
// CryptographyType represents different cryptography methods
type CryptographyType int
const (
CryptographyTypePlain CryptographyType = iota
CryptographyTypeAesGcm256
)
func (c CryptographyType) TargetKeySize() int {
switch c {
case CryptographyTypePlain:
return 0 // Plain text has no limitation
case CryptographyTypeAesGcm256:
return 32
default:
return 0
}
}
// PlainKey is used to mark a byte slice as a plaintext key
type PlainKey struct {
tag CryptographyType
key []byte
}
func NewPlainKey(key []byte, t CryptographyType) (*PlainKey, error) {
limitation := t.TargetKeySize()
if limitation > 0 && len(key) != limitation {
return nil, errors.Errorf("encryption method and key length mismatch, expect %d got %d", limitation, len(key))
}
return &PlainKey{key: key, tag: t}, nil
}
func (p *PlainKey) KeyTag() CryptographyType {
return p.tag
}
func (p *PlainKey) Key() []byte {
return p.key
}

104
br/pkg/kms/gcp.go Normal file
View File

@ -0,0 +1,104 @@
// Copyright 2024 PingCAP, Inc. Licensed under Apache-2.0.
package kms
import (
"context"
"hash/crc32"
"strings"
"cloud.google.com/go/kms/apiv1"
"cloud.google.com/go/kms/apiv1/kmspb"
"github.com/pingcap/errors"
"github.com/pingcap/kvproto/pkg/encryptionpb"
"github.com/pingcap/log"
"go.uber.org/zap"
"google.golang.org/api/option"
"google.golang.org/protobuf/types/known/wrapperspb"
)
const (
// need to keep it exactly same as TiKV STORAGE_VENDOR_NAME_GCP in TiKV
StorageVendorNameGcp = "gcp"
)
type GcpKms struct {
config *encryptionpb.MasterKeyKms
// the location prefix of key id,
// format: projects/{project_name}/locations/{location}
location string
client *kms.KeyManagementClient
}
func NewGcpKms(config *encryptionpb.MasterKeyKms) (*GcpKms, error) {
if config.GcpKms == nil {
return nil, errors.New("GCP config is missing")
}
// config string pattern verified at parsing flag phase, we should have valid string at this stage.
config.KeyId = strings.TrimSuffix(config.KeyId, "/")
// join the first 4 parts of the key id to get the location
location := strings.Join(strings.Split(config.KeyId, "/")[:4], "/")
ctx := context.Background()
var clientOpt option.ClientOption
if config.GcpKms.Credential != "" {
clientOpt = option.WithCredentialsFile(config.GcpKms.Credential)
}
client, err := kms.NewKeyManagementClient(ctx, clientOpt)
if err != nil {
return nil, errors.Errorf("failed to create GCP KMS client: %v", err)
}
return &GcpKms{
config: config,
location: location,
client: client,
}, nil
}
func (g *GcpKms) Name() string {
return StorageVendorNameGcp
}
func (g *GcpKms) DecryptDataKey(ctx context.Context, dataKey []byte) ([]byte, error) {
req := &kmspb.DecryptRequest{
Name: g.config.KeyId,
Ciphertext: dataKey,
CiphertextCrc32C: wrapperspb.Int64(int64(g.calculateCRC32C(dataKey))),
}
resp, err := g.client.Decrypt(ctx, req)
if err != nil {
return nil, errors.Annotate(err, "gcp kms decrypt request failed")
}
if int64(g.calculateCRC32C(resp.Plaintext)) != resp.PlaintextCrc32C.Value {
return nil, errors.New("response corrupted in-transit")
}
return resp.Plaintext, nil
}
func (g *GcpKms) checkCRC32(data []byte, expected int64) error {
crc := int64(g.calculateCRC32C(data))
if crc != expected {
return errors.Errorf("crc32c mismatch, expected: %d, got: %d", expected, crc)
}
return nil
}
func (g *GcpKms) calculateCRC32C(data []byte) uint32 {
t := crc32.MakeTable(crc32.Castagnoli)
return crc32.Checksum(data, t)
}
func (g *GcpKms) Close() {
err := g.client.Close()
if err != nil {
log.Error("failed to close gcp kms client", zap.Error(err))
}
}

13
br/pkg/kms/kms.go Normal file
View File

@ -0,0 +1,13 @@
// Copyright 2024 PingCAP, Inc. Licensed under Apache-2.0.
package kms
import "context"
// Provider is an interface for key management service providers
// implement encrypt data key in future if needed
type Provider interface {
DecryptDataKey(ctx context.Context, dataKey []byte) ([]byte, error)
Name() string
Close()
}

View File

@ -14,6 +14,7 @@ go_library(
"//br/pkg/logutil",
"//br/pkg/storage",
"//br/pkg/summary",
"//br/pkg/utils",
"//pkg/meta/model",
"//pkg/statistics/handle",
"//pkg/statistics/handle/types",
@ -47,6 +48,7 @@ go_test(
shard_count = 9,
deps = [
"//br/pkg/storage",
"//br/pkg/utils",
"//pkg/meta/model",
"//pkg/parser/model",
"//pkg/statistics/handle/types",

View File

@ -23,6 +23,7 @@ import (
"github.com/pingcap/tidb/br/pkg/logutil"
"github.com/pingcap/tidb/br/pkg/storage"
"github.com/pingcap/tidb/br/pkg/summary"
"github.com/pingcap/tidb/br/pkg/utils"
"github.com/pingcap/tidb/pkg/meta/model"
"github.com/pingcap/tidb/pkg/statistics/handle/util"
"github.com/pingcap/tidb/pkg/tablecodec"
@ -87,22 +88,17 @@ func Encrypt(content []byte, cipher *backuppb.CipherInfo) (encryptedContent, iv
}
}
// Decrypt decrypts the content according to CipherInfo and IV.
func Decrypt(content []byte, cipher *backuppb.CipherInfo, iv []byte) ([]byte, error) {
if len(content) == 0 || cipher == nil {
return content, nil
func DecryptFullBackupMetaIfNeeded(metaData []byte, cipherInfo *backuppb.CipherInfo) ([]byte, error) {
if cipherInfo == nil || !utils.IsEffectiveEncryptionMethod(cipherInfo.CipherType) {
return metaData, nil
}
switch cipher.CipherType {
case encryptionpb.EncryptionMethod_PLAINTEXT:
return content, nil
case encryptionpb.EncryptionMethod_AES128_CTR,
encryptionpb.EncryptionMethod_AES192_CTR,
encryptionpb.EncryptionMethod_AES256_CTR:
return encrypt.AESDecryptWithCTR(content, cipher.CipherKey, iv)
default:
return content, errors.Annotate(berrors.ErrInvalidArgument, "cipher type invalid")
// the prefix of backup meta file is iv(16 bytes) for ctr mode if encryption method is valid
iv := metaData[:CrypterIvLen]
decryptBackupMeta, err := utils.Decrypt(metaData[len(iv):], cipherInfo, iv)
if err != nil {
return nil, errors.Annotate(err, "decrypt failed with wrong key")
}
return decryptBackupMeta, nil
}
// walkLeafMetaFile walks the leaves of the given metafile, and deal with it by calling the function `output`.
@ -130,7 +126,7 @@ func walkLeafMetaFile(
return errors.Trace(err)
}
decryptContent, err := Decrypt(content, cipher, node.CipherIv)
decryptContent, err := utils.Decrypt(content, cipher, node.CipherIv)
if err != nil {
return errors.Trace(err)
}

View File

@ -11,6 +11,7 @@ import (
backuppb "github.com/pingcap/kvproto/pkg/brpb"
"github.com/pingcap/kvproto/pkg/encryptionpb"
"github.com/pingcap/tidb/br/pkg/storage"
"github.com/pingcap/tidb/br/pkg/utils"
"github.com/stretchr/testify/require"
)
@ -203,14 +204,14 @@ func TestEncryptAndDecrypt(t *testing.T) {
require.NoError(t, err)
require.Equal(t, originalData, encryptData)
decryptData, err := Decrypt(encryptData, &cipher, iv)
decryptData, err := utils.Decrypt(encryptData, &cipher, iv)
require.NoError(t, err)
require.Equal(t, decryptData, originalData)
} else {
require.NoError(t, err)
require.NotEqual(t, originalData, encryptData)
decryptData, err := Decrypt(encryptData, &cipher, iv)
decryptData, err := utils.Decrypt(encryptData, &cipher, iv)
require.NoError(t, err)
require.Equal(t, decryptData, originalData)
@ -218,7 +219,7 @@ func TestEncryptAndDecrypt(t *testing.T) {
CipherType: v.method,
CipherKey: []byte(v.wrongKey),
}
decryptData, err = Decrypt(encryptData, &wrongCipher, iv)
decryptData, err = utils.Decrypt(encryptData, &wrongCipher, iv)
if len(v.rightKey) != len(v.wrongKey) {
require.Error(t, err)
} else {

View File

@ -26,6 +26,7 @@ import (
backuppb "github.com/pingcap/kvproto/pkg/brpb"
berrors "github.com/pingcap/tidb/br/pkg/errors"
"github.com/pingcap/tidb/br/pkg/storage"
"github.com/pingcap/tidb/br/pkg/utils"
"github.com/pingcap/tidb/pkg/meta/model"
"github.com/pingcap/tidb/pkg/statistics/handle"
statstypes "github.com/pingcap/tidb/pkg/statistics/handle/types"
@ -201,7 +202,7 @@ func downloadStats(
return errors.Trace(err)
}
decryptContent, err := Decrypt(content, cipher, statsFile.CipherIv)
decryptContent, err := utils.Decrypt(content, cipher, statsFile.CipherIv)
if err != nil {
return errors.Trace(err)
}

View File

@ -16,6 +16,7 @@ go_library(
"//br/pkg/checksum",
"//br/pkg/conn",
"//br/pkg/conn/util",
"//br/pkg/encryption",
"//br/pkg/errors",
"//br/pkg/glue",
"//br/pkg/logutil",
@ -49,6 +50,7 @@ go_library(
"@com_github_pingcap_errors//:errors",
"@com_github_pingcap_failpoint//:failpoint",
"@com_github_pingcap_kvproto//pkg/brpb",
"@com_github_pingcap_kvproto//pkg/encryptionpb",
"@com_github_pingcap_kvproto//pkg/errorpb",
"@com_github_pingcap_kvproto//pkg/import_sstpb",
"@com_github_pingcap_kvproto//pkg/kvrpcpb",
@ -113,6 +115,7 @@ go_test(
"@com_github_pingcap_errors//:errors",
"@com_github_pingcap_failpoint//:failpoint",
"@com_github_pingcap_kvproto//pkg/brpb",
"@com_github_pingcap_kvproto//pkg/encryptionpb",
"@com_github_pingcap_kvproto//pkg/errorpb",
"@com_github_pingcap_kvproto//pkg/import_sstpb",
"@com_github_pingcap_kvproto//pkg/metapb",

View File

@ -33,11 +33,13 @@ import (
"github.com/pingcap/errors"
"github.com/pingcap/failpoint"
backuppb "github.com/pingcap/kvproto/pkg/brpb"
"github.com/pingcap/kvproto/pkg/encryptionpb"
"github.com/pingcap/log"
"github.com/pingcap/tidb/br/pkg/checkpoint"
"github.com/pingcap/tidb/br/pkg/checksum"
"github.com/pingcap/tidb/br/pkg/conn"
"github.com/pingcap/tidb/br/pkg/conn/util"
"github.com/pingcap/tidb/br/pkg/encryption"
"github.com/pingcap/tidb/br/pkg/glue"
"github.com/pingcap/tidb/br/pkg/logutil"
"github.com/pingcap/tidb/br/pkg/metautil"
@ -293,13 +295,15 @@ func (rc *LogClient) InitCheckpointMetadataForLogRestore(
return gcRatio, nil
}
func (rc *LogClient) InstallLogFileManager(ctx context.Context, startTS, restoreTS uint64, metadataDownloadBatchSize uint) error {
func (rc *LogClient) InstallLogFileManager(ctx context.Context, startTS, restoreTS uint64, metadataDownloadBatchSize uint,
encryptionManager *encryption.Manager) error {
init := LogFileManagerInit{
StartTS: startTS,
RestoreTS: restoreTS,
Storage: rc.storage,
MetadataDownloadBatchSize: metadataDownloadBatchSize,
EncryptionManager: encryptionManager,
}
var err error
rc.LogFileManager, err = CreateLogFileManager(ctx, init)
@ -436,7 +440,7 @@ func ApplyKVFilesWithBatchMethod(
return nil
}
func ApplyKVFilesWithSingelMethod(
func ApplyKVFilesWithSingleMethod(
ctx context.Context,
files LogIter,
applyFunc func(file []*LogDataFileInfo, kvCount int64, size uint64),
@ -477,6 +481,8 @@ func (rc *LogClient) RestoreKVFiles(
pitrBatchSize uint32,
updateStats func(kvCount uint64, size uint64),
onProgress func(cnt int64),
cipherInfo *backuppb.CipherInfo,
masterKeys []*encryptionpb.MasterKey,
) error {
var (
err error
@ -488,7 +494,7 @@ func (rc *LogClient) RestoreKVFiles(
defer func() {
if err == nil {
elapsed := time.Since(start)
log.Info("Restore KV files", zap.Duration("take", elapsed))
log.Info("Restored KV files", zap.Duration("take", elapsed))
summary.CollectSuccessUnit("files", fileCount, elapsed)
}
}()
@ -548,7 +554,8 @@ func (rc *LogClient) RestoreKVFiles(
}
}()
return rc.fileImporter.ImportKVFiles(ectx, files, rule, rc.shiftStartTS, rc.startTS, rc.restoreTS, supportBatch)
return rc.fileImporter.ImportKVFiles(ectx, files, rule, rc.shiftStartTS, rc.startTS, rc.restoreTS,
supportBatch, cipherInfo, masterKeys)
})
}
}
@ -557,7 +564,7 @@ func (rc *LogClient) RestoreKVFiles(
if supportBatch {
err = ApplyKVFilesWithBatchMethod(ectx, logIter, int(pitrBatchCount), uint64(pitrBatchSize), applyFunc, &applyWg)
} else {
err = ApplyKVFilesWithSingelMethod(ectx, logIter, applyFunc, &applyWg)
err = ApplyKVFilesWithSingleMethod(ectx, logIter, applyFunc, &applyWg)
}
return errors.Trace(err)
})
@ -619,18 +626,25 @@ func initFullBackupTables(
ctx context.Context,
s storage.ExternalStorage,
tableFilter filter.Filter,
cipherInfo *backuppb.CipherInfo,
) (map[int64]*metautil.Table, error) {
metaData, err := s.ReadFile(ctx, metautil.MetaFile)
if err != nil {
return nil, errors.Trace(err)
}
backupMetaBytes, err := metautil.DecryptFullBackupMetaIfNeeded(metaData, cipherInfo)
if err != nil {
return nil, errors.Annotate(err, "decrypt failed with wrong key")
}
backupMeta := &backuppb.BackupMeta{}
if err = backupMeta.Unmarshal(metaData); err != nil {
if err = backupMeta.Unmarshal(backupMetaBytes); err != nil {
return nil, errors.Trace(err)
}
// read full backup databases to get map[table]table.Info
reader := metautil.NewMetaReader(backupMeta, s, nil)
reader := metautil.NewMetaReader(backupMeta, s, cipherInfo)
databases, err := metautil.LoadBackupTables(ctx, reader, false)
if err != nil {
@ -684,6 +698,7 @@ const UnsafePITRLogRestoreStartBeforeAnyUpstreamUserDDL = "UNSAFE_PITR_LOG_RESTO
func (rc *LogClient) generateDBReplacesFromFullBackupStorage(
ctx context.Context,
cfg *InitSchemaConfig,
cipherInfo *backuppb.CipherInfo,
) (map[stream.UpstreamID]*stream.DBReplace, error) {
dbReplaces := make(map[stream.UpstreamID]*stream.DBReplace)
if cfg.FullBackupStorage == nil {
@ -698,7 +713,7 @@ func (rc *LogClient) generateDBReplacesFromFullBackupStorage(
if err != nil {
return nil, errors.Trace(err)
}
fullBackupTables, err := initFullBackupTables(ctx, s, cfg.TableFilter)
fullBackupTables, err := initFullBackupTables(ctx, s, cfg.TableFilter, cipherInfo)
if err != nil {
return nil, errors.Trace(err)
}
@ -741,6 +756,7 @@ func (rc *LogClient) generateDBReplacesFromFullBackupStorage(
func (rc *LogClient) InitSchemasReplaceForDDL(
ctx context.Context,
cfg *InitSchemaConfig,
cipherInfo *backuppb.CipherInfo,
) (*stream.SchemasReplace, error) {
var (
err error
@ -785,7 +801,7 @@ func (rc *LogClient) InitSchemasReplaceForDDL(
if len(dbMaps) <= 0 {
log.Info("no id maps, build the table replaces from cluster and full backup schemas")
needConstructIdMap = true
dbReplaces, err = rc.generateDBReplacesFromFullBackupStorage(ctx, cfg)
dbReplaces, err = rc.generateDBReplacesFromFullBackupStorage(ctx, cfg, cipherInfo)
if err != nil {
return nil, errors.Trace(err)
}

View File

@ -866,7 +866,7 @@ func TestApplyKVFilesWithSingelMethod(t *testing.T) {
}
}
logclient.ApplyKVFilesWithSingelMethod(
logclient.ApplyKVFilesWithSingleMethod(
context.TODO(),
toLogDataFileInfoIter(iter.FromSlice(ds)),
applyFunc,
@ -1293,7 +1293,7 @@ func TestApplyKVFilesWithBatchMethod5(t *testing.T) {
require.Equal(t, backuppb.FileType_Delete, types[len(types)-1])
types = make([]backuppb.FileType, 0)
logclient.ApplyKVFilesWithSingelMethod(
logclient.ApplyKVFilesWithSingleMethod(
context.TODO(),
toLogDataFileInfoIter(iter.FromSlice(ds)),
applyFunc,
@ -1377,7 +1377,7 @@ func TestInitSchemasReplaceForDDL(t *testing.T) {
{
client := logclient.TEST_NewLogClient(123, 1, 2, 1, domain.NewMockDomain(), fakeSession{})
cfg := &logclient.InitSchemaConfig{IsNewTask: false}
_, err := client.InitSchemasReplaceForDDL(ctx, cfg)
_, err := client.InitSchemasReplaceForDDL(ctx, cfg, nil)
require.Error(t, err)
require.Regexp(t, "failed to get pitr id map from mysql.tidb_pitr_id_map.* [2, 1]", err.Error())
}
@ -1385,7 +1385,7 @@ func TestInitSchemasReplaceForDDL(t *testing.T) {
{
client := logclient.TEST_NewLogClient(123, 1, 2, 1, domain.NewMockDomain(), fakeSession{})
cfg := &logclient.InitSchemaConfig{IsNewTask: true}
_, err := client.InitSchemasReplaceForDDL(ctx, cfg)
_, err := client.InitSchemasReplaceForDDL(ctx, cfg, nil)
require.Error(t, err)
require.Regexp(t, "failed to get pitr id map from mysql.tidb_pitr_id_map.* [1, 1]", err.Error())
}
@ -1399,7 +1399,7 @@ func TestInitSchemasReplaceForDDL(t *testing.T) {
require.NoError(t, err)
client := logclient.TEST_NewLogClient(123, 1, 2, 1, domain.NewMockDomain(), se)
cfg := &logclient.InitSchemaConfig{IsNewTask: true}
_, err = client.InitSchemasReplaceForDDL(ctx, cfg)
_, err = client.InitSchemasReplaceForDDL(ctx, cfg, nil)
require.Error(t, err)
require.Contains(t, err.Error(), "miss upstream table information at `start-ts`(1) but the full backup path is not specified")
}

View File

@ -19,6 +19,7 @@ import (
"github.com/pingcap/errors"
backuppb "github.com/pingcap/kvproto/pkg/brpb"
"github.com/pingcap/kvproto/pkg/encryptionpb"
"github.com/pingcap/tidb/br/pkg/glue"
"github.com/pingcap/tidb/br/pkg/storage"
"github.com/pingcap/tidb/br/pkg/stream"
@ -90,6 +91,7 @@ func (helper *FakeStreamMetadataHelper) ReadFile(
length uint64,
compressionType backuppb.CompressionType,
storage storage.ExternalStorage,
encryptionInfo *encryptionpb.FileEncryptionInfo,
) ([]byte, error) {
return helper.Data[offset : offset+length], nil
}

View File

@ -23,6 +23,7 @@ import (
"github.com/pingcap/errors"
backuppb "github.com/pingcap/kvproto/pkg/brpb"
"github.com/pingcap/kvproto/pkg/encryptionpb"
"github.com/pingcap/kvproto/pkg/import_sstpb"
"github.com/pingcap/kvproto/pkg/kvrpcpb"
"github.com/pingcap/kvproto/pkg/metapb"
@ -101,6 +102,8 @@ func (importer *LogFileImporter) ImportKVFiles(
startTS uint64,
restoreTS uint64,
supportBatch bool,
cipherInfo *backuppb.CipherInfo,
masterKeys []*encryptionpb.MasterKey,
) error {
var (
startKey []byte
@ -111,7 +114,7 @@ func (importer *LogFileImporter) ImportKVFiles(
if !supportBatch && len(files) > 1 {
return errors.Annotatef(berrors.ErrInvalidArgument,
"do not support batch apply but files count:%v > 1", len(files))
"do not support batch apply, file count: %v > 1", len(files))
}
log.Debug("import kv files", zap.Int("batch file count", len(files)))
@ -143,7 +146,8 @@ func (importer *LogFileImporter) ImportKVFiles(
if len(subfiles) == 0 {
return RPCResultOK()
}
return importer.importKVFileForRegion(ctx, subfiles, rule, shiftStartTS, startTS, restoreTS, r, supportBatch)
return importer.importKVFileForRegion(ctx, subfiles, rule, shiftStartTS, startTS, restoreTS, r, supportBatch,
cipherInfo, masterKeys)
})
return errors.Trace(err)
}
@ -184,9 +188,11 @@ func (importer *LogFileImporter) importKVFileForRegion(
restoreTS uint64,
info *split.RegionInfo,
supportBatch bool,
cipherInfo *backuppb.CipherInfo,
masterKeys []*encryptionpb.MasterKey,
) RPCResult {
// Try to download file.
result := importer.downloadAndApplyKVFile(ctx, files, rule, info, shiftStartTS, startTS, restoreTS, supportBatch)
result := importer.downloadAndApplyKVFile(ctx, files, rule, info, shiftStartTS, startTS, restoreTS, supportBatch, cipherInfo, masterKeys)
if !result.OK() {
errDownload := result.Err
for _, e := range multierr.Errors(errDownload) {
@ -216,7 +222,8 @@ func (importer *LogFileImporter) downloadAndApplyKVFile(
startTS uint64,
restoreTS uint64,
supportBatch bool,
) RPCResult {
cipherInfo *backuppb.CipherInfo,
masterKeys []*encryptionpb.MasterKey) RPCResult {
leader := regionInfo.Leader
if leader == nil {
return RPCResultFromError(errors.Annotatef(berrors.ErrPDLeaderNotFound,
@ -251,11 +258,12 @@ func (importer *LogFileImporter) downloadAndApplyKVFile(
}
return startTS
}(),
RestoreTs: restoreTS,
StartKey: regionInfo.Region.GetStartKey(),
EndKey: regionInfo.Region.GetEndKey(),
Sha256: file.GetSha256(),
CompressionType: file.CompressionType,
RestoreTs: restoreTS,
StartKey: regionInfo.Region.GetStartKey(),
EndKey: regionInfo.Region.GetEndKey(),
Sha256: file.GetSha256(),
CompressionType: file.CompressionType,
FileEncryptionInfo: file.FileEncryptionInfo,
}
metas = append(metas, meta)
@ -276,6 +284,8 @@ func (importer *LogFileImporter) downloadAndApplyKVFile(
RewriteRules: rewriteRules,
Context: reqCtx,
StorageCacheId: importer.cacheKey,
CipherInfo: cipherInfo,
MasterKeys: masterKeys,
}
} else {
req = &import_sstpb.ApplyRequest{
@ -284,16 +294,18 @@ func (importer *LogFileImporter) downloadAndApplyKVFile(
RewriteRule: *rewriteRules[0],
Context: reqCtx,
StorageCacheId: importer.cacheKey,
CipherInfo: cipherInfo,
MasterKeys: masterKeys,
}
}
log.Debug("apply kv file", logutil.Leader(leader))
log.Debug("applying kv file", logutil.Leader(leader))
resp, err := importer.importClient.ApplyKVFile(ctx, leader.GetStoreId(), req)
if err != nil {
return RPCResultFromError(errors.Trace(err))
}
if resp.GetError() != nil {
logutil.CL(ctx).Warn("import meet error", zap.Stringer("error", resp.GetError()))
logutil.CL(ctx).Warn("import has error", zap.Stringer("error", resp.GetError()))
return RPCResultFromPBError(resp.GetError())
}
return RPCResultOK()

View File

@ -60,6 +60,7 @@ func TestImportKVFiles(t *testing.T) {
startTS,
restoreTS,
false,
nil, nil,
)
require.True(t, berrors.ErrInvalidArgument.Equal(err))
}
@ -268,9 +269,9 @@ func TestFileImporter(t *testing.T) {
require.NoError(t, err)
rewriteRules, encodeKeyFiles := prepareData()
err = importer.ImportKVFiles(ctx, encodeKeyFiles, rewriteRules, 1, 1, 1, true)
err = importer.ImportKVFiles(ctx, encodeKeyFiles, rewriteRules, 1, 1, 1, true, nil, nil)
require.NoError(t, err)
err = importer.ImportKVFiles(ctx, encodeKeyFiles, rewriteRules, 1, 1, 1, false)
err = importer.ImportKVFiles(ctx, encodeKeyFiles, rewriteRules, 1, 1, 1, false, nil, nil)
require.NoError(t, err)
}

View File

@ -12,7 +12,9 @@ import (
"github.com/pingcap/errors"
backuppb "github.com/pingcap/kvproto/pkg/brpb"
"github.com/pingcap/kvproto/pkg/encryptionpb"
"github.com/pingcap/log"
"github.com/pingcap/tidb/br/pkg/encryption"
berrors "github.com/pingcap/tidb/br/pkg/errors"
"github.com/pingcap/tidb/br/pkg/storage"
"github.com/pingcap/tidb/br/pkg/stream"
@ -54,6 +56,7 @@ type streamMetadataHelper interface {
length uint64,
compressionType backuppb.CompressionType,
storage storage.ExternalStorage,
encryptionInfo *encryptionpb.FileEncryptionInfo,
) ([]byte, error)
ParseToMetadata(rawMetaData []byte) (*backuppb.Metadata, error)
}
@ -85,6 +88,7 @@ type LogFileManagerInit struct {
Storage storage.ExternalStorage
MetadataDownloadBatchSize uint
EncryptionManager *encryption.Manager
}
type DDLMetaGroup struct {
@ -99,7 +103,7 @@ func CreateLogFileManager(ctx context.Context, init LogFileManagerInit) (*LogFil
startTS: init.StartTS,
restoreTS: init.RestoreTS,
storage: init.Storage,
helper: stream.NewMetadataHelper(),
helper: stream.NewMetadataHelper(stream.WithEncryptionManager(init.EncryptionManager)),
metadataDownloadBatchSize: init.MetadataDownloadBatchSize,
}
@ -329,7 +333,8 @@ func (rc *LogFileManager) ReadAllEntries(
kvEntries := make([]*KvEntryWithTS, 0)
nextKvEntries := make([]*KvEntryWithTS, 0)
buff, err := rc.helper.ReadFile(ctx, file.Path, file.RangeOffset, file.RangeLength, file.CompressionType, rc.storage)
buff, err := rc.helper.ReadFile(ctx, file.Path, file.RangeOffset, file.RangeLength, file.CompressionType,
rc.storage, file.FileEncryptionInfo)
if err != nil {
return nil, nil, errors.Trace(err)
}

View File

@ -647,7 +647,7 @@ func (rc *SnapClient) CreateDatabases(ctx context.Context, dbs []*metautil.Datab
return nil
}
log.Info("create databases in db pool", zap.Int("pool size", len(rc.dbPool)))
log.Info("create databases in db pool", zap.Int("pool size", len(rc.dbPool)), zap.Int("number of db", len(dbs)))
eg, ectx := errgroup.WithContext(ctx)
workers := tidbutil.NewWorkerPool(uint(len(rc.dbPool)), "DB DDL workers")
for _, db_ := range dbs {

View File

@ -15,6 +15,7 @@ go_library(
importpath = "github.com/pingcap/tidb/br/pkg/stream",
visibility = ["//visibility:public"],
deps = [
"//br/pkg/encryption",
"//br/pkg/errors",
"//br/pkg/glue",
"//br/pkg/httputil",
@ -36,6 +37,7 @@ go_library(
"@com_github_klauspost_compress//zstd",
"@com_github_pingcap_errors//:errors",
"@com_github_pingcap_kvproto//pkg/brpb",
"@com_github_pingcap_kvproto//pkg/encryptionpb",
"@com_github_pingcap_kvproto//pkg/metapb",
"@com_github_pingcap_log//:log",
"@com_github_tikv_client_go_v2//oracle",

View File

@ -16,12 +16,16 @@ package stream
import (
"context"
"crypto/sha256"
"encoding/hex"
"strings"
"github.com/klauspost/compress/zstd"
"github.com/pingcap/errors"
backuppb "github.com/pingcap/kvproto/pkg/brpb"
"github.com/pingcap/kvproto/pkg/encryptionpb"
"github.com/pingcap/log"
"github.com/pingcap/tidb/br/pkg/encryption"
"github.com/pingcap/tidb/br/pkg/storage"
"github.com/pingcap/tidb/pkg/kv"
"github.com/pingcap/tidb/pkg/meta"
@ -157,16 +161,31 @@ type ContentRef struct {
// MetadataHelper make restore/truncate compatible with metadataV1 and metadataV2.
type MetadataHelper struct {
cache map[string]*ContentRef
decoder *zstd.Decoder
cache map[string]*ContentRef
decoder *zstd.Decoder
encryptionManager *encryption.Manager
}
func NewMetadataHelper() *MetadataHelper {
type MetadataHelperOption func(*MetadataHelper)
func WithEncryptionManager(manager *encryption.Manager) MetadataHelperOption {
return func(mh *MetadataHelper) {
mh.encryptionManager = manager
}
}
func NewMetadataHelper(opts ...MetadataHelperOption) *MetadataHelper {
decoder, _ := zstd.NewReader(nil)
return &MetadataHelper{
helper := &MetadataHelper{
cache: make(map[string]*ContentRef),
decoder: decoder,
}
for _, opt := range opts {
opt(helper)
}
return helper
}
func (m *MetadataHelper) InitCacheEntry(path string, ref int) {
@ -191,6 +210,35 @@ func (m *MetadataHelper) decodeCompressedData(data []byte, compressionType backu
"failed to decode compressed data: compression type is unimplemented. type id is %d", compressionType)
}
func (m *MetadataHelper) verifyChecksumAndDecryptIfNeeded(ctx context.Context, data []byte,
encryptionInfo *encryptionpb.FileEncryptionInfo) ([]byte, error) {
// no need to decrypt
if encryptionInfo == nil {
return data, nil
}
if m.encryptionManager == nil {
return nil, errors.New("need to decrypt data but encryption manager not set")
}
// Verify checksum before decryption
if encryptionInfo.Checksum != nil {
actualChecksum := sha256.Sum256(data)
expectedChecksumHex := hex.EncodeToString(encryptionInfo.Checksum)
actualChecksumHex := hex.EncodeToString(actualChecksum[:])
if expectedChecksumHex != actualChecksumHex {
return nil, errors.Errorf("checksum mismatch before decryption, expected %s, actual %s",
expectedChecksumHex, actualChecksumHex)
}
}
decryptedContent, err := m.encryptionManager.Decrypt(ctx, data, encryptionInfo)
if err != nil {
return nil, errors.Trace(err)
}
return decryptedContent, nil
}
func (m *MetadataHelper) ReadFile(
ctx context.Context,
path string,
@ -198,6 +246,7 @@ func (m *MetadataHelper) ReadFile(
length uint64,
compressionType backuppb.CompressionType,
storage storage.ExternalStorage,
encryptionInfo *encryptionpb.FileEncryptionInfo,
) ([]byte, error) {
var err error
cref, exist := m.cache[path]
@ -212,7 +261,12 @@ func (m *MetadataHelper) ReadFile(
if err != nil {
return nil, errors.Trace(err)
}
return m.decodeCompressedData(data, compressionType)
// decrypt if needed
decryptedData, err := m.verifyChecksumAndDecryptIfNeeded(ctx, data, encryptionInfo)
if err != nil {
return nil, errors.Trace(err)
}
return m.decodeCompressedData(decryptedData, compressionType)
}
cref.ref -= 1
@ -223,8 +277,12 @@ func (m *MetadataHelper) ReadFile(
return nil, errors.Trace(err)
}
}
buf, err := m.decodeCompressedData(cref.data[offset:offset+length], compressionType)
// decrypt if needed
decryptedData, err := m.verifyChecksumAndDecryptIfNeeded(ctx, cref.data[offset:offset+length], encryptionInfo)
if err != nil {
return nil, errors.Trace(err)
}
buf, err := m.decodeCompressedData(decryptedData, compressionType)
if cref.ref <= 0 {
// need reset reference information.

View File

@ -66,14 +66,14 @@ func TestMetadataHelperReadFile(t *testing.T) {
require.NoError(t, err)
helper.InitCacheEntry(filename2, 2)
get_data, err := helper.ReadFile(ctx, filename1, 0, 0, backuppb.CompressionType_UNKNOWN, s)
get_data, err := helper.ReadFile(ctx, filename1, 0, 0, backuppb.CompressionType_UNKNOWN, s, nil)
require.NoError(t, err)
require.Equal(t, data1, get_data)
get_data, err = helper.ReadFile(ctx, filename2, 0, uint64(len(data1)), backuppb.CompressionType_UNKNOWN, s)
get_data, err = helper.ReadFile(ctx, filename2, 0, uint64(len(data1)), backuppb.CompressionType_UNKNOWN, s, nil)
require.NoError(t, err)
require.Equal(t, data1, get_data)
get_data, err = helper.ReadFile(ctx, filename2, uint64(len(data1)), uint64(len(data2)),
backuppb.CompressionType_ZSTD, s)
backuppb.CompressionType_ZSTD, s, nil)
require.NoError(t, err)
require.Equal(t, data1, get_data)
}

View File

@ -8,6 +8,7 @@ go_library(
"backup_raw.go",
"backup_txn.go",
"common.go",
"encryption.go",
"restore.go",
"restore_data.go",
"restore_ebs_meta.go",
@ -27,6 +28,7 @@ go_library(
"//br/pkg/config",
"//br/pkg/conn",
"//br/pkg/conn/util",
"//br/pkg/encryption",
"//br/pkg/errors",
"//br/pkg/glue",
"//br/pkg/httputil",
@ -106,12 +108,13 @@ go_test(
"backup_test.go",
"common_test.go",
"config_test.go",
"encryption_test.go",
"restore_test.go",
"stream_test.go",
],
embed = [":task"],
flaky = True,
shard_count = 33,
shard_count = 38,
deps = [
"//br/pkg/backup",
"//br/pkg/config",

View File

@ -72,7 +72,6 @@ const (
TableBackupCmd = "Table Backup"
RawBackupCmd = "Raw Backup"
TxnBackupCmd = "Txn Backup"
EBSBackupCmd = "EBS Backup"
)
// CompressionConfig is the configuration for sst file compression.

View File

@ -91,9 +91,13 @@ const (
defaultGRPCKeepaliveTimeout = 3 * time.Second
defaultCloudAPIConcurrency = 8
flagCipherType = "crypter.method"
flagCipherKey = "crypter.key"
flagCipherKeyFile = "crypter.key-file"
flagFullBackupCipherType = "crypter.method"
flagFullBackupCipherKey = "crypter.key"
flagFullBackupCipherKeyFile = "crypter.key-file"
flagLogBackupCipherType = "log.crypter.method"
flagLogBackupCipherKey = "log.crypter.key"
flagLogBackupCipherKeyFile = "log.crypter.key-file"
flagMetadataDownloadBatchSize = "metadata-download-batch-size"
defaultMetadataDownloadBatchSize = 128
@ -104,6 +108,10 @@ const (
crypterAES256KeyLen = 32
flagFullBackupType = "type"
masterKeysDelimiter = ","
flagMasterKeyConfig = "master-key"
flagMasterKeyCipherType = "master-key-crypter-method"
)
const (
@ -260,8 +268,17 @@ type Config struct {
// GrpcKeepaliveTimeout is the max time a grpc conn can keep idel before killed.
GRPCKeepaliveTimeout time.Duration `json:"grpc-keepalive-timeout" toml:"grpc-keepalive-timeout"`
// Plaintext data key mainly used for full/snapshot backup and restore.
CipherInfo backuppb.CipherInfo `json:"-" toml:"-"`
// Could be used in log backup and restore but not recommended in a serious environment since data key is stored
// in PD in plaintext.
LogBackupCipherInfo backuppb.CipherInfo `json:"-" toml:"-"`
// Master key based encryption for log restore.
// More than one can be specified for log restore if master key rotated during log backup.
MasterKeyConfig backuppb.MasterKeyConfig `json:"master-key-config" toml:"master-key-config"`
// whether there's explicit filter
ExplicitFilter bool `json:"-" toml:"-"`
@ -310,17 +327,34 @@ func DefineCommonFlags(flags *pflag.FlagSet) {
flags.BoolP(flagSkipCheckPath, "", false, "Skip path verification")
_ = flags.MarkHidden(flagSkipCheckPath)
flags.String(flagCipherType, "plaintext", "Encrypt/decrypt method, "+
flags.String(flagFullBackupCipherType, "plaintext", "Encrypt/decrypt method, "+
"be one of plaintext|aes128-ctr|aes192-ctr|aes256-ctr case-insensitively, "+
"\"plaintext\" represents no encrypt/decrypt")
flags.String(flagCipherKey, "",
flags.String(flagFullBackupCipherKey, "",
"aes-crypter key, used to encrypt/decrypt the data "+
"by the hexadecimal string, eg: \"0123456789abcdef0123456789abcdef\"")
flags.String(flagCipherKeyFile, "", "FilePath, its content is used as the cipher-key")
flags.String(flagFullBackupCipherKeyFile, "", "FilePath, its content is used as the cipher-key")
flags.Uint(flagMetadataDownloadBatchSize, defaultMetadataDownloadBatchSize,
"the batch size of downloading metadata, such as log restore metadata for truncate or restore")
// log backup plaintext key flags
flags.String(flagLogBackupCipherType, "plaintext", "Encrypt/decrypt method, "+
"be one of plaintext|aes128-ctr|aes192-ctr|aes256-ctr case-insensitively, "+
"\"plaintext\" represents no encrypt/decrypt")
flags.String(flagLogBackupCipherKey, "",
"aes-crypter key, used to encrypt/decrypt the data "+
"by the hexadecimal string, eg: \"0123456789abcdef0123456789abcdef\"")
flags.String(flagLogBackupCipherKeyFile, "", "FilePath, its content is used as the cipher-key")
// master key config
flags.String(flagMasterKeyCipherType, "plaintext", "Encrypt/decrypt method, "+
"be one of plaintext|aes128-ctr|aes192-ctr|aes256-ctr case-insensitively, "+
"\"plaintext\" represents no encrypt/decrypt")
flags.String(flagMasterKeyConfig, "", "Master key config for point in time restore "+
"examples: \"local:///path/to/master/key/file,"+
"aws-kms:///{key-id}?AWS_ACCESS_KEY_ID={access-key}&AWS_SECRET_ACCESS_KEY={secret-key}&REGION={region},"+
"gcp-kms:///projects/{project-id}/locations/{location}/keyRings/{keyring}/cryptoKeys/{key-name}?AUTH=specified&CREDENTIALS={credentials}\"")
_ = flags.MarkHidden(flagMetadataDownloadBatchSize)
storage.DefineFlags(flags)
@ -334,10 +368,15 @@ func HiddenFlagsForStream(flags *pflag.FlagSet) {
_ = flags.MarkHidden(flagRateLimit)
_ = flags.MarkHidden(flagRateLimitUnit)
_ = flags.MarkHidden(flagRemoveTiFlash)
_ = flags.MarkHidden(flagCipherType)
_ = flags.MarkHidden(flagCipherKey)
_ = flags.MarkHidden(flagCipherKeyFile)
_ = flags.MarkHidden(flagFullBackupCipherType)
_ = flags.MarkHidden(flagFullBackupCipherKey)
_ = flags.MarkHidden(flagFullBackupCipherKeyFile)
_ = flags.MarkHidden(flagLogBackupCipherType)
_ = flags.MarkHidden(flagLogBackupCipherKey)
_ = flags.MarkHidden(flagLogBackupCipherKeyFile)
_ = flags.MarkHidden(flagSwitchModeInterval)
_ = flags.MarkHidden(flagMasterKeyConfig)
_ = flags.MarkHidden(flagMasterKeyCipherType)
storage.HiddenFlagsForStream(flags)
}
@ -456,7 +495,7 @@ func checkCipherKeyMatch(cipher *backuppb.CipherInfo) bool {
}
func (cfg *Config) parseCipherInfo(flags *pflag.FlagSet) error {
crypterStr, err := flags.GetString(flagCipherType)
crypterStr, err := flags.GetString(flagFullBackupCipherType)
if err != nil {
return errors.Trace(err)
}
@ -470,12 +509,12 @@ func (cfg *Config) parseCipherInfo(flags *pflag.FlagSet) error {
return nil
}
key, err := flags.GetString(flagCipherKey)
key, err := flags.GetString(flagFullBackupCipherKey)
if err != nil {
return errors.Trace(err)
}
keyFilePath, err := flags.GetString(flagCipherKeyFile)
keyFilePath, err := flags.GetString(flagFullBackupCipherKeyFile)
if err != nil {
return errors.Trace(err)
}
@ -492,6 +531,43 @@ func (cfg *Config) parseCipherInfo(flags *pflag.FlagSet) error {
return nil
}
func (cfg *Config) parseLogBackupCipherInfo(flags *pflag.FlagSet) (bool, error) {
crypterStr, err := flags.GetString(flagLogBackupCipherType)
if err != nil {
return false, errors.Trace(err)
}
cfg.LogBackupCipherInfo.CipherType, err = parseCipherType(crypterStr)
if err != nil {
return false, errors.Trace(err)
}
if !utils.IsEffectiveEncryptionMethod(cfg.LogBackupCipherInfo.CipherType) {
return false, nil
}
key, err := flags.GetString(flagLogBackupCipherKey)
if err != nil {
return false, errors.Trace(err)
}
keyFilePath, err := flags.GetString(flagLogBackupCipherKeyFile)
if err != nil {
return false, errors.Trace(err)
}
cfg.LogBackupCipherInfo.CipherKey, err = GetCipherKeyContent(key, keyFilePath)
if err != nil {
return false, errors.Trace(err)
}
if !checkCipherKeyMatch(&cfg.CipherInfo) {
return false, errors.Annotate(berrors.ErrInvalidArgument, "log backup encryption method and key length not match")
}
return true, nil
}
func (cfg *Config) normalizePDURLs() error {
for i := range cfg.PD {
var err error
@ -618,7 +694,17 @@ func (cfg *Config) ParseFromFlags(flags *pflag.FlagSet) error {
log.L().Info("--skip-check-path is deprecated, need explicitly set it anymore")
}
if err = cfg.parseCipherInfo(flags); err != nil {
err = cfg.parseCipherInfo(flags)
if err != nil {
return errors.Trace(err)
}
hasLogBackupPlaintextKey, err := cfg.parseLogBackupCipherInfo(flags)
if err != nil {
return errors.Trace(err)
}
if err = cfg.parseAndValidateMasterKeyInfo(hasLogBackupPlaintextKey, flags); err != nil {
return errors.Trace(err)
}
@ -629,6 +715,51 @@ func (cfg *Config) ParseFromFlags(flags *pflag.FlagSet) error {
return cfg.normalizePDURLs()
}
func (cfg *Config) parseAndValidateMasterKeyInfo(hasPlaintextKey bool, flags *pflag.FlagSet) error {
masterKeyString, err := flags.GetString(flagMasterKeyConfig)
if err != nil {
return errors.Errorf("master key flag '%s' is not defined: %v", flagMasterKeyConfig, err)
}
if masterKeyString == "" {
return nil
}
if hasPlaintextKey {
return errors.Errorf("invalid argument: both plaintext data key encryption and master key based encryption are set at the same time")
}
encryptionMethodString, err := flags.GetString(flagMasterKeyCipherType)
if err != nil {
return errors.Errorf("encryption method flag '%s' is not defined: %v", flagMasterKeyCipherType, err)
}
encryptionMethod, err := parseCipherType(encryptionMethodString)
if err != nil {
return errors.Errorf("failed to parse encryption method: %v", err)
}
if !utils.IsEffectiveEncryptionMethod(encryptionMethod) {
return errors.Errorf("invalid encryption method: %s", encryptionMethodString)
}
masterKeyStrings := strings.Split(masterKeyString, masterKeysDelimiter)
cfg.MasterKeyConfig = backuppb.MasterKeyConfig{
EncryptionType: encryptionMethod,
MasterKeys: make([]*encryptionpb.MasterKey, 0, len(masterKeyStrings)),
}
for _, keyString := range masterKeyStrings {
masterKey, err := validateAndParseMasterKeyString(strings.TrimSpace(keyString))
if err != nil {
return errors.Wrapf(err, "invalid master key configuration: %s", keyString)
}
cfg.MasterKeyConfig.MasterKeys = append(cfg.MasterKeyConfig.MasterKeys, &masterKey)
}
return nil
}
// NewMgr creates a new mgr at the given PD address.
func NewMgr(ctx context.Context,
g glue.Glue, pds []string,
@ -726,7 +857,7 @@ func ReadBackupMeta(
if cfg.CipherInfo.CipherType != encryptionpb.EncryptionMethod_PLAINTEXT {
iv = metaData[:metautil.CrypterIvLen]
}
decryptBackupMeta, err := metautil.Decrypt(metaData[len(iv):], &cfg.CipherInfo, iv)
decryptBackupMeta, err := utils.Decrypt(metaData[len(iv):], &cfg.CipherInfo, iv)
if err != nil {
return nil, nil, nil, errors.Annotate(err, "decrypt failed with wrong key")
}

View File

@ -185,6 +185,7 @@ func expectedDefaultConfig() Config {
GRPCKeepaliveTime: 10000000000,
GRPCKeepaliveTimeout: 3000000000,
CipherInfo: backup.CipherInfo{CipherType: 1},
LogBackupCipherInfo: backup.CipherInfo{CipherType: 1},
MetadataDownloadBatchSize: 0x80,
}
}
@ -241,3 +242,132 @@ func TestDefaultRestore(t *testing.T) {
defaultConfig := expectedDefaultRestoreConfig()
require.Equal(t, defaultConfig, def)
}
func TestParseAndValidateMasterKeyInfo(t *testing.T) {
tests := []struct {
name string
input string
expectedKeys []*encryptionpb.MasterKey
expectError bool
}{
{
name: "Empty input",
input: "",
expectedKeys: nil,
expectError: false,
},
{
name: "Single local config",
input: "local:///path/to/key",
expectedKeys: []*encryptionpb.MasterKey{
{
Backend: &encryptionpb.MasterKey_File{
File: &encryptionpb.MasterKeyFile{Path: "/path/to/key"},
},
},
},
expectError: false,
},
{
name: "Single AWS config",
input: "aws-kms:///key-id?AWS_ACCESS_KEY_ID=AKIAIOSFODNN7EXAMPLE&AWS_SECRET_ACCESS_KEY=wJalrXUtnFEMI/K7MDENG/bPxRfiCYEXAMPLEKEY&REGION=us-west-2",
expectedKeys: []*encryptionpb.MasterKey{
{
Backend: &encryptionpb.MasterKey_Kms{
Kms: &encryptionpb.MasterKeyKms{
Vendor: "aws",
KeyId: "key-id",
Region: "us-west-2",
},
},
},
},
expectError: false,
},
{
name: "Single Azure config",
input: "azure-kms:///key-name/key-version?AZURE_TENANT_ID=tenant-id&AZURE_CLIENT_ID=client-id&AZURE_CLIENT_SECRET=client-secret&AZURE_VAULT_NAME=vault-name",
expectedKeys: []*encryptionpb.MasterKey{
{
Backend: &encryptionpb.MasterKey_Kms{
Kms: &encryptionpb.MasterKeyKms{
Vendor: "azure",
KeyId: "key-name/key-version",
AzureKms: &encryptionpb.AzureKms{
TenantId: "tenant-id",
ClientId: "client-id",
ClientSecret: "client-secret",
KeyVaultUrl: "vault-name",
},
},
},
},
},
expectError: false,
},
{
name: "Single GCP config",
input: "gcp-kms:///projects/project-id/locations/global/keyRings/ring-name/cryptoKeys/key-name?CREDENTIALS=credentials",
expectedKeys: []*encryptionpb.MasterKey{
{
Backend: &encryptionpb.MasterKey_Kms{
Kms: &encryptionpb.MasterKeyKms{
Vendor: "gcp",
KeyId: "projects/project-id/locations/global/keyRings/ring-name/cryptoKeys/key-name",
GcpKms: &encryptionpb.GcpKms{
Credential: "credentials",
},
},
},
},
},
expectError: false,
},
{
name: "Multiple configs",
input: "local:///path/to/key," +
"aws-kms:///key-id?AWS_ACCESS_KEY_ID=AKIAIOSFODNN7EXAMPLE&AWS_SECRET_ACCESS_KEY=wJalrXUtnFEMI/K7MDENG/bPxRfiCYEXAMPLEKEY&REGION=us-west-2",
expectedKeys: []*encryptionpb.MasterKey{
{
Backend: &encryptionpb.MasterKey_File{
File: &encryptionpb.MasterKeyFile{Path: "/path/to/key"},
},
},
{
Backend: &encryptionpb.MasterKey_Kms{
Kms: &encryptionpb.MasterKeyKms{
Vendor: "aws",
KeyId: "key-id",
Region: "us-west-2",
},
},
},
},
expectError: false,
},
{
name: "Invalid config",
input: "invalid:///config",
expectedKeys: nil,
expectError: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
cfg := &Config{}
flags := pflag.NewFlagSet("test", pflag.ContinueOnError)
flags.String(flagMasterKeyConfig, tt.input, "")
flags.String(flagMasterKeyCipherType, "aes256-ctr", "")
err := cfg.parseAndValidateMasterKeyInfo(false, flags)
if tt.expectError {
require.Error(t, err)
} else {
require.NoError(t, err)
require.Equal(t, tt.expectedKeys, cfg.MasterKeyConfig.MasterKeys)
}
})
}
}

166
br/pkg/task/encryption.go Normal file
View File

@ -0,0 +1,166 @@
// Copyright 2024 PingCAP, Inc. Licensed under Apache-2.0.
package task
import (
"fmt"
"net/url"
"path"
"regexp"
"github.com/pingcap/errors"
"github.com/pingcap/kvproto/pkg/encryptionpb"
)
const (
SchemeLocal = "local"
SchemeAWS = "aws-kms"
SchemeAzure = "azure-kms"
SchemeGCP = "gcp-kms"
AWSVendor = "aws"
AWSRegion = "REGION"
AWSEndpoint = "ENDPOINT"
AWSAccessKeyId = "AWS_ACCESS_KEY_ID"
AWSSecretKey = "AWS_SECRET_ACCESS_KEY"
AzureVendor = "azure"
AzureTenantID = "AZURE_TENANT_ID"
AzureClientID = "AZURE_CLIENT_ID"
AzureClientSecret = "AZURE_CLIENT_SECRET"
AzureVaultName = "AZURE_VAULT_NAME"
GCPVendor = "gcp"
GCPCredentials = "CREDENTIALS"
)
var (
awsRegex = regexp.MustCompile(`^/([^/]+)$`)
azureRegex = regexp.MustCompile(`^/(.+)$`)
gcpRegex = regexp.MustCompile(`^/projects/([^/]+)/locations/([^/]+)/keyRings/([^/]+)/cryptoKeys/([^/]+)/?$`)
)
func validateAndParseMasterKeyString(keyString string) (encryptionpb.MasterKey, error) {
u, err := url.Parse(keyString)
if err != nil {
return encryptionpb.MasterKey{}, errors.Trace(err)
}
switch u.Scheme {
case SchemeLocal:
return parseLocalDiskConfig(u)
case SchemeAWS:
return parseAwsKmsConfig(u)
case SchemeAzure:
return parseAzureKmsConfig(u)
case SchemeGCP:
return parseGcpKmsConfig(u)
default:
return encryptionpb.MasterKey{}, errors.Errorf("unsupported master key type: %s", u.Scheme)
}
}
func parseLocalDiskConfig(u *url.URL) (encryptionpb.MasterKey, error) {
if !path.IsAbs(u.Path) {
return encryptionpb.MasterKey{}, errors.New("local master key path must be absolute")
}
return encryptionpb.MasterKey{
Backend: &encryptionpb.MasterKey_File{
File: &encryptionpb.MasterKeyFile{
Path: u.Path,
},
},
}, nil
}
func parseAwsKmsConfig(u *url.URL) (encryptionpb.MasterKey, error) {
matches := awsRegex.FindStringSubmatch(u.Path)
if matches == nil {
return encryptionpb.MasterKey{}, errors.New("invalid AWS KMS key ID format")
}
keyID := matches[1]
q := u.Query()
region := q.Get(AWSRegion)
accessKey := q.Get(AWSAccessKeyId)
secretAccessKey := q.Get(AWSSecretKey)
if region == "" {
return encryptionpb.MasterKey{}, errors.New("missing AWS KMS region info")
}
var awsKms *encryptionpb.AwsKms
if accessKey != "" && secretAccessKey != "" {
awsKms = &encryptionpb.AwsKms{
AccessKey: accessKey,
SecretAccessKey: secretAccessKey,
}
}
return encryptionpb.MasterKey{
Backend: &encryptionpb.MasterKey_Kms{
Kms: &encryptionpb.MasterKeyKms{
Vendor: AWSVendor,
KeyId: keyID,
Region: region,
Endpoint: q.Get(AWSEndpoint), // Optional
AwsKms: awsKms, // Optional, can read from env
},
},
}, nil
}
func parseAzureKmsConfig(u *url.URL) (encryptionpb.MasterKey, error) {
matches := azureRegex.FindStringSubmatch(u.Path)
if matches == nil {
return encryptionpb.MasterKey{}, errors.New("invalid Azure KMS path format")
}
keyID := matches[1] // This now captures the entire path as the key ID
q := u.Query()
azureKms := &encryptionpb.AzureKms{
TenantId: q.Get(AzureTenantID),
ClientId: q.Get(AzureClientID),
ClientSecret: q.Get(AzureClientSecret),
KeyVaultUrl: q.Get(AzureVaultName),
}
if azureKms.TenantId == "" || azureKms.ClientId == "" || azureKms.ClientSecret == "" || azureKms.KeyVaultUrl == "" {
return encryptionpb.MasterKey{}, errors.New("missing required Azure KMS parameters")
}
return encryptionpb.MasterKey{
Backend: &encryptionpb.MasterKey_Kms{
Kms: &encryptionpb.MasterKeyKms{
Vendor: AzureVendor,
KeyId: keyID,
AzureKms: azureKms,
},
},
}, nil
}
func parseGcpKmsConfig(u *url.URL) (encryptionpb.MasterKey, error) {
matches := gcpRegex.FindStringSubmatch(u.Path)
if matches == nil {
return encryptionpb.MasterKey{}, errors.New("invalid GCP KMS path format")
}
projectID, location, keyRing, keyName := matches[1], matches[2], matches[3], matches[4]
q := u.Query()
gcpKms := &encryptionpb.GcpKms{
Credential: q.Get(GCPCredentials),
}
return encryptionpb.MasterKey{
Backend: &encryptionpb.MasterKey_Kms{
Kms: &encryptionpb.MasterKeyKms{
Vendor: GCPVendor,
KeyId: fmt.Sprintf("projects/%s/locations/%s/keyRings/%s/cryptoKeys/%s", projectID, location, keyRing, keyName),
GcpKms: gcpKms,
},
},
}, nil
}

View File

@ -0,0 +1,192 @@
// Copyright 2024 PingCAP, Inc. Licensed under Apache-2.0.
package task
import (
"net/url"
"testing"
"github.com/pingcap/kvproto/pkg/encryptionpb"
"github.com/stretchr/testify/assert"
)
func TestParseLocalDiskConfig(t *testing.T) {
tests := []struct {
name string
input string
expected encryptionpb.MasterKey
expectError bool
}{
{
name: "Valid local path",
input: "local:///path/to/key",
expected: encryptionpb.MasterKey{Backend: &encryptionpb.MasterKey_File{File: &encryptionpb.MasterKeyFile{Path: "/path/to/key"}}},
expectError: false,
},
{
name: "Invalid local path",
input: "local://relative/path",
expectError: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
u, _ := url.Parse(tt.input)
result, err := parseLocalDiskConfig(u)
if tt.expectError {
assert.Error(t, err)
} else {
assert.NoError(t, err)
assert.Equal(t, tt.expected, result)
}
})
}
}
func TestParseAwsKmsConfig(t *testing.T) {
tests := []struct {
name string
input string
expected encryptionpb.MasterKey
expectError bool
}{
{
name: "Valid AWS config",
input: "aws-kms:///key-id?AWS_ACCESS_KEY_ID=AKIAIOSFODNN7EXAMPLE&AWS_SECRET_ACCESS_KEY=wJalrXUtnFEMI/K7MDENG/bPxRfiCYEXAMPLEKEY&REGION=us-west-2",
expected: encryptionpb.MasterKey{
Backend: &encryptionpb.MasterKey_Kms{
Kms: &encryptionpb.MasterKeyKms{
Vendor: "aws",
KeyId: "key-id",
Region: "us-west-2",
},
},
},
expectError: false,
},
{
name: "Missing key ID",
input: "aws-kms:///?AWS_ACCESS_KEY_ID=AKIAIOSFODNN7EXAMPLE&AWS_SECRET_ACCESS_KEY=wJalrXUtnFEMI/K7MDENG/bPxRfiCYEXAMPLEKEY&REGION=us-west-2",
expectError: true,
},
{
name: "Missing required parameter",
input: "aws-kms:///key-id?AWS_ACCESS_KEY_ID=AKIAIOSFODNN7EXAMPLE&REGION=us-west-2",
expectError: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
u, _ := url.Parse(tt.input)
result, err := parseAwsKmsConfig(u)
if tt.expectError {
assert.Error(t, err)
} else {
assert.NoError(t, err)
assert.Equal(t, tt.expected, result)
}
})
}
}
func TestParseAzureKmsConfig(t *testing.T) {
tests := []struct {
name string
input string
expected encryptionpb.MasterKey
expectError bool
}{
{
name: "Valid Azure config",
input: "azure-kms:///key-name/key-version?AZURE_TENANT_ID=tenant-id&AZURE_CLIENT_ID=client-id&AZURE_CLIENT_SECRET=client-secret&AZURE_VAULT_NAME=vault-name",
expected: encryptionpb.MasterKey{
Backend: &encryptionpb.MasterKey_Kms{
Kms: &encryptionpb.MasterKeyKms{
Vendor: "azure",
KeyId: "key-name/key-version",
AzureKms: &encryptionpb.AzureKms{
TenantId: "tenant-id",
ClientId: "client-id",
ClientSecret: "client-secret",
KeyVaultUrl: "vault-name",
},
},
},
},
expectError: false,
},
{
name: "Missing required parameter",
input: "azure-kms:///key-name/key-version?AZURE_TENANT_ID=tenant-id&AZURE_CLIENT_ID=client-id&AZURE_VAULT_NAME=vault-name",
expectError: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
u, _ := url.Parse(tt.input)
result, err := parseAzureKmsConfig(u)
if tt.expectError {
assert.Error(t, err)
} else {
assert.NoError(t, err)
assert.Equal(t, tt.expected, result)
}
})
}
}
func TestParseGcpKmsConfig(t *testing.T) {
tests := []struct {
name string
input string
expected encryptionpb.MasterKey
expectError bool
}{
{
name: "Valid GCP config",
input: "gcp-kms:///projects/project-id/locations/global/keyRings/ring-name/cryptoKeys/key-name?CREDENTIALS=credentials",
expected: encryptionpb.MasterKey{
Backend: &encryptionpb.MasterKey_Kms{
Kms: &encryptionpb.MasterKeyKms{
Vendor: "gcp",
KeyId: "projects/project-id/locations/global/keyRings/ring-name/cryptoKeys/key-name",
GcpKms: &encryptionpb.GcpKms{
Credential: "credentials",
},
},
},
},
expectError: false,
},
{
name: "Invalid path format",
input: "gcp-kms:///invalid/path?CREDENTIALS=credentials",
expectError: true,
},
{
name: "Missing credentials",
input: "gcp-kms:///projects/project-id/locations/global/keyRings/ring-name/cryptoKeys/key-name",
expectError: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
u, _ := url.Parse(tt.input)
result, err := parseGcpKmsConfig(u)
if tt.expectError {
assert.Error(t, err)
} else {
assert.NoError(t, err)
assert.Equal(t, tt.expected, result)
}
})
}
}

View File

@ -247,7 +247,8 @@ type RestoreConfig struct {
AllowPITRFromIncremental bool `json:"allow-pitr-from-incremental" toml:"allow-pitr-from-incremental"`
// [startTs, RestoreTS] is used to `restore log` from StartTS to RestoreTS.
StartTS uint64 `json:"start-ts" toml:"start-ts"`
StartTS uint64 `json:"start-ts" toml:"start-ts"`
// if not specified system will restore to the max TS available
RestoreTS uint64 `json:"restore-ts" toml:"restore-ts"`
tiflashRecorder *tiflashrec.TiFlashRecorder `json:"-" toml:"-"`
PitrBatchCount uint32 `json:"pitr-batch-count" toml:"pitr-batch-count"`
@ -695,12 +696,12 @@ func RunRestore(c context.Context, g glue.Glue, cmdName string, cfg *RestoreConf
if err := version.CheckClusterVersion(c, mgr.GetPDClient(), version.CheckVersionForBRPiTR); err != nil {
return errors.Trace(err)
}
restoreError = RunStreamRestore(c, mgr, g, cmdName, cfg)
restoreError = RunStreamRestore(c, mgr, g, cfg)
} else {
if err := version.CheckClusterVersion(c, mgr.GetPDClient(), version.CheckVersionForBR); err != nil {
return errors.Trace(err)
}
restoreError = runRestore(c, mgr, g, cmdName, cfg, nil)
restoreError = runSnapshotRestore(c, mgr, g, cmdName, cfg, nil)
}
if restoreError != nil {
return errors.Trace(restoreError)
@ -733,12 +734,13 @@ func RunRestore(c context.Context, g glue.Glue, cmdName string, cfg *RestoreConf
return nil
}
func runRestore(c context.Context, mgr *conn.Mgr, g glue.Glue, cmdName string, cfg *RestoreConfig, checkInfo *PiTRTaskInfo) error {
func runSnapshotRestore(c context.Context, mgr *conn.Mgr, g glue.Glue, cmdName string, cfg *RestoreConfig, checkInfo *PiTRTaskInfo) error {
cfg.Adjust()
defer summary.Summary(cmdName)
ctx, cancel := context.WithCancel(c)
defer cancel()
log.Info("starting snapshot restore")
if span := opentracing.SpanFromContext(ctx); span != nil && span.Tracer() != nil {
span1 := span.Tracer().StartSpan("task.RunRestore", opentracing.ChildOf(span.Context()))
defer span1.Finish()
@ -836,7 +838,7 @@ func runRestore(c context.Context, mgr *conn.Mgr, g glue.Glue, cmdName string, c
if client.IsIncremental() {
// don't support checkpoint for the ddl restore
log.Info("the incremental snapshot restore doesn't support checkpoint mode, so unuse checkpoint.")
log.Info("the incremental snapshot restore doesn't support checkpoint mode, disable checkpoint.")
cfg.UseCheckpoint = false
}
@ -860,7 +862,7 @@ func runRestore(c context.Context, mgr *conn.Mgr, g glue.Glue, cmdName string, c
log.Info("finish removing pd scheduler")
}()
var checkpointFirstRun bool = true
var checkpointFirstRun = true
if cfg.UseCheckpoint {
// if the checkpoint metadata exists in the checkpoint storage, the restore is not
// for the first time.

View File

@ -36,6 +36,7 @@ import (
"github.com/pingcap/tidb/br/pkg/backup"
"github.com/pingcap/tidb/br/pkg/checkpoint"
"github.com/pingcap/tidb/br/pkg/conn"
"github.com/pingcap/tidb/br/pkg/encryption"
berrors "github.com/pingcap/tidb/br/pkg/errors"
"github.com/pingcap/tidb/br/pkg/glue"
"github.com/pingcap/tidb/br/pkg/httputil"
@ -302,8 +303,8 @@ func NewStreamMgr(ctx context.Context, cfg *StreamConfig, g glue.Glue, isStreamS
}
}()
// just stream start need Storage
s := &streamMgr{
// only stream start command needs Storage
streamManager := &streamMgr{
cfg: cfg,
mgr: mgr,
}
@ -323,12 +324,12 @@ func NewStreamMgr(ctx context.Context, cfg *StreamConfig, g glue.Glue, isStreamS
if err = client.SetStorage(ctx, backend, &opts); err != nil {
return nil, errors.Trace(err)
}
s.bc = client
streamManager.bc = client
// create http client to do some requirements check.
s.httpCli = httputil.NewClient(mgr.GetTLSConfig())
streamManager.httpCli = httputil.NewClient(mgr.GetTLSConfig())
}
return s, nil
return streamManager, nil
}
func (s *streamMgr) close() {
@ -346,7 +347,7 @@ func (s *streamMgr) setLock(ctx context.Context) error {
// adjustAndCheckStartTS checks that startTS should be smaller than currentTS,
// and endTS is larger than currentTS.
func (s *streamMgr) adjustAndCheckStartTS(ctx context.Context) error {
currentTS, err := s.mgr.GetTS(ctx)
currentTS, err := s.mgr.GetCurrentTsFromPD(ctx)
if err != nil {
return errors.Trace(err)
}
@ -499,7 +500,7 @@ func RunStreamCommand(
}
if err := commandFn(ctx, g, cmdName, cfg); err != nil {
log.Error("failed to stream", zap.String("command", cmdName), zap.Error(err))
log.Error("failed to run stream command", zap.String("command", cmdName), zap.Error(err))
summary.SetSuccessStatus(false)
summary.CollectFailureUnit(cmdName, err)
return err
@ -547,22 +548,28 @@ func RunStreamStart(
log.Warn("failed to close etcd client", zap.Error(closeErr))
}
}()
// check if any import/restore task is running, it's not allowed to start log backup
// while restore is ongoing.
if err = streamMgr.checkImportTaskRunning(ctx, cli.Client); err != nil {
return errors.Trace(err)
}
// It supports single stream log task currently.
if count, err := cli.GetTaskCount(ctx); err != nil {
return errors.Trace(err)
} else if count > 0 {
return errors.Annotate(berrors.ErrStreamLogTaskExist, "It supports single stream log task currently")
return errors.Annotate(berrors.ErrStreamLogTaskExist, "failed to start the log backup, allow only one running task")
}
exist, err := streamMgr.checkLock(ctx)
// make sure external file lock is available
locked, err := streamMgr.checkLock(ctx)
if err != nil {
return errors.Trace(err)
}
// exist is true, which represents restart a stream task. Or create a new stream task.
if exist {
// locked means this is a stream task restart. Or create a new stream task.
if locked {
logInfo, err := getLogRange(ctx, &cfg.Config)
if err != nil {
return errors.Trace(err)
@ -614,6 +621,7 @@ func RunStreamStart(
return errors.Annotate(berrors.ErrInvalidArgument, "nothing need to observe")
}
securityConfig := generateSecurityConfig(cfg)
ti := streamhelper.TaskInfo{
PBInfo: backuppb.StreamBackupTaskInfo{
Storage: streamMgr.bc.GetStorageBackend(),
@ -622,6 +630,7 @@ func RunStreamStart(
Name: cfg.TaskName,
TableFilter: cfg.FilterStr,
CompressionType: backuppb.CompressionType_ZSTD,
SecurityConfig: &securityConfig,
},
Ranges: ranges,
Pausing: false,
@ -633,6 +642,30 @@ func RunStreamStart(
return nil
}
func generateSecurityConfig(cfg *StreamConfig) backuppb.StreamBackupTaskSecurityConfig {
if len(cfg.LogBackupCipherInfo.CipherKey) > 0 && utils.IsEffectiveEncryptionMethod(cfg.LogBackupCipherInfo.CipherType) {
return backuppb.StreamBackupTaskSecurityConfig{
Encryption: &backuppb.StreamBackupTaskSecurityConfig_PlaintextDataKey{
PlaintextDataKey: &backuppb.CipherInfo{
CipherType: cfg.LogBackupCipherInfo.CipherType,
CipherKey: cfg.LogBackupCipherInfo.CipherKey,
},
},
}
}
if len(cfg.MasterKeyConfig.MasterKeys) > 0 && utils.IsEffectiveEncryptionMethod(cfg.MasterKeyConfig.EncryptionType) {
return backuppb.StreamBackupTaskSecurityConfig{
Encryption: &backuppb.StreamBackupTaskSecurityConfig_MasterKeyConfig{
MasterKeyConfig: &backuppb.MasterKeyConfig{
EncryptionType: cfg.MasterKeyConfig.EncryptionType,
MasterKeys: cfg.MasterKeyConfig.MasterKeys,
},
},
}
}
return backuppb.StreamBackupTaskSecurityConfig{}
}
func RunStreamMetadata(
c context.Context,
g glue.Glue,
@ -995,13 +1028,13 @@ func RunStreamTruncate(c context.Context, g glue.Glue, cmdName string, cfg *Stre
}
if cfg.Until < sp {
console.Println("According to the log, you have truncated backup data before", em(formatTS(sp)))
console.Println("According to the log, you have truncated log backup data before", em(formatTS(sp)))
if !cfg.SkipPrompt && !console.PromptBool("Continue? ") {
return nil
}
}
readMetaDone := console.ShowTask("Reading Metadata... ", glue.WithTimeCost())
readMetaDone := console.ShowTask("Reading log backup metadata... ", glue.WithTimeCost())
metas := stream.StreamMetadataSet{
MetadataDownloadBatchSize: cfg.MetadataDownloadBatchSize,
Helper: stream.NewMetadataHelper(),
@ -1025,11 +1058,11 @@ func RunStreamTruncate(c context.Context, g glue.Glue, cmdName string, cfg *Stre
kvCount += d.KVCount
return
})
console.Printf("We are going to remove %s files, until %s.\n",
console.Printf("We are going to truncate %s files, up to TS %s.\n",
em(fileCount),
em(formatTS(cfg.Until)),
)
if !cfg.SkipPrompt && !console.PromptBool(warn("Sure? ")) {
if !cfg.SkipPrompt && !console.PromptBool(warn("Are you sure?")) {
return nil
}
@ -1042,7 +1075,7 @@ func RunStreamTruncate(c context.Context, g glue.Glue, cmdName string, cfg *Stre
// begin to remove
p := console.StartProgressBar(
"Clearing Data Files and Metadata", fileCount,
"Truncating Data Files and Metadata", fileCount,
glue.WithTimeCost(),
glue.WithConstExtraField("kv-count", kvCount),
glue.WithConstExtraField("kv-size", fmt.Sprintf("%d(%s)", totalSize, units.HumanSize(float64(totalSize)))),
@ -1113,7 +1146,6 @@ func RunStreamRestore(
c context.Context,
mgr *conn.Mgr,
g glue.Glue,
cmdName string,
cfg *RestoreConfig,
) (err error) {
ctx, cancelFn := context.WithCancel(c)
@ -1132,6 +1164,8 @@ func RunStreamRestore(
if err != nil {
return errors.Trace(err)
}
// if not set by user, restore to the max TS available
if cfg.RestoreTS == 0 {
cfg.RestoreTS = logInfo.logMaxTS
}
@ -1157,7 +1191,7 @@ func RunStreamRestore(
}
}
log.Info("start restore on point",
log.Info("start point in time restore",
zap.Uint64("restore-from", cfg.StartTS), zap.Uint64("restore-to", cfg.RestoreTS),
zap.Uint64("log-min-ts", logInfo.logMinTS), zap.Uint64("log-max-ts", logInfo.logMaxTS))
if err := checkLogRange(cfg.StartTS, cfg.RestoreTS, logInfo.logMinTS, logInfo.logMaxTS); err != nil {
@ -1180,7 +1214,7 @@ func RunStreamRestore(
logStorage := cfg.Config.Storage
cfg.Config.Storage = cfg.FullBackupStorage
// TiFlash replica is restored to down-stream on 'pitr' currently.
if err = runRestore(ctx, mgr, g, FullRestoreCmd, cfg, checkInfo); err != nil {
if err = runSnapshotRestore(ctx, mgr, g, FullRestoreCmd, cfg, checkInfo); err != nil {
return errors.Trace(err)
}
cfg.Config.Storage = logStorage
@ -1191,9 +1225,6 @@ func RunStreamRestore(
}
if checkInfo.CheckpointInfo != nil && checkInfo.CheckpointInfo.Metadata != nil && checkInfo.CheckpointInfo.Metadata.TiFlashItems != nil {
log.Info("load tiflash records of snapshot restore from checkpoint")
if err != nil {
return errors.Trace(err)
}
cfg.tiflashRecorder.Load(checkInfo.CheckpointInfo.Metadata.TiFlashItems)
}
}
@ -1329,7 +1360,12 @@ func restoreStream(
}()
}
err = client.InstallLogFileManager(ctx, cfg.StartTS, cfg.RestoreTS, cfg.MetadataDownloadBatchSize)
encryptionManager, err := encryption.NewManager(&cfg.LogBackupCipherInfo, &cfg.MasterKeyConfig)
if err != nil {
return errors.Annotate(err, "failed to create encryption manager for log restore")
}
defer encryptionManager.Close()
err = client.InstallLogFileManager(ctx, cfg.StartTS, cfg.RestoreTS, cfg.MetadataDownloadBatchSize, encryptionManager)
if err != nil {
return err
}
@ -1345,12 +1381,13 @@ func restoreStream(
newTask = false
}
// get the schemas ID replace information.
// since targeted full backup storage, need to use the full backup cipher
schemasReplace, err := client.InitSchemasReplaceForDDL(ctx, &logclient.InitSchemaConfig{
IsNewTask: newTask,
TableFilter: cfg.TableFilter,
TiFlashRecorder: cfg.tiflashRecorder,
FullBackupStorage: fullBackupStorage,
})
}, &cfg.Config.CipherInfo)
if err != nil {
return errors.Trace(err)
}
@ -1431,7 +1468,8 @@ func restoreStream(
return errors.Trace(err)
}
return client.RestoreKVFiles(ctx, rewriteRules, idrules, logFilesIterWithSplit, checkpointRunner, cfg.PitrBatchCount, cfg.PitrBatchSize, updateStats, p.IncBy)
return client.RestoreKVFiles(ctx, rewriteRules, idrules, logFilesIterWithSplit, checkpointRunner,
cfg.PitrBatchCount, cfg.PitrBatchSize, updateStats, p.IncBy, &cfg.LogBackupCipherInfo, cfg.MasterKeyConfig.MasterKeys)
})
if err != nil {
return errors.Annotate(err, "failed to restore kv files")
@ -1547,15 +1585,15 @@ func getExternalStorageOptions(cfg *Config, u *backuppb.StorageBackend) storage.
}
}
func checkLogRange(restoreFrom, restoreTo, logMinTS, logMaxTS uint64) error {
// serveral ts constraint:
// logMinTS <= restoreFrom <= restoreTo <= logMaxTS
if logMinTS > restoreFrom || restoreFrom > restoreTo || restoreTo > logMaxTS {
func checkLogRange(restoreFromTS, restoreToTS, logMinTS, logMaxTS uint64) error {
// several ts constraint:
// logMinTS <= restoreFromTS <= restoreToTS <= logMaxTS
if logMinTS > restoreFromTS || restoreFromTS > restoreToTS || restoreToTS > logMaxTS {
return errors.Annotatef(berrors.ErrInvalidArgument,
"restore log from %d(%s) to %d(%s), "+
" but the current existed log from %d(%s) to %d(%s)",
restoreFrom, oracle.GetTimeFromTS(restoreFrom),
restoreTo, oracle.GetTimeFromTS(restoreTo),
restoreFromTS, oracle.GetTimeFromTS(restoreFromTS),
restoreToTS, oracle.GetTimeFromTS(restoreToTS),
logMinTS, oracle.GetTimeFromTS(logMinTS),
logMaxTS, oracle.GetTimeFromTS(logMaxTS),
)
@ -1648,7 +1686,7 @@ func getGlobalCheckpointFromStorage(ctx context.Context, s storage.ExternalStora
return globalCheckPointTS, errors.Trace(err)
}
// getFullBackupTS gets the snapshot-ts of full bakcup
// getFullBackupTS gets the snapshot-ts of full backup
func getFullBackupTS(
ctx context.Context,
cfg *RestoreConfig,
@ -1663,11 +1701,17 @@ func getFullBackupTS(
return 0, 0, errors.Trace(err)
}
backupmeta := &backuppb.BackupMeta{}
if err = backupmeta.Unmarshal(metaData); err != nil {
decryptedMetaData, err := metautil.DecryptFullBackupMetaIfNeeded(metaData, &cfg.CipherInfo)
if err != nil {
return 0, 0, errors.Trace(err)
}
backupmeta := &backuppb.BackupMeta{}
if err = backupmeta.Unmarshal(decryptedMetaData); err != nil {
return 0, 0, errors.Trace(err)
}
// start and end are identical in full backup, pick random one
return backupmeta.GetEndVersion(), backupmeta.GetClusterId(), nil
}
@ -1757,7 +1801,7 @@ func checkPiTRTaskInfo(
cfg *RestoreConfig,
) (*PiTRTaskInfo, error) {
var (
doFullRestore = (len(cfg.FullBackupStorage) > 0)
doFullRestore = len(cfg.FullBackupStorage) > 0
curTaskInfo *checkpoint.CheckpointTaskInfoForLogRestore
)
checkInfo := &PiTRTaskInfo{}

View File

@ -110,35 +110,6 @@ func TestCheckLogRange(t *testing.T) {
}
}
type fakeResolvedInfo struct {
storeID int64
resolvedTS uint64
}
func fakeMetaFiles(ctx context.Context, tempDir string, infos []fakeResolvedInfo) error {
backupMetaDir := filepath.Join(tempDir, stream.GetStreamBackupMetaPrefix())
s, err := storage.NewLocalStorage(backupMetaDir)
if err != nil {
return errors.Trace(err)
}
for _, info := range infos {
meta := &backuppb.Metadata{
StoreId: info.storeID,
ResolvedTs: info.resolvedTS,
}
buff, err := meta.Marshal()
if err != nil {
return errors.Trace(err)
}
filename := fmt.Sprintf("%d_%d.meta", info.storeID, info.resolvedTS)
if err = s.WriteFile(ctx, filename, buff); err != nil {
return errors.Trace(err)
}
}
return nil
}
func fakeCheckpointFiles(
ctx context.Context,
tmpDir string,
@ -154,7 +125,7 @@ func fakeCheckpointFiles(
for _, info := range infos {
filename := fmt.Sprintf("%v.ts", info.storeID)
buff := make([]byte, 8)
binary.LittleEndian.PutUint64(buff, info.global_checkpoint)
binary.LittleEndian.PutUint64(buff, info.globalCheckpoint)
if _, err := s.Create(ctx, filename, nil); err != nil {
return errors.Trace(err)
}
@ -170,8 +141,8 @@ func fakeCheckpointFiles(
}
type fakeGlobalCheckPoint struct {
storeID int64
global_checkpoint uint64
storeID int64
globalCheckpoint uint64
}
func TestGetGlobalCheckpointFromStorage(t *testing.T) {
@ -182,16 +153,16 @@ func TestGetGlobalCheckpointFromStorage(t *testing.T) {
infos := []fakeGlobalCheckPoint{
{
storeID: 1,
global_checkpoint: 98,
storeID: 1,
globalCheckpoint: 98,
},
{
storeID: 2,
global_checkpoint: 90,
storeID: 2,
globalCheckpoint: 90,
},
{
storeID: 2,
global_checkpoint: 99,
storeID: 2,
globalCheckpoint: 99,
},
}

View File

@ -7,6 +7,7 @@ go_library(
"db.go",
"dyn_pprof_other.go",
"dyn_pprof_unix.go",
"encryption.go",
"error_handling.go",
"json.go",
"key.go",
@ -35,6 +36,7 @@ go_library(
"//pkg/parser/types",
"//pkg/sessionctx",
"//pkg/util",
"//pkg/util/encrypt",
"//pkg/util/logutil",
"//pkg/util/sqlexec",
"@com_github_cheggaaa_pb_v3//:pb",
@ -43,6 +45,7 @@ go_library(
"@com_github_pingcap_errors//:errors",
"@com_github_pingcap_failpoint//:failpoint",
"@com_github_pingcap_kvproto//pkg/brpb",
"@com_github_pingcap_kvproto//pkg/encryptionpb",
"@com_github_pingcap_kvproto//pkg/metapb",
"@com_github_pingcap_log//:log",
"@com_github_tikv_client_go_v2//oracle",

View File

@ -0,0 +1,30 @@
package utils
import (
"github.com/pingcap/errors"
backuppb "github.com/pingcap/kvproto/pkg/brpb"
"github.com/pingcap/kvproto/pkg/encryptionpb"
berrors "github.com/pingcap/tidb/br/pkg/errors"
"github.com/pingcap/tidb/pkg/util/encrypt"
)
func Decrypt(content []byte, cipher *backuppb.CipherInfo, iv []byte) ([]byte, error) {
if len(content) == 0 || cipher == nil {
return content, nil
}
switch cipher.CipherType {
case encryptionpb.EncryptionMethod_PLAINTEXT:
return content, nil
case encryptionpb.EncryptionMethod_AES128_CTR,
encryptionpb.EncryptionMethod_AES192_CTR,
encryptionpb.EncryptionMethod_AES256_CTR:
return encrypt.AESDecryptWithCTR(content, cipher.CipherKey, iv)
default:
return content, errors.Annotatef(berrors.ErrInvalidArgument, "cipher type invalid %s", cipher.CipherType)
}
}
func IsEffectiveEncryptionMethod(method encryptionpb.EncryptionMethod) bool {
return method != encryptionpb.EncryptionMethod_UNKNOWN && method != encryptionpb.EncryptionMethod_PLAINTEXT
}

View File

@ -33,7 +33,7 @@ This folder contains all tests which relies on external processes such as TiDB.
## Preparations
1. The following 9 executables must be copied or linked into these locations:
1. The following 9 executables must be copied or linked into the `bin` folder under the TiDB root dir:
* `bin/tidb-server`
* `bin/tikv-server`
@ -80,7 +80,7 @@ If you have docker installed, you can skip step 1 and step 2 by running
1. Build `br.test` using `make build_for_br_integration_test`
2. Check that all 9 required executables and `br` executable exist
3. Select the tests to run using `export TEST_NAME="<test_name1> <test_name2> ..."`
3. Execute `tests/run.sh`
4. Execute `br/tests/run.sh`
<!-- 4. To start cluster with tiflash, please run `TIFLASH=1 tests/run.sh` -->
If the first two steps are done before, you could also run `tests/run.sh` directly.

435
br/tests/br_encryption/run.sh Executable file
View File

@ -0,0 +1,435 @@
#!/bin/sh
#
# Copyright 2024 PingCAP, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
set -eu
. run_services
CUR=$(cd "$(dirname "$0")" && pwd)
# const value
PREFIX="encryption_backup"
res_file="$TEST_DIR/sql_res.$TEST_NAME.txt"
DB="$TEST_NAME"
TABLE="usertable"
DB_COUNT=3
TASK_NAME="encryption_test"
create_db_with_table() {
for i in $(seq $DB_COUNT); do
run_sql "CREATE DATABASE $DB${i};"
go-ycsb load mysql -P $CUR/workload -p mysql.host=$TIDB_IP -p mysql.port=$TIDB_PORT -p mysql.user=root -p mysql.db=$DB${i} -p recordcount=1000
done
}
start_log_backup() {
_storage=$1
_encryption_args=$2
echo "start log backup task"
run_br --pd "$PD_ADDR" log start --task-name $TASK_NAME -s "$_storage" $_encryption_args
}
drop_db() {
for i in $(seq $DB_COUNT); do
run_sql "DROP DATABASE IF EXISTS $DB${i};"
done
}
insert_additional_data() {
local prefix=$1
for i in $(seq $DB_COUNT); do
go-ycsb load mysql -P $CUR/workload -p mysql.host=$TIDB_IP -p mysql.port=$TIDB_PORT -p mysql.user=root -p mysql.db=$DB${i} -p insertcount=1000 -p insertstart=1000000 -p recordcount=1001000 -p workload=core
done
}
wait_log_checkpoint_advance() {
echo "wait for log checkpoint to advance"
sleep 10
local current_ts=$(python3 -c "import time; print(int(time.time() * 1000) << 18)")
echo "current ts: $current_ts"
i=0
while true; do
# extract the checkpoint ts of the log backup task. If there is some error, the checkpoint ts should be empty
log_backup_status=$(unset BR_LOG_TO_TERM && run_br --skip-goleak --pd $PD_ADDR log status --task-name $TASK_NAME --json 2>br.log)
echo "log backup status: $log_backup_status"
local checkpoint_ts=$(echo "$log_backup_status" | head -n 1 | jq 'if .[0].last_errors | length == 0 then .[0].checkpoint else empty end')
echo "checkpoint ts: $checkpoint_ts"
# check whether the checkpoint ts is a number
if [ $checkpoint_ts -gt 0 ] 2>/dev/null; then
if [ $checkpoint_ts -gt $current_ts ]; then
echo "the checkpoint has advanced"
break
fi
echo "the checkpoint hasn't advanced"
i=$((i+1))
if [ "$i" -gt 50 ]; then
echo 'the checkpoint lag is too large'
exit 1
fi
sleep 10
else
echo "TEST: [$TEST_NAME] failed to wait checkpoint advance!"
exit 1
fi
done
}
calculate_checksum() {
local db=$1
local checksum=$(run_sql "USE $db; ADMIN CHECKSUM TABLE $TABLE;" | awk '/CHECKSUM/{print $2}')
echo $checksum
}
check_db_consistency() {
fail=false
for i in $(seq $DB_COUNT); do
local original_checksum=${checksum_ori[i]}
local new_checksum=$(calculate_checksum "$DB${i}")
if [ "$original_checksum" != "$new_checksum" ]; then
fail=true
echo "TEST: [$TEST_NAME] checksum mismatch on database $DB${i}"
echo "Original checksum: $original_checksum, New checksum: $new_checksum"
else
echo "Database $DB${i} checksum match: $new_checksum"
fi
done
if $fail; then
echo "TEST: [$TEST_NAME] data consistency check failed!"
return 1
fi
echo "TEST: [$TEST_NAME] data consistency check passed."
return 0
}
verify_dbs_empty() {
echo "Verifying databases are empty"
for i in $(seq $DB_COUNT); do
db_name="$DB${i}"
table_count=$(run_sql "USE $db_name; SHOW TABLES;" | wc -l)
if [ "$table_count" -ne 0 ]; then
echo "ERROR: Database $db_name is not empty"
return 1
fi
done
echo "All databases are empty"
return 0
}
run_backup_restore_test() {
local encryption_mode=$1
local full_encryption_args=$2
local log_encryption_args=$3
echo "===== run_backup_restore_test $encryption_mode $full_encryption_args $log_encryption_args ====="
restart_services || { echo "Failed to restart services"; exit 1; }
# Drop existing databases before starting the test
drop_db || { echo "Failed to drop databases"; exit 1; }
# Start log backup
start_log_backup "local://$TEST_DIR/$PREFIX/log" "$log_encryption_args" || { echo "Failed to start log backup"; exit 1; }
# Create test databases and insert initial data
create_db_with_table || { echo "Failed to create databases and tables"; exit 1; }
# Calculate and store original checksums
for i in $(seq $DB_COUNT); do
checksum_ori[${i}]=$(calculate_checksum "$DB${i}") || { echo "Failed to calculate initial checksum"; exit 1; }
done
# Full backup
echo "run full backup with $encryption_mode"
run_br --pd "$PD_ADDR" backup full -s "local://$TEST_DIR/$PREFIX/full" $full_encryption_args || { echo "Full backup failed"; exit 1; }
# Insert additional test data
insert_additional_data "${encryption_mode}_after_full_backup" || { echo "Failed to insert additional data"; exit 1; }
# Update checksums after inserting additional data
for i in $(seq $DB_COUNT); do
checksum_ori[${i}]=$(calculate_checksum "$DB${i}") || { echo "Failed to calculate checksum after insertion"; exit 1; }
done
wait_log_checkpoint_advance || { echo "Failed to wait for log checkpoint"; exit 1; }
#sanity check pause still works
run_br log pause --task-name $TASK_NAME --pd $PD_ADDR || { echo "Failed to pause log backup"; exit 1; }
#sanity check resume still works
run_br log resume --task-name $TASK_NAME --pd $PD_ADDR || { echo "Failed to resume log backup"; exit 1; }
#sanity check stop still works
run_br log stop --task-name $TASK_NAME --pd $PD_ADDR || { echo "Failed to stop log backup"; exit 1; }
# restart service should clean up everything
restart_services || { echo "Failed to restart services"; exit 1; }
verify_dbs_empty || { echo "Failed to verify databases are empty"; exit 1; }
# Run pitr restore and measure the performance
echo "restore log backup with $full_encryption_args and $log_encryption_args"
local start_time=$(date +%s.%N)
timeout 300 run_br --pd "$PD_ADDR" restore point -s "local://$TEST_DIR/$PREFIX/log" --full-backup-storage "local://$TEST_DIR/$PREFIX/full" $full_encryption_args $log_encryption_args || {
echo "Log backup restore failed or timed out after 5 minutes"
exit 1
}
local end_time=$(date +%s.%N)
local duration=$(echo "$end_time - $start_time" | bc | awk '{printf "%.3f", $0}')
echo "${encryption_mode} took ${duration} seconds"
echo "${encryption_mode},${duration}" >> "$TEST_DIR/performance_results.csv"
# Check data consistency after restore
echo "check data consistency after restore"
check_db_consistency || { echo "TEST: [$TEST_NAME] $encryption_mode backup and restore (including log) failed"; exit 1; }
# sanity check truncate still works
# make sure some files exists in log dir
log_dir="$TEST_DIR/$PREFIX/log"
if [ -z "$(ls -A $log_dir)" ]; then
echo "Error: No files found in the log directory $log_dir"
exit 1
else
echo "Files exist in the log directory $log_dir"
fi
current_time=$(date -u +"%Y-%m-%d %H:%M:%S+0000")
run_br log truncate -s "local://$TEST_DIR/$PREFIX/log" --until "$current_time" -y || { echo "Failed to truncate log backup"; exit 1; }
# make sure no files exist in log dir
if [ -z "$(ls -A $log_dir)" ]; then
echo "Error: Files still exist in the log directory $log_dir"
exit 1
else
echo "No files exist in the log directory $log_dir"
fi
# Clean up after the test
drop_db || { echo "Failed to drop databases after test"; exit 1; }
rm -rf "$TEST_DIR/$PREFIX"
echo "TEST: [$TEST_NAME] $encryption_mode backup and restore (including log) passed"
}
start_and_wait_for_localstack() {
# Start LocalStack in the background with only the required services
SERVICES=s3,ec2,kms localstack start -d
echo "Waiting for LocalStack services to be ready..."
max_attempts=30
attempt=0
while [ $attempt -lt $max_attempts ]; do
response=$(curl -s "http://localhost:4566/_localstack/health")
if echo "$response" | jq -e '.services.s3 == "running" or .services.s3 == "available"' > /dev/null && \
echo "$response" | jq -e '.services.ec2 == "running" or .services.ec2 == "available"' > /dev/null && \
echo "$response" | jq -e '.services.kms == "running" or .services.kms == "available"' > /dev/null; then
echo "LocalStack services are ready"
return 0
fi
attempt=$((attempt+1))
echo "Waiting for LocalStack services... Attempt $attempt of $max_attempts"
sleep 2
done
echo "LocalStack services did not become ready in time"
localstack stop
return 1
}
test_backup_encrypted_restore_unencrypted() {
echo "===== Testing backup with encryption, restore without encryption ====="
restart_services || { echo "Failed to restart services"; exit 1; }
# Start log backup
start_log_backup "local://$TEST_DIR/$PREFIX/log" "--log.crypter.method AES256-CTR --log.crypter.key 0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef" || { echo "Failed to start log backup"; exit 1; }
# Create test databases and insert initial data
create_db_with_table || { echo "Failed to create databases and tables"; exit 1; }
# Backup with encryption
run_br --pd $PD_ADDR backup full -s "local://$TEST_DIR/$PREFIX/full --crypter.method AES256-CTR --crypter.key 0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef"
# Insert additional test data
insert_additional_data "insert_after_full_backup" || { echo "Failed to insert additional data"; exit 1; }
wait_log_checkpoint_advance || { echo "Failed to wait for log checkpoint"; exit 1; }
# Stop and clean the cluster
restart_services || { echo "Failed to restart services"; exit 1; }
# Try to restore without encryption (this should fail)
if run_br --pd "$PD_ADDR" restore point -s "local://$TEST_DIR/$PREFIX/log" --full-backup-storage "local://$TEST_DIR/$PREFIX/full --crypter.method AES256-CTR --crypter.key 0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef"; then
echo "Error: Restore without encryption should have failed, but it succeeded"
exit 1
else
echo "Restore without encryption failed as expected"
fi
# Clean up after the test
drop_db || { echo "Failed to drop databases after test"; exit 1; }
rm -rf "$TEST_DIR/$PREFIX"
echo "TEST: test_backup_encrypted_restore_unencrypted passed"
}
test_plaintext() {
run_backup_restore_test "plaintext" "" ""
}
test_plaintext_data_key() {
run_backup_restore_test "plaintext-data-key" "--crypter.method AES256-CTR --crypter.key 0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef" "--log.crypter.method AES256-CTR --log.crypter.key 0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef"
}
test_local_master_key() {
_MASTER_KEY_DIR="$TEST_DIR/$PREFIX/master_key"
mkdir -p "$_MASTER_KEY_DIR"
openssl rand -hex 32 > "$_MASTER_KEY_DIR/master.key"
_MASTER_KEY_PATH="local:///$_MASTER_KEY_DIR/master.key"
run_backup_restore_test "local_master_key" "" "--master-key-crypter-method AES256-CTR --master-key $_MASTER_KEY_PATH"
rm -rf "$_MASTER_KEY_DIR"
}
test_aws_kms() {
# Start LocalStack and wait for services to be ready
if ! start_and_wait_for_localstack; then
echo "Failed to start LocalStack services"
exit 1
fi
# localstack listening port
ENDPOINT="http://localhost:4566"
# Create KMS key using curl
KMS_RESPONSE=$(curl -X POST "$ENDPOINT/kms/" \
-H "Content-Type: application/x-amz-json-1.1" \
-H "X-Amz-Target: TrentService.CreateKey" \
-d '{
"Description": "My test key",
"KeyUsage": "ENCRYPT_DECRYPT",
"Origin": "AWS_KMS"
}')
echo "KMS CreateKey response: $KMS_RESPONSE"
AWS_KMS_KEY_ID=$(echo "$KMS_RESPONSE" | jq -r '.KeyMetadata.KeyId')
AWS_ACCESS_KEY_ID="TEST"
AWS_SECRET_ACCESS_KEY="TEST"
REGION="us-east-1"
AWS_KMS_URI="aws-kms:///${AWS_KMS_KEY_ID}?AWS_ACCESS_KEY_ID=${AWS_ACCESS_KEY_ID}&AWS_SECRET_ACCESS_KEY=${AWS_SECRET_ACCESS_KEY}&REGION=${REGION}&ENDPOINT=${ENDPOINT}"
run_backup_restore_test "aws_kms" "--crypter.method AES256-CTR --crypter.key 0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef" "--master-key-crypter-method AES256-CTR --master-key $AWS_KMS_URI"
# Stop LocalStack
localstack stop
}
test_aws_kms_with_iam() {
# Start LocalStack and wait for services to be ready
if ! start_and_wait_for_localstack; then
echo "Failed to start LocalStack services"
exit 1
fi
# localstack listening port
ENDPOINT="http://localhost:4566"
# Create KMS key using curl
KMS_RESPONSE=$(curl -X POST "$ENDPOINT/kms/" \
-H "Content-Type: application/x-amz-json-1.1" \
-H "X-Amz-Target: TrentService.CreateKey" \
-d '{
"Description": "My test key",
"KeyUsage": "ENCRYPT_DECRYPT",
"Origin": "AWS_KMS"
}')
echo "KMS CreateKey response: $KMS_RESPONSE"
AWS_KMS_KEY_ID=$(echo "$KMS_RESPONSE" | jq -r '.KeyMetadata.KeyId')
# export these two as required by aws-kms backend
export AWS_ACCESS_KEY_ID="TEST"
export AWS_SECRET_ACCESS_KEY="TEST"
REGION="us-east-1"
AWS_KMS_URI="aws-kms:///${AWS_KMS_KEY_ID}?&REGION=${REGION}&ENDPOINT=${ENDPOINT}"
run_backup_restore_test "aws_kms" "--crypter.method AES256-CTR --crypter.key 0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef" "--master-key-crypter-method AES256-CTR --master-key $AWS_KMS_URI"
# Stop LocalStack
localstack stop
}
test_gcp_kms() {
# Ensure GCP credentials are set
if [ -z "$GOOGLE_APPLICATION_CREDENTIALS" ]; then
echo "GCP credentials not set. Skipping GCP KMS test."
return
fi
# Replace these with your actual GCP KMS details
GCP_PROJECT_ID="carbide-network-435219-q3"
GCP_LOCATION="us-west1"
GCP_KEY_RING="local-kms-testing"
GCP_KEY_NAME="kms-testing-key"
GCP_CREDENTIALS="$GOOGLE_APPLICATION_CREDENTIALS"
GCP_KMS_URI="gcp-kms:///projects/$GCP_PROJECT_ID/locations/$GCP_LOCATION/keyRings/$GCP_KEY_RING/cryptoKeys/$GCP_KEY_NAME?AUTH=specified&CREDENTIALS=$GCP_CREDENTIALS"
run_backup_restore_test "gcp_kms" "--crypter.method AES256-CTR --crypter.key 0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef" "--master-key-crypter-method AES256-CTR --master-key $GCP_KMS_URI"
}
test_mixed_full_encrypted_log_plain() {
local full_encryption_args="--crypter.method AES256-CTR --crypter.key 0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef"
local log_encryption_args=""
run_backup_restore_test "mixed_full_encrypted_log_plain" "$full_encryption_args" "$log_encryption_args"
}
test_mixed_full_plain_log_encrypted() {
local full_encryption_args=""
local log_encryption_args="--log.crypter.method AES256-CTR --log.crypter.key 0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef"
run_backup_restore_test "mixed_full_plain_log_encrypted" "$full_encryption_args" "$log_encryption_args"
}
# Initialize performance results file
echo "Operation,Encryption Mode,Duration (seconds)" > "$TEST_DIR/performance_results.csv"
# Run tests
test_backup_encrypted_restore_unencrypted
test_plaintext
test_plaintext_data_key
test_local_master_key
# some issue running in CI, will fix later
#test_aws_kms
#test_aws_kms_with_iam
test_mixed_full_encrypted_log_plain
test_mixed_full_plain_log_encrypted
# uncomment for manual GCP KMS testing
#test_gcp_kms
echo "All encryption tests passed successfully"
# Display performance results
echo "Performance Results:"
cat "$TEST_DIR/performance_results.csv"

View File

@ -0,0 +1,12 @@
recordcount=1000
operationcount=0
workload=core
readallfields=true
readproportion=0
updateproportion=0
scanproportion=0
insertproportion=0
requestdistribution=uniform

View File

@ -36,3 +36,6 @@ path = "/tmp/backup_restore_test/master-key-file"
[log-backup]
max-flush-interval = "50s"
[gc]
ratio-threshold = 1.1

View File

@ -43,6 +43,8 @@ minio_cli_url="${file_server_url}/download/builds/minio/minio/RELEASE.2020-02-27
kes_url="${file_server_url}/download/kes"
fake_gcs_server_url="${file_server_url}/download/builds/fake-gcs-server"
brv_url="${file_server_url}/download/builds/brv4.0.8"
# already manually uploaded to file server for localstack
localstack_url="${file_server_url}/download/localstack-cli.tar.gz"
set -o nounset
@ -103,6 +105,13 @@ function main() {
download "$fake_gcs_server_url" "fake-gcs-server" "third_bin/fake-gcs-server"
download "$brv_url" "brv4.0.8" "third_bin/brv4.0.8"
# Download and set up LocalStack
download "$localstack_url" "localstack-cli.tar.gz" "tmp/localstack-cli.tar.gz"
mkdir -p tmp/localstack_extract
tar -xzf tmp/localstack-cli.tar.gz -C tmp/localstack_extract
mv tmp/localstack_extract/localstack/* third_bin/
rm -rf tmp/localstack_extract
chmod +x third_bin/*
rm -rf tmp
rm -rf third_bin/bin

View File

@ -26,7 +26,7 @@ groups=(
["G03"]='br_incompatible_tidb_config br_incremental br_incremental_index br_incremental_only_ddl br_incremental_same_table br_insert_after_restore br_key_locked br_log_test br_move_backup br_mv_index br_other br_partition_add_index br_tidb_placement_policy br_tiflash br_tiflash_conflict'
["G04"]='br_range br_replica_read br_restore_TDE_enable br_restore_log_task_enable br_s3 br_shuffle_leader br_shuffle_region br_single_table'
["G05"]='br_skip_checksum br_split_region_fail br_systables br_table_filter br_txn br_stats br_clustered_index br_crypter'
["G06"]='br_tikv_outage br_tikv_outage3 br_restore_checkpoint'
["G06"]='br_tikv_outage br_tikv_outage3 br_restore_checkpoint br_encryption'
["G07"]='br_pitr'
["G08"]='br_tikv_outage2 br_ttl br_views_and_sequences br_z_gc_safepoint br_autorandom br_file_corruption'
)

View File

@ -1,9 +1,9 @@
package(default_visibility = ["//visibility:public"])
load("@io_bazel_rules_go//go:def.bzl", "go_library", "nogo")
load("@bazel_skylib//rules:common_settings.bzl", "bool_flag")
load("@io_bazel_rules_go//go:def.bzl", "go_library", "nogo")
load("//build/linter/staticcheck:def.bzl", "staticcheck_analyzers")
package(default_visibility = ["//visibility:public"])
bool_flag(
name = "with_nogo_flag",
build_setting_default = False,

3
go.mod
View File

@ -3,6 +3,7 @@ module github.com/pingcap/tidb
go 1.21
require (
cloud.google.com/go/kms v1.15.7
cloud.google.com/go/storage v1.38.0
github.com/Azure/azure-sdk-for-go/sdk/azcore v1.12.0
github.com/Azure/azure-sdk-for-go/sdk/azidentity v1.6.0
@ -305,7 +306,7 @@ require (
google.golang.org/genproto v0.0.0-20240227224415-6ceb2ff114de // indirect
google.golang.org/genproto/googleapis/api v0.0.0-20240318140521-94a12d6c2237 // indirect
google.golang.org/genproto/googleapis/rpc v0.0.0-20240515191416-fc5f0ca64291 // indirect
google.golang.org/protobuf v1.34.2 // indirect
google.golang.org/protobuf v1.34.2
gopkg.in/inf.v0 v0.9.1 // indirect
gopkg.in/natefinch/lumberjack.v2 v2.2.1 // indirect
gopkg.in/yaml.v3 v3.0.1 // indirect

View File

@ -47,8 +47,10 @@ stop() {
}
restart_services() {
echo "Restarting services"
stop_services
start_services
echo "Services restarted"
}
stop_services() {