lightning: refactor schema importer (#52334)

ref pingcap/tidb#52142
This commit is contained in:
D3Hunter
2024-04-24 13:52:41 +08:00
committed by GitHub
parent 0805e850d4
commit 773ce7e07e
9 changed files with 761 additions and 711 deletions

View File

@ -54,7 +54,6 @@ go_library(
"//pkg/meta/autoid",
"//pkg/parser",
"//pkg/parser/ast",
"//pkg/parser/format",
"//pkg/parser/model",
"//pkg/parser/mysql",
"//pkg/planner/core",
@ -116,7 +115,6 @@ go_test(
"meta_manager_test.go",
"precheck_impl_test.go",
"precheck_test.go",
"restore_schema_test.go",
"table_import_test.go",
"tidb_test.go",
],

View File

@ -56,7 +56,6 @@ import (
"github.com/pingcap/tidb/pkg/lightning/tikv"
"github.com/pingcap/tidb/pkg/lightning/worker"
"github.com/pingcap/tidb/pkg/meta/autoid"
"github.com/pingcap/tidb/pkg/parser"
"github.com/pingcap/tidb/pkg/parser/model"
"github.com/pingcap/tidb/pkg/session"
"github.com/pingcap/tidb/pkg/sessionctx/variable"
@ -583,324 +582,19 @@ outside:
return errors.Trace(err)
}
type schemaStmtType int
// String implements fmt.Stringer interface.
func (stmtType schemaStmtType) String() string {
switch stmtType {
case schemaCreateDatabase:
return "restore database schema"
case schemaCreateTable:
return "restore table schema"
case schemaCreateView:
return "restore view schema"
}
return "unknown statement of schema"
}
const (
schemaCreateDatabase schemaStmtType = iota
schemaCreateTable
schemaCreateView
)
type schemaJob struct {
dbName string
tblName string // empty for create db jobs
stmtType schemaStmtType
stmts []string
}
type restoreSchemaWorker struct {
ctx context.Context
quit context.CancelFunc
logger log.Logger
jobCh chan *schemaJob
errCh chan error
wg sync.WaitGroup
db *sql.DB
parser *parser.Parser
store storage.ExternalStorage
}
func (worker *restoreSchemaWorker) addJob(sqlStr string, job *schemaJob) error {
stmts, err := createIfNotExistsStmt(worker.parser, sqlStr, job.dbName, job.tblName)
if err != nil {
return errors.Trace(err)
}
job.stmts = stmts
return worker.appendJob(job)
}
func (worker *restoreSchemaWorker) makeJobs(
dbMetas []*mydump.MDDatabaseMeta,
getDBs func(context.Context) ([]*model.DBInfo, error),
getTables func(context.Context, string) ([]*model.TableInfo, error),
) error {
defer func() {
close(worker.jobCh)
worker.quit()
}()
if len(dbMetas) == 0 {
return nil
}
// 1. restore databases, execute statements concurrency
dbs, err := getDBs(worker.ctx)
if err != nil {
worker.logger.Warn("get databases from downstream failed", zap.Error(err))
}
dbSet := make(set.StringSet, len(dbs))
for _, db := range dbs {
dbSet.Insert(db.Name.L)
}
for _, dbMeta := range dbMetas {
// if downstream already has this database, we can skip ddl job
if dbSet.Exist(strings.ToLower(dbMeta.Name)) {
worker.logger.Info(
"database already exists in downstream, skip processing the source file",
zap.String("db", dbMeta.Name),
)
continue
}
sql := dbMeta.GetSchema(worker.ctx, worker.store)
err = worker.addJob(sql, &schemaJob{
dbName: dbMeta.Name,
tblName: "",
stmtType: schemaCreateDatabase,
})
if err != nil {
return err
}
}
err = worker.wait()
if err != nil {
return err
}
// 2. restore tables, execute statements concurrency
for _, dbMeta := range dbMetas {
// we can ignore error here, and let check failed later if schema not match
tables, err := getTables(worker.ctx, dbMeta.Name)
if err != nil {
worker.logger.Warn("get tables from downstream failed", zap.Error(err))
}
tableSet := make(set.StringSet, len(tables))
for _, t := range tables {
tableSet.Insert(t.Name.L)
}
for _, tblMeta := range dbMeta.Tables {
if tableSet.Exist(strings.ToLower(tblMeta.Name)) {
// we already has this table in TiDB.
// we should skip ddl job and let SchemaValid check.
worker.logger.Info(
"table already exists in downstream, skip processing the source file",
zap.String("db", dbMeta.Name),
zap.String("table", tblMeta.Name),
)
continue
} else if tblMeta.SchemaFile.FileMeta.Path == "" {
return common.ErrSchemaNotExists.GenWithStackByArgs(dbMeta.Name, tblMeta.Name)
}
sql, err := tblMeta.GetSchema(worker.ctx, worker.store)
if err != nil {
return err
}
if sql != "" {
err = worker.addJob(sql, &schemaJob{
dbName: dbMeta.Name,
tblName: tblMeta.Name,
stmtType: schemaCreateTable,
})
if err != nil {
return err
}
}
}
}
err = worker.wait()
if err != nil {
return err
}
// 3. restore views. Since views can cross database we must restore views after all table schemas are restored.
for _, dbMeta := range dbMetas {
for _, viewMeta := range dbMeta.Views {
sql, err := viewMeta.GetSchema(worker.ctx, worker.store)
if sql != "" {
err = worker.addJob(sql, &schemaJob{
dbName: dbMeta.Name,
tblName: viewMeta.Name,
stmtType: schemaCreateView,
})
if err != nil {
return err
}
// we don't support restore views concurrency, cauz it maybe will raise a error
err = worker.wait()
if err != nil {
return err
}
}
if err != nil {
return err
}
}
}
return nil
}
func (worker *restoreSchemaWorker) doJob() {
var session *sql.Conn
defer func() {
if session != nil {
_ = session.Close()
}
}()
loop:
for {
select {
case <-worker.ctx.Done():
// don't `return` or throw `worker.ctx.Err()`here,
// if we `return`, we can't mark cancelled jobs as done,
// if we `throw(worker.ctx.Err())`, it will be blocked to death
break loop
case job := <-worker.jobCh:
if job == nil {
// successful exit
return
}
var err error
if session == nil {
session, err = func() (*sql.Conn, error) {
return worker.db.Conn(worker.ctx)
}()
if err != nil {
worker.wg.Done()
worker.throw(err)
// don't return
break loop
}
}
logger := worker.logger.With(zap.String("db", job.dbName), zap.String("table", job.tblName))
sqlWithRetry := common.SQLWithRetry{
Logger: worker.logger,
DB: session,
}
for _, stmt := range job.stmts {
task := logger.Begin(zap.DebugLevel, fmt.Sprintf("execute SQL: %s", stmt))
err = sqlWithRetry.Exec(worker.ctx, "run create schema job", stmt)
if err != nil {
// try to imitate IF NOT EXISTS behavior for parsing errors
exists := false
switch job.stmtType {
case schemaCreateDatabase:
var err2 error
exists, err2 = common.SchemaExists(worker.ctx, session, job.dbName)
if err2 != nil {
task.Error("failed to check database existence", zap.Error(err2))
}
case schemaCreateTable:
exists, _ = common.TableExists(worker.ctx, session, job.dbName, job.tblName)
}
if exists {
err = nil
}
}
task.End(zap.ErrorLevel, err)
if err != nil {
err = common.ErrCreateSchema.Wrap(err).GenWithStackByArgs(common.UniqueTable(job.dbName, job.tblName), job.stmtType.String())
worker.wg.Done()
worker.throw(err)
// don't return
break loop
}
}
worker.wg.Done()
}
}
// mark the cancelled job as `Done`, a little tricky,
// cauz we need make sure `worker.wg.Wait()` wouldn't blocked forever
for range worker.jobCh {
worker.wg.Done()
}
}
func (worker *restoreSchemaWorker) wait() error {
// avoid to `worker.wg.Wait()` blocked forever when all `doJob`'s goroutine exited.
// don't worry about goroutine below, it never become a zombie,
// cauz we have mechanism to clean cancelled jobs in `worker.jobCh`.
// means whole jobs has been send to `worker.jobCh` would be done.
waitCh := make(chan struct{})
go func() {
worker.wg.Wait()
close(waitCh)
}()
select {
case err := <-worker.errCh:
return err
case <-worker.ctx.Done():
return worker.ctx.Err()
case <-waitCh:
return nil
}
}
func (worker *restoreSchemaWorker) throw(err error) {
select {
case <-worker.ctx.Done():
// don't throw `worker.ctx.Err()` again, it will be blocked to death.
return
case worker.errCh <- err:
worker.quit()
}
}
func (worker *restoreSchemaWorker) appendJob(job *schemaJob) error {
worker.wg.Add(1)
select {
case err := <-worker.errCh:
// cancel the job
worker.wg.Done()
return err
case <-worker.ctx.Done():
// cancel the job
worker.wg.Done()
return errors.Trace(worker.ctx.Err())
case worker.jobCh <- job:
return nil
}
}
func (rc *Controller) restoreSchema(ctx context.Context) error {
// create table with schema file
// we can handle the duplicated created with createIfNotExist statement
// and we will check the schema in TiDB is valid with the datafile in DataCheck later.
logTask := log.FromContext(ctx).Begin(zap.InfoLevel, "restore all schema")
logger := log.FromContext(ctx)
concurrency := min(rc.cfg.App.RegionConcurrency, 8)
childCtx, cancel := context.WithCancel(ctx)
p := parser.New()
p.SetSQLMode(rc.cfg.TiDB.SQLMode)
worker := restoreSchemaWorker{
ctx: childCtx,
quit: cancel,
logger: log.FromContext(ctx),
jobCh: make(chan *schemaJob, concurrency),
errCh: make(chan error),
db: rc.db,
parser: p,
store: rc.store,
}
for i := 0; i < concurrency; i++ {
go worker.doJob()
}
err := worker.makeJobs(rc.dbMetas, rc.preInfoGetter.FetchRemoteDBModels, rc.preInfoGetter.FetchRemoteTableModels)
logTask.End(zap.ErrorLevel, err)
// sql.DB is a connection pool, we set it to concurrency + 1(for job generator)
// to reuse connections, as we might call db.Conn/conn.Close many times.
// there's no API to get sql.DB.MaxIdleConns, so we revert to its default which is 2
rc.db.SetMaxIdleConns(concurrency + 1)
defer rc.db.SetMaxIdleConns(2)
schemaImp := mydump.NewSchemaImporter(logger, rc.cfg.TiDB.SQLMode, rc.db, rc.store, concurrency)
err := schemaImp.Run(ctx, rc.dbMetas)
if err != nil {
return err
}

View File

@ -1,247 +0,0 @@
// Copyright 2022 PingCAP, Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package importer
import (
"context"
stderrors "errors"
"fmt"
"testing"
"github.com/DATA-DOG/go-sqlmock"
"github.com/pingcap/errors"
"github.com/pingcap/tidb/br/pkg/mock"
"github.com/pingcap/tidb/br/pkg/storage"
"github.com/pingcap/tidb/pkg/ddl"
"github.com/pingcap/tidb/pkg/lightning/backend"
"github.com/pingcap/tidb/pkg/lightning/checkpoints"
"github.com/pingcap/tidb/pkg/lightning/config"
"github.com/pingcap/tidb/pkg/lightning/mydump"
"github.com/pingcap/tidb/pkg/parser"
"github.com/pingcap/tidb/pkg/parser/ast"
"github.com/pingcap/tidb/pkg/parser/model"
"github.com/pingcap/tidb/pkg/parser/mysql"
tmock "github.com/pingcap/tidb/pkg/util/mock"
filter "github.com/pingcap/tidb/pkg/util/table-filter"
"github.com/stretchr/testify/require"
"github.com/stretchr/testify/suite"
"go.uber.org/mock/gomock"
)
type restoreSchemaSuite struct {
suite.Suite
ctx context.Context
rc *Controller
controller *gomock.Controller
dbMock sqlmock.Sqlmock
tableInfos []*model.TableInfo
infoGetter *PreImportInfoGetterImpl
targetInfoGetter *TargetInfoGetterImpl
}
func TestRestoreSchemaSuite(t *testing.T) {
suite.Run(t, new(restoreSchemaSuite))
}
func (s *restoreSchemaSuite) SetupSuite() {
ctx := context.Background()
fakeDataDir := s.T().TempDir()
store, err := storage.NewLocalStorage(fakeDataDir)
require.NoError(s.T(), err)
// restore database schema file
fakeDBName := "fakedb"
// please follow the `mydump.defaultFileRouteRules`, matches files like '{schema}-schema-create.sql'
fakeFileName := fmt.Sprintf("%s-schema-create.sql", fakeDBName)
err = store.WriteFile(ctx, fakeFileName, []byte(fmt.Sprintf("CREATE DATABASE %s;", fakeDBName)))
require.NoError(s.T(), err)
// restore table schema files
fakeTableFilesCount := 8
p := parser.New()
p.SetSQLMode(mysql.ModeANSIQuotes)
se := tmock.NewContext()
tableInfos := make([]*model.TableInfo, 0, fakeTableFilesCount)
for i := 1; i <= fakeTableFilesCount; i++ {
fakeTableName := fmt.Sprintf("tbl%d", i)
// please follow the `mydump.defaultFileRouteRules`, matches files like '{schema}.{table}-schema.sql'
fakeFileName := fmt.Sprintf("%s.%s-schema.sql", fakeDBName, fakeTableName)
fakeFileContent := fmt.Sprintf("CREATE TABLE %s(i TINYINT);", fakeTableName)
err = store.WriteFile(ctx, fakeFileName, []byte(fakeFileContent))
require.NoError(s.T(), err)
node, err := p.ParseOneStmt(fakeFileContent, "", "")
require.NoError(s.T(), err)
core, err := ddl.MockTableInfo(se, node.(*ast.CreateTableStmt), 0xabcdef)
require.NoError(s.T(), err)
core.State = model.StatePublic
tableInfos = append(tableInfos, core)
}
s.tableInfos = tableInfos
// restore view schema files
fakeViewFilesCount := 8
for i := 1; i <= fakeViewFilesCount; i++ {
fakeViewName := fmt.Sprintf("tbl%d", i)
// please follow the `mydump.defaultFileRouteRules`, matches files like '{schema}.{table}-schema-view.sql'
fakeFileName := fmt.Sprintf("%s.%s-schema-view.sql", fakeDBName, fakeViewName)
fakeFileContent := []byte(fmt.Sprintf("CREATE ALGORITHM=UNDEFINED VIEW `%s` (`i`) AS SELECT `i` FROM `%s`.`%s`;", fakeViewName, fakeDBName, fmt.Sprintf("tbl%d", i)))
err = store.WriteFile(ctx, fakeFileName, fakeFileContent)
require.NoError(s.T(), err)
}
config := config.NewConfig()
config.Mydumper.DefaultFileRules = true
config.Mydumper.CharacterSet = "utf8mb4"
config.App.RegionConcurrency = 8
mydumpLoader, err := mydump.NewLoaderWithStore(ctx, mydump.NewLoaderCfg(config), store)
s.Require().NoError(err)
dbMetas := mydumpLoader.GetDatabases()
targetInfoGetter := &TargetInfoGetterImpl{
cfg: config,
}
preInfoGetter := &PreImportInfoGetterImpl{
cfg: config,
srcStorage: store,
targetInfoGetter: targetInfoGetter,
dbMetas: dbMetas,
}
preInfoGetter.Init()
s.rc = &Controller{
checkTemplate: NewSimpleTemplate(),
cfg: config,
store: store,
dbMetas: dbMetas,
checkpointsDB: &checkpoints.NullCheckpointsDB{},
preInfoGetter: preInfoGetter,
}
s.infoGetter = preInfoGetter
s.targetInfoGetter = targetInfoGetter
}
//nolint:interfacer // change test case signature might cause Check failed to find this test case?
func (s *restoreSchemaSuite) SetupTest() {
s.controller, s.ctx = gomock.WithContext(context.Background(), s.T())
mockTargetInfoGetter := mock.NewMockTargetInfoGetter(s.controller)
mockBackend := mock.NewMockBackend(s.controller)
mockTargetInfoGetter.EXPECT().
FetchRemoteDBModels(gomock.Any()).
AnyTimes().
Return([]*model.DBInfo{{Name: model.NewCIStr("fakedb")}}, nil)
mockTargetInfoGetter.EXPECT().
FetchRemoteTableModels(gomock.Any(), gomock.Any()).
AnyTimes().
Return(s.tableInfos, nil)
mockBackend.EXPECT().Close()
theBackend := backend.MakeEngineManager(mockBackend)
s.rc.engineMgr = theBackend
s.rc.backend = mockBackend
s.targetInfoGetter.backend = mockTargetInfoGetter
mockDB, sqlMock, err := sqlmock.New()
require.NoError(s.T(), err)
for i := 0; i < 17; i++ {
sqlMock.ExpectExec(".*").WillReturnResult(sqlmock.NewResult(int64(i), 1))
}
s.targetInfoGetter.db = mockDB
s.rc.db = mockDB
s.dbMock = sqlMock
}
func (s *restoreSchemaSuite) TearDownTest() {
s.rc.Close()
s.controller.Finish()
}
func (s *restoreSchemaSuite) TestRestoreSchemaSuccessful() {
// before restore, if sysVars is initialized by other test, the time_zone should be default value
if len(s.rc.sysVars) > 0 {
tz, ok := s.rc.sysVars["time_zone"]
require.True(s.T(), ok)
require.Equal(s.T(), "SYSTEM", tz)
}
s.dbMock.ExpectQuery(".*").WillReturnRows(sqlmock.NewRows([]string{"time_zone"}).AddRow("SYSTEM"))
s.rc.cfg.TiDB.Vars = map[string]string{
"time_zone": "UTC",
}
err := s.rc.restoreSchema(s.ctx)
require.NoError(s.T(), err)
// test after restore schema, sysVars has been updated
tz, ok := s.rc.sysVars["time_zone"]
require.True(s.T(), ok)
require.Equal(s.T(), "UTC", tz)
}
func (s *restoreSchemaSuite) TestRestoreSchemaFailed() {
// use injectErr which cannot be retried
injectErr := stderrors.New("could not match actual sql")
mockDB, sqlMock, err := sqlmock.New()
require.NoError(s.T(), err)
sqlMock.ExpectExec(".*").WillReturnError(injectErr)
for i := 0; i < 16; i++ {
sqlMock.ExpectExec(".*").WillReturnResult(sqlmock.NewResult(int64(i), 1))
}
s.rc.db = mockDB
s.targetInfoGetter.db = mockDB
err = s.rc.restoreSchema(s.ctx)
require.Error(s.T(), err)
require.True(s.T(), errors.ErrorEqual(err, injectErr))
}
// When restoring a CSV with `-no-schema` and the target table doesn't exist
// then we can't restore the schema as the `Path` is empty. This is to make
// sure this results in the correct error.
// https://github.com/pingcap/br/issues/1394
func (s *restoreSchemaSuite) TestNoSchemaPath() {
fakeTable := mydump.MDTableMeta{
DB: "fakedb",
Name: "fake1",
SchemaFile: mydump.FileInfo{
TableName: filter.Table{
Schema: "fakedb",
Name: "fake1",
},
FileMeta: mydump.SourceFileMeta{
Path: "",
},
},
DataFiles: []mydump.FileInfo{},
TotalSize: 0,
}
s.rc.dbMetas[0].Tables = append(s.rc.dbMetas[0].Tables, &fakeTable)
err := s.rc.restoreSchema(s.ctx)
require.Error(s.T(), err)
require.Regexp(s.T(), `table .* schema not found`, err.Error())
s.rc.dbMetas[0].Tables = s.rc.dbMetas[0].Tables[:len(s.rc.dbMetas[0].Tables)-1]
}
func (s *restoreSchemaSuite) TestRestoreSchemaContextCancel() {
childCtx, cancel := context.WithCancel(s.ctx)
mockDB, sqlMock, err := sqlmock.New()
require.NoError(s.T(), err)
for i := 0; i < 17; i++ {
sqlMock.ExpectExec(".*").WillReturnResult(sqlmock.NewResult(int64(i), 1))
}
s.rc.db = mockDB
s.targetInfoGetter.db = mockDB
cancel()
err = s.rc.restoreSchema(childCtx)
require.Error(s.T(), err)
err = errors.Cause(err)
require.Equal(s.T(), childCtx.Err(), err)
}

View File

@ -31,8 +31,6 @@ import (
"github.com/pingcap/tidb/pkg/lightning/metric"
"github.com/pingcap/tidb/pkg/lightning/mydump"
"github.com/pingcap/tidb/pkg/parser"
"github.com/pingcap/tidb/pkg/parser/ast"
"github.com/pingcap/tidb/pkg/parser/format"
"github.com/pingcap/tidb/pkg/parser/model"
"github.com/pingcap/tidb/pkg/parser/mysql"
"github.com/pingcap/tidb/pkg/sessionctx/variable"
@ -131,47 +129,6 @@ func (timgr *TiDBManager) Close() {
timgr.db.Close()
}
func createIfNotExistsStmt(p *parser.Parser, createTable, dbName, tblName string) ([]string, error) {
stmts, _, err := p.ParseSQL(createTable)
if err != nil {
return []string{}, common.ErrInvalidSchemaStmt.Wrap(err).GenWithStackByArgs(createTable)
}
var res strings.Builder
ctx := format.NewRestoreCtx(format.DefaultRestoreFlags|format.RestoreTiDBSpecialComment|format.RestoreWithTTLEnableOff, &res)
retStmts := make([]string, 0, len(stmts))
for _, stmt := range stmts {
switch node := stmt.(type) {
case *ast.CreateDatabaseStmt:
node.Name = model.NewCIStr(dbName)
node.IfNotExists = true
case *ast.DropDatabaseStmt:
node.Name = model.NewCIStr(dbName)
node.IfExists = true
case *ast.CreateTableStmt:
node.Table.Schema = model.NewCIStr(dbName)
node.Table.Name = model.NewCIStr(tblName)
node.IfNotExists = true
case *ast.CreateViewStmt:
node.ViewName.Schema = model.NewCIStr(dbName)
node.ViewName.Name = model.NewCIStr(tblName)
case *ast.DropTableStmt:
node.Tables[0].Schema = model.NewCIStr(dbName)
node.Tables[0].Name = model.NewCIStr(tblName)
node.IfExists = true
}
if err := stmt.Restore(ctx); err != nil {
return []string{}, common.ErrInvalidSchemaStmt.Wrap(err).GenWithStackByArgs(createTable)
}
ctx.WritePlain(";")
retStmts = append(retStmts, res.String())
res.Reset()
}
return retStmts, nil
}
// DropTable drops a table.
func (timgr *TiDBManager) DropTable(ctx context.Context, tableName string) error {
sql := common.SQLWithRetry{

View File

@ -28,7 +28,6 @@ import (
"github.com/pingcap/tidb/pkg/lightning/checkpoints"
"github.com/pingcap/tidb/pkg/lightning/metric"
"github.com/pingcap/tidb/pkg/lightning/mydump"
"github.com/pingcap/tidb/pkg/parser"
"github.com/pingcap/tidb/pkg/parser/ast"
"github.com/pingcap/tidb/pkg/parser/model"
tmysql "github.com/pingcap/tidb/pkg/parser/mysql"
@ -61,109 +60,6 @@ func newTiDBSuite(t *testing.T) *tidbSuite {
return &s
}
func TestCreateTableIfNotExistsStmt(t *testing.T) {
dbName := "testdb"
p := parser.New()
createSQLIfNotExistsStmt := func(createTable, tableName string) []string {
res, err := createIfNotExistsStmt(p, createTable, dbName, tableName)
require.NoError(t, err)
return res
}
require.Equal(t, []string{"CREATE DATABASE IF NOT EXISTS `testdb` CHARACTER SET = utf8 COLLATE = utf8_general_ci;"},
createSQLIfNotExistsStmt("CREATE DATABASE `foo` CHARACTER SET = utf8 COLLATE = utf8_general_ci;", ""))
require.Equal(t, []string{"CREATE TABLE IF NOT EXISTS `testdb`.`foo` (`bar` TINYINT(1));"},
createSQLIfNotExistsStmt("CREATE TABLE `foo`(`bar` TINYINT(1));", "foo"))
require.Equal(t, []string{"CREATE TABLE IF NOT EXISTS `testdb`.`foo` (`bar` TINYINT(1));"},
createSQLIfNotExistsStmt("CREATE TABLE IF NOT EXISTS `foo`(`bar` TINYINT(1));", "foo"))
// case insensitive
require.Equal(t, []string{"CREATE TABLE IF NOT EXISTS `testdb`.`fOo` (`bar` TINYINT(1));"},
createSQLIfNotExistsStmt("/* cOmmEnt */ creAte tablE `fOo`(`bar` TinyinT(1));", "fOo"))
require.Equal(t, []string{"CREATE TABLE IF NOT EXISTS `testdb`.`FoO` (`bAR` TINYINT(1));"},
createSQLIfNotExistsStmt("/* coMMenT */ crEatE tAble If not EXISts `FoO`(`bAR` tiNyInT(1));", "FoO"))
// only one "CREATE TABLE" is replaced
require.Equal(t, []string{"CREATE TABLE IF NOT EXISTS `testdb`.`foo` (`bar` INT(1) COMMENT 'CREATE TABLE');"},
createSQLIfNotExistsStmt("CREATE TABLE `foo`(`bar` INT(1) COMMENT 'CREATE TABLE');", "foo"))
// test clustered index consistency
require.Equal(t, []string{"CREATE TABLE IF NOT EXISTS `testdb`.`foo` (`bar` INT(1) PRIMARY KEY /*T![clustered_index] CLUSTERED */ COMMENT 'CREATE TABLE');"},
createSQLIfNotExistsStmt("CREATE TABLE `foo`(`bar` INT(1) PRIMARY KEY CLUSTERED COMMENT 'CREATE TABLE');", "foo"))
require.Equal(t, []string{"CREATE TABLE IF NOT EXISTS `testdb`.`foo` (`bar` INT(1) COMMENT 'CREATE TABLE',PRIMARY KEY(`bar`) /*T![clustered_index] NONCLUSTERED */);"},
createSQLIfNotExistsStmt("CREATE TABLE `foo`(`bar` INT(1) COMMENT 'CREATE TABLE', PRIMARY KEY (`bar`) NONCLUSTERED);", "foo"))
require.Equal(t, []string{"CREATE TABLE IF NOT EXISTS `testdb`.`foo` (`bar` INT(1) PRIMARY KEY /*T![clustered_index] NONCLUSTERED */ COMMENT 'CREATE TABLE');"},
createSQLIfNotExistsStmt("CREATE TABLE `foo`(`bar` INT(1) PRIMARY KEY /*T![clustered_index] NONCLUSTERED */ COMMENT 'CREATE TABLE');", "foo"))
require.Equal(t, []string{"CREATE TABLE IF NOT EXISTS `testdb`.`foo` (`bar` INT(1) COMMENT 'CREATE TABLE',PRIMARY KEY(`bar`) /*T![clustered_index] CLUSTERED */);"},
createSQLIfNotExistsStmt("CREATE TABLE `foo`(`bar` INT(1) COMMENT 'CREATE TABLE', PRIMARY KEY (`bar`) /*T![clustered_index] CLUSTERED */);", "foo"))
require.Equal(t, []string{"CREATE TABLE IF NOT EXISTS `testdb`.`foo` (`bar` INT(1) PRIMARY KEY /*T![auto_rand] AUTO_RANDOM(2) */ COMMENT 'CREATE TABLE');"},
createSQLIfNotExistsStmt("CREATE TABLE `foo`(`bar` INT(1) PRIMARY KEY AUTO_RANDOM(2) COMMENT 'CREATE TABLE');", "foo"))
// upper case becomes shorter
require.Equal(t, []string{"CREATE TABLE IF NOT EXISTS `testdb`.`ſ` (`ı` TINYINT(1));"},
createSQLIfNotExistsStmt("CREATE TABLE `ſ`(`ı` TINYINT(1));", "ſ"))
// upper case becomes longer
require.Equal(t, []string{"CREATE TABLE IF NOT EXISTS `testdb`.`ɑ` (`ȿ` TINYINT(1));"},
createSQLIfNotExistsStmt("CREATE TABLE `ɑ`(`ȿ` TINYINT(1));", "ɑ"))
// non-utf-8
require.Equal(t, []string{"CREATE TABLE IF NOT EXISTS `testdb`.`\xcc\xcc\xcc` (`???` TINYINT(1));"},
createSQLIfNotExistsStmt("CREATE TABLE `\xcc\xcc\xcc`(`\xdd\xdd\xdd` TINYINT(1));", "\xcc\xcc\xcc"))
// renaming a table
require.Equal(t, []string{"CREATE TABLE IF NOT EXISTS `testdb`.`ba``r` (`x` INT);"},
createSQLIfNotExistsStmt("create table foo(x int);", "ba`r"))
// conditional comments
require.Equal(t, []string{
"SET NAMES 'binary';",
"SET @@SESSION.`FOREIGN_KEY_CHECKS`=0;",
"CREATE TABLE IF NOT EXISTS `testdb`.`m` (`z` DOUBLE) ENGINE = InnoDB AUTO_INCREMENT = 8343230 DEFAULT CHARACTER SET = UTF8;",
},
createSQLIfNotExistsStmt(`
/*!40101 SET NAMES binary*/;
/*!40014 SET FOREIGN_KEY_CHECKS=0*/;
CREATE TABLE x.y (z double) ENGINE=InnoDB AUTO_INCREMENT=8343230 DEFAULT CHARSET=utf8;
`, "m"))
// create view
require.Equal(t, []string{
"SET NAMES 'binary';",
"DROP TABLE IF EXISTS `testdb`.`m`;",
"DROP VIEW IF EXISTS `testdb`.`m`;",
"SET @`PREV_CHARACTER_SET_CLIENT`=@@`character_set_client`;",
"SET @`PREV_CHARACTER_SET_RESULTS`=@@`character_set_results`;",
"SET @`PREV_COLLATION_CONNECTION`=@@`collation_connection`;",
"SET @@SESSION.`character_set_client`=`utf8`;",
"SET @@SESSION.`character_set_results`=`utf8`;",
"SET @@SESSION.`collation_connection`=`utf8_general_ci`;",
"CREATE ALGORITHM = UNDEFINED DEFINER = `root`@`192.168.198.178` SQL SECURITY DEFINER VIEW `testdb`.`m` (`s`) AS SELECT `s` FROM `db1`.`v1` WHERE `i`<2;",
"SET @@SESSION.`character_set_client`=@`PREV_CHARACTER_SET_CLIENT`;",
"SET @@SESSION.`character_set_results`=@`PREV_CHARACTER_SET_RESULTS`;",
"SET @@SESSION.`collation_connection`=@`PREV_COLLATION_CONNECTION`;",
},
createSQLIfNotExistsStmt(`
/*!40101 SET NAMES binary*/;
DROP TABLE IF EXISTS v2;
DROP VIEW IF EXISTS v2;
SET @PREV_CHARACTER_SET_CLIENT=@@CHARACTER_SET_CLIENT;
SET @PREV_CHARACTER_SET_RESULTS=@@CHARACTER_SET_RESULTS;
SET @PREV_COLLATION_CONNECTION=@@COLLATION_CONNECTION;
SET character_set_client = utf8;
SET character_set_results = utf8;
SET collation_connection = utf8_general_ci;
CREATE ALGORITHM=UNDEFINED DEFINER=root@192.168.198.178 SQL SECURITY DEFINER VIEW v2 (s) AS SELECT s FROM db1.v1 WHERE i<2;
SET character_set_client = @PREV_CHARACTER_SET_CLIENT;
SET character_set_results = @PREV_CHARACTER_SET_RESULTS;
SET collation_connection = @PREV_COLLATION_CONNECTION;
`, "m"))
}
func TestDropTable(t *testing.T) {
s := newTiDBSuite(t)
ctx := context.Background()

View File

@ -13,6 +13,7 @@ go_library(
"reader.go",
"region.go",
"router.go",
"schema_import.go",
],
importpath = "github.com/pingcap/tidb/pkg/lightning/mydump",
visibility = ["//visibility:public"],
@ -24,11 +25,18 @@ go_library(
"//pkg/lightning/log",
"//pkg/lightning/metric",
"//pkg/lightning/worker",
"//pkg/parser",
"//pkg/parser/ast",
"//pkg/parser/format",
"//pkg/parser/model",
"//pkg/parser/mysql",
"//pkg/types",
"//pkg/util",
"//pkg/util/filter",
"//pkg/util/regexpr-router",
"//pkg/util/set",
"//pkg/util/slice",
"//pkg/util/sqlescape",
"//pkg/util/table-filter",
"//pkg/util/zeropool",
"@com_github_pingcap_errors//:errors",
@ -59,6 +67,7 @@ go_test(
"reader_test.go",
"region_test.go",
"router_test.go",
"schema_import_test.go",
],
data = glob([
"csv/*",
@ -75,12 +84,14 @@ go_test(
"//pkg/lightning/config",
"//pkg/lightning/log",
"//pkg/lightning/worker",
"//pkg/parser",
"//pkg/parser/mysql",
"//pkg/testkit/testsetup",
"//pkg/types",
"//pkg/util/filter",
"//pkg/util/table-filter",
"//pkg/util/table-router",
"@com_github_data_dog_go_sqlmock//:go-sqlmock",
"@com_github_pingcap_errors//:errors",
"@com_github_stretchr_testify//assert",
"@com_github_stretchr_testify//require",

View File

@ -130,7 +130,7 @@ func (m *MDTableMeta) GetSchema(ctx context.Context, store storage.ExternalStora
zap.String("Path", m.SchemaFile.FileMeta.Path),
log.ShortError(err),
)
return "", err
return "", errors.Trace(err)
}
return string(schema), nil
}

View File

@ -0,0 +1,371 @@
// Copyright 2024 PingCAP, Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package mydump
import (
"context"
"database/sql"
"fmt"
"strings"
"github.com/pingcap/errors"
"github.com/pingcap/tidb/br/pkg/storage"
"github.com/pingcap/tidb/pkg/lightning/common"
"github.com/pingcap/tidb/pkg/lightning/log"
"github.com/pingcap/tidb/pkg/parser"
"github.com/pingcap/tidb/pkg/parser/ast"
"github.com/pingcap/tidb/pkg/parser/format"
"github.com/pingcap/tidb/pkg/parser/model"
"github.com/pingcap/tidb/pkg/parser/mysql"
"github.com/pingcap/tidb/pkg/util"
"github.com/pingcap/tidb/pkg/util/set"
"github.com/pingcap/tidb/pkg/util/sqlescape"
"go.uber.org/zap"
)
type schemaStmtType int
// String implements fmt.Stringer interface.
func (stmtType schemaStmtType) String() string {
switch stmtType {
case schemaCreateDatabase:
return "restore database schema"
case schemaCreateTable:
return "restore table schema"
case schemaCreateView:
return "restore view schema"
}
return "unknown statement of schema"
}
const (
schemaCreateDatabase schemaStmtType = iota
schemaCreateTable
schemaCreateView
)
type schemaJob struct {
dbName string
tblName string // empty for create db jobs
stmtType schemaStmtType
sqlStr string
}
// SchemaImporter is used to import schema from dump files.
type SchemaImporter struct {
logger log.Logger
db *sql.DB
sqlMode mysql.SQLMode
store storage.ExternalStorage
concurrency int
}
// NewSchemaImporter creates a new SchemaImporter instance.
func NewSchemaImporter(logger log.Logger, sqlMode mysql.SQLMode, db *sql.DB, store storage.ExternalStorage, concurrency int) *SchemaImporter {
return &SchemaImporter{
logger: logger,
db: db,
sqlMode: sqlMode,
store: store,
concurrency: concurrency,
}
}
// Run imports all schemas from the given database metas.
func (si *SchemaImporter) Run(ctx context.Context, dbMetas []*MDDatabaseMeta) (err error) {
logTask := si.logger.Begin(zap.InfoLevel, "restore all schema")
defer func() {
logTask.End(zap.ErrorLevel, err)
}()
if len(dbMetas) == 0 {
return nil
}
if err = si.importDatabases(ctx, dbMetas); err != nil {
return errors.Trace(err)
}
if err = si.importTables(ctx, dbMetas); err != nil {
return errors.Trace(err)
}
return errors.Trace(si.importViews(ctx, dbMetas))
}
func (si *SchemaImporter) importDatabases(ctx context.Context, dbMetas []*MDDatabaseMeta) error {
existingSchemas, err := si.getExistingDatabases(ctx)
if err != nil {
return err
}
ch := make(chan *MDDatabaseMeta)
eg, egCtx := util.NewErrorGroupWithRecoverWithCtx(ctx)
for i := 0; i < si.concurrency; i++ {
eg.Go(func() error {
p := parser.New()
p.SetSQLMode(si.sqlMode)
for dbMeta := range ch {
sqlStr := dbMeta.GetSchema(egCtx, si.store)
if err2 := si.runJob(egCtx, p, &schemaJob{
dbName: dbMeta.Name,
stmtType: schemaCreateDatabase,
sqlStr: sqlStr,
}); err2 != nil {
return err2
}
}
return nil
})
}
eg.Go(func() error {
defer close(ch)
for i := range dbMetas {
dbMeta := dbMetas[i]
// if downstream already has this database, we can skip ddl job
if existingSchemas.Exist(strings.ToLower(dbMeta.Name)) {
si.logger.Info("database already exists in downstream, skip",
zap.String("db", dbMeta.Name),
)
continue
}
select {
case ch <- dbMeta:
case <-egCtx.Done():
}
}
return nil
})
return eg.Wait()
}
func (si *SchemaImporter) importTables(ctx context.Context, dbMetas []*MDDatabaseMeta) error {
ch := make(chan *MDTableMeta)
eg, egCtx := util.NewErrorGroupWithRecoverWithCtx(ctx)
for i := 0; i < si.concurrency; i++ {
eg.Go(func() error {
p := parser.New()
p.SetSQLMode(si.sqlMode)
for tableMeta := range ch {
sqlStr, err := tableMeta.GetSchema(egCtx, si.store)
if err != nil {
return err
}
if err = si.runJob(egCtx, p, &schemaJob{
dbName: tableMeta.DB,
tblName: tableMeta.Name,
stmtType: schemaCreateTable,
sqlStr: sqlStr,
}); err != nil {
return err
}
}
return nil
})
}
eg.Go(func() error {
defer close(ch)
for _, dbMeta := range dbMetas {
if len(dbMeta.Tables) == 0 {
continue
}
tables, err := si.getExistingTables(egCtx, dbMeta.Name)
if err != nil {
return err
}
for i := range dbMeta.Tables {
tblMeta := dbMeta.Tables[i]
if tables.Exist(strings.ToLower(tblMeta.Name)) {
// we already has this table in TiDB.
// we should skip ddl job and let SchemaValid check.
si.logger.Info("table already exists in downstream, skip",
zap.String("db", dbMeta.Name),
zap.String("table", tblMeta.Name),
)
continue
} else if tblMeta.SchemaFile.FileMeta.Path == "" {
return common.ErrSchemaNotExists.GenWithStackByArgs(dbMeta.Name, tblMeta.Name)
}
select {
case ch <- tblMeta:
case <-egCtx.Done():
return egCtx.Err()
}
}
}
return nil
})
return eg.Wait()
}
// dumpling dump a view as a table-schema sql file which creates a table of same name
// as the view, and a view-schema sql file which drops the table and creates the view.
func (si *SchemaImporter) importViews(ctx context.Context, dbMetas []*MDDatabaseMeta) error {
// 3. restore views. Since views can cross database we must restore views after all table schemas are restored.
// we don't support restore views concurrency, cauz it maybe will raise a error
p := parser.New()
p.SetSQLMode(si.sqlMode)
for _, dbMeta := range dbMetas {
if len(dbMeta.Views) == 0 {
continue
}
existingViews, err := si.getExistingViews(ctx, dbMeta.Name)
if err != nil {
return err
}
for _, viewMeta := range dbMeta.Views {
if existingViews.Exist(strings.ToLower(viewMeta.Name)) {
si.logger.Info("view already exists in downstream, skip",
zap.String("db", dbMeta.Name),
zap.String("view-name", viewMeta.Name))
continue
}
sqlStr, err := viewMeta.GetSchema(ctx, si.store)
if err != nil {
return err
}
if strings.TrimSpace(sqlStr) == "" {
si.logger.Info("view schema is empty, skip",
zap.String("db", dbMeta.Name),
zap.String("view-name", viewMeta.Name))
continue
}
if err = si.runJob(ctx, p, &schemaJob{
dbName: dbMeta.Name,
tblName: viewMeta.Name,
stmtType: schemaCreateView,
sqlStr: sqlStr,
}); err != nil {
return err
}
}
}
return nil
}
func (si *SchemaImporter) runJob(ctx context.Context, p *parser.Parser, job *schemaJob) error {
stmts, err := createIfNotExistsStmt(p, job.sqlStr, job.dbName, job.tblName)
if err != nil {
return errors.Trace(err)
}
conn, err := si.db.Conn(ctx)
if err != nil {
return err
}
defer func() {
_ = conn.Close()
}()
logger := si.logger.With(zap.String("db", job.dbName), zap.String("table", job.tblName))
sqlWithRetry := common.SQLWithRetry{
Logger: logger,
DB: conn,
}
for _, stmt := range stmts {
task := logger.Begin(zap.DebugLevel, fmt.Sprintf("execute SQL: %s", stmt))
err = sqlWithRetry.Exec(ctx, "run create schema job", stmt)
task.End(zap.ErrorLevel, err)
if err != nil {
return common.ErrCreateSchema.Wrap(err).GenWithStackByArgs(common.UniqueTable(job.dbName, job.tblName), job.stmtType.String())
}
}
return nil
}
func (si *SchemaImporter) getExistingDatabases(ctx context.Context) (set.StringSet, error) {
return si.getExistingSchemas(ctx, `SELECT SCHEMA_NAME FROM information_schema.SCHEMATA`)
}
// the result contains views too, but as table and view share the same name space, it's ok.
func (si *SchemaImporter) getExistingTables(ctx context.Context, dbName string) (set.StringSet, error) {
sb := new(strings.Builder)
sqlescape.MustFormatSQL(sb, `SELECT TABLE_NAME FROM information_schema.TABLES WHERE TABLE_SCHEMA = %?`, dbName)
return si.getExistingSchemas(ctx, sb.String())
}
func (si *SchemaImporter) getExistingViews(ctx context.Context, dbName string) (set.StringSet, error) {
sb := new(strings.Builder)
sqlescape.MustFormatSQL(sb, `SELECT TABLE_NAME FROM information_schema.VIEWS WHERE TABLE_SCHEMA = %?`, dbName)
return si.getExistingSchemas(ctx, sb.String())
}
// get existing databases/tables/views using the given query, the first column of
// the query result should be the name.
// The returned names are convert to lower case.
func (si *SchemaImporter) getExistingSchemas(ctx context.Context, query string) (set.StringSet, error) {
conn, err := si.db.Conn(ctx)
if err != nil {
return nil, errors.Trace(err)
}
defer func() {
_ = conn.Close()
}()
sqlWithRetry := common.SQLWithRetry{
Logger: si.logger,
DB: conn,
}
stringRows, err := sqlWithRetry.QueryStringRows(ctx, "get existing schemas", query)
if err != nil {
return nil, errors.Trace(err)
}
res := make(set.StringSet, len(stringRows))
for _, row := range stringRows {
res.Insert(strings.ToLower(row[0]))
}
return res, nil
}
func createIfNotExistsStmt(p *parser.Parser, createTable, dbName, tblName string) ([]string, error) {
stmts, _, err := p.ParseSQL(createTable)
if err != nil {
return []string{}, common.ErrInvalidSchemaStmt.Wrap(err).GenWithStackByArgs(createTable)
}
var res strings.Builder
ctx := format.NewRestoreCtx(format.DefaultRestoreFlags|format.RestoreTiDBSpecialComment|format.RestoreWithTTLEnableOff, &res)
retStmts := make([]string, 0, len(stmts))
for _, stmt := range stmts {
switch node := stmt.(type) {
case *ast.CreateDatabaseStmt:
node.Name = model.NewCIStr(dbName)
node.IfNotExists = true
case *ast.DropDatabaseStmt:
node.Name = model.NewCIStr(dbName)
node.IfExists = true
case *ast.CreateTableStmt:
node.Table.Schema = model.NewCIStr(dbName)
node.Table.Name = model.NewCIStr(tblName)
node.IfNotExists = true
case *ast.CreateViewStmt:
node.ViewName.Schema = model.NewCIStr(dbName)
node.ViewName.Name = model.NewCIStr(tblName)
case *ast.DropTableStmt:
node.Tables[0].Schema = model.NewCIStr(dbName)
node.Tables[0].Name = model.NewCIStr(tblName)
node.IfExists = true
}
if err := stmt.Restore(ctx); err != nil {
return []string{}, common.ErrInvalidSchemaStmt.Wrap(err).GenWithStackByArgs(createTable)
}
ctx.WritePlain(";")
retStmts = append(retStmts, res.String())
res.Reset()
}
return retStmts, nil
}

View File

@ -0,0 +1,370 @@
// Copyright 2024 PingCAP, Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package mydump
import (
"context"
"fmt"
"os"
"path"
"testing"
"github.com/DATA-DOG/go-sqlmock"
"github.com/pingcap/errors"
"github.com/pingcap/tidb/br/pkg/storage"
"github.com/pingcap/tidb/pkg/lightning/common"
"github.com/pingcap/tidb/pkg/lightning/log"
"github.com/pingcap/tidb/pkg/parser"
"github.com/pingcap/tidb/pkg/parser/mysql"
"github.com/stretchr/testify/require"
"go.uber.org/zap"
)
func TestSchemaImporter(t *testing.T) {
db, mock, err := sqlmock.New()
mock.MatchExpectationsInOrder(false)
require.NoError(t, err)
t.Cleanup(func() {
require.NoError(t, mock.ExpectationsWereMet())
// have to ignore the error here, as sqlmock doesn't allow set number of
// expectations, and each opened connection requires a Close() call.
_ = db.Close()
})
ctx := context.Background()
tempDir := t.TempDir()
store, err := storage.NewLocalStorage(tempDir)
require.NoError(t, err)
logger := log.Logger{Logger: zap.NewExample()}
importer := NewSchemaImporter(logger, mysql.SQLMode(0), db, store, 4)
require.NoError(t, importer.Run(ctx, nil))
t.Run("get existing schema err", func(t *testing.T) {
mock.ExpectQuery(`information_schema.SCHEMATA`).WillReturnError(errors.New("non retryable error"))
require.ErrorContains(t, importer.Run(ctx, []*MDDatabaseMeta{{Name: "test"}}), "non retryable error")
require.NoError(t, mock.ExpectationsWereMet())
})
t.Run("database already exists", func(t *testing.T) {
mock.ExpectQuery(`information_schema.SCHEMATA`).WillReturnRows(
sqlmock.NewRows([]string{"SCHEMA_NAME"}).AddRow("test"))
require.NoError(t, importer.Run(ctx, []*MDDatabaseMeta{{Name: "test"}}))
require.NoError(t, mock.ExpectationsWereMet())
})
t.Run("create non exist database", func(t *testing.T) {
mock.ExpectQuery(`information_schema.SCHEMATA`).WillReturnRows(
sqlmock.NewRows([]string{"SCHEMA_NAME"}))
dbMetas := make([]*MDDatabaseMeta, 0, 10)
for i := 0; i < 10; i++ {
mock.ExpectExec(fmt.Sprintf("CREATE DATABASE IF NOT EXISTS `test%02d`", i)).
WillReturnResult(sqlmock.NewResult(0, 0))
dbMetas = append(dbMetas, &MDDatabaseMeta{Name: fmt.Sprintf("test%02d", i)})
}
require.NoError(t, importer.Run(ctx, dbMetas))
require.NoError(t, mock.ExpectationsWereMet())
})
t.Run("break on database error", func(t *testing.T) {
importer2 := NewSchemaImporter(logger, mysql.SQLMode(0), db, store, 1)
mock.ExpectQuery(`information_schema.SCHEMATA`).WillReturnRows(
sqlmock.NewRows([]string{"SCHEMA_NAME"}))
fileName := "invalid-schema.sql"
require.NoError(t, os.WriteFile(path.Join(tempDir, fileName), []byte("CREATE invalid;"), 0o644))
dbMetas := []*MDDatabaseMeta{
{Name: "test", charSet: "auto", SchemaFile: FileInfo{FileMeta: SourceFileMeta{Path: fileName}}},
{Name: "test2"}, // not chance to run
}
require.ErrorContains(t, importer2.Run(ctx, dbMetas), "invalid schema statement")
require.NoError(t, mock.ExpectationsWereMet())
require.NoError(t, os.Remove(path.Join(tempDir, fileName)))
dbMetas = append([]*MDDatabaseMeta{{Name: "ttt"}}, dbMetas...)
mock.ExpectQuery(`information_schema.SCHEMATA`).WillReturnRows(
sqlmock.NewRows([]string{"SCHEMA_NAME"}))
mock.ExpectExec("CREATE DATABASE IF NOT EXISTS `ttt`").
WillReturnError(errors.New("non retryable error"))
err2 := importer2.Run(ctx, dbMetas)
require.ErrorIs(t, err2, common.ErrCreateSchema)
require.ErrorContains(t, err2, "non retryable error")
require.NoError(t, mock.ExpectationsWereMet())
})
t.Run("table: get existing schema err", func(t *testing.T) {
mock.ExpectQuery(`information_schema.SCHEMATA`).WillReturnRows(
sqlmock.NewRows([]string{"SCHEMA_NAME"}).
AddRow("test01").AddRow("test02").AddRow("test03").
AddRow("test04").AddRow("test05"))
mock.ExpectQuery("TABLES WHERE TABLE_SCHEMA = 'test02'").
WillReturnError(errors.New("non retryable error"))
dbMetas := []*MDDatabaseMeta{
{Name: "test01"},
{Name: "test02", Tables: []*MDTableMeta{{DB: "test02", Name: "t"}}},
}
require.ErrorContains(t, importer.Run(ctx, dbMetas), "non retryable error")
require.NoError(t, mock.ExpectationsWereMet())
})
t.Run("table: invalid schema file", func(t *testing.T) {
mock.ExpectQuery(`information_schema.SCHEMATA`).WillReturnRows(
sqlmock.NewRows([]string{"SCHEMA_NAME"}).
AddRow("test01").AddRow("test02").AddRow("test03").
AddRow("test04").AddRow("test05"))
mock.ExpectQuery("TABLES WHERE TABLE_SCHEMA = 'test01'").
WillReturnRows(sqlmock.NewRows([]string{"TABLE_NAME"}).AddRow("t1"))
fileName := "t2-invalid-schema.sql"
require.NoError(t, os.WriteFile(path.Join(tempDir, fileName), []byte("CREATE table t2 whatever;"), 0o644))
dbMetas := []*MDDatabaseMeta{
{Name: "test01", Tables: []*MDTableMeta{
{DB: "test01", Name: "t1"},
{DB: "test01", Name: "T2", charSet: "auto",
SchemaFile: FileInfo{FileMeta: SourceFileMeta{Path: fileName}}},
}},
}
require.ErrorContains(t, importer.Run(ctx, dbMetas), "line 1 column 24 near")
require.NoError(t, mock.ExpectationsWereMet())
// create table t2 downstream manually as workaround
mock.ExpectQuery(`information_schema.SCHEMATA`).WillReturnRows(
sqlmock.NewRows([]string{"SCHEMA_NAME"}).
AddRow("test01").AddRow("test02").AddRow("test03").
AddRow("test04").AddRow("test05"))
mock.ExpectQuery("TABLES WHERE TABLE_SCHEMA = 'test01'").
WillReturnRows(sqlmock.NewRows([]string{"TABLE_NAME"}).AddRow("t1").AddRow("t2"))
require.NoError(t, importer.Run(ctx, dbMetas))
require.NoError(t, mock.ExpectationsWereMet())
require.NoError(t, os.Remove(path.Join(tempDir, fileName)))
})
t.Run("table: break on error", func(t *testing.T) {
importer2 := NewSchemaImporter(logger, mysql.SQLMode(0), db, store, 1)
mock.ExpectQuery(`information_schema.SCHEMATA`).WillReturnRows(
sqlmock.NewRows([]string{"SCHEMA_NAME"}).
AddRow("test01").AddRow("test02").AddRow("test03").
AddRow("test04").AddRow("test05"))
fileNameT1 := "test01.t1-schema.sql"
fileNameT2 := "test01.t2-schema.sql"
require.NoError(t, os.WriteFile(path.Join(tempDir, fileNameT1), []byte("CREATE table t1(a int);"), 0o644))
require.NoError(t, os.WriteFile(path.Join(tempDir, fileNameT2), []byte("CREATE table t2(a int);"), 0o644))
dbMetas := []*MDDatabaseMeta{
{Name: "test01", Tables: []*MDTableMeta{
{DB: "test01", Name: "t1", charSet: "auto", SchemaFile: FileInfo{FileMeta: SourceFileMeta{Path: fileNameT1}}},
{DB: "test01", Name: "t2", charSet: "auto", SchemaFile: FileInfo{FileMeta: SourceFileMeta{Path: fileNameT2}}},
}},
}
mock.ExpectQuery("TABLES WHERE TABLE_SCHEMA = 'test01'").
WillReturnRows(sqlmock.NewRows([]string{"TABLE_NAME"}))
mock.ExpectExec("CREATE TABLE IF NOT EXISTS `test01`.`t1`").
WillReturnError(errors.New("non retryable create table error"))
require.ErrorContains(t, importer2.Run(ctx, dbMetas), "non retryable create table error")
require.NoError(t, mock.ExpectationsWereMet())
require.NoError(t, os.Remove(path.Join(tempDir, fileNameT1)))
require.NoError(t, os.Remove(path.Join(tempDir, fileNameT2)))
})
t.Run("view: get existing schema err", func(t *testing.T) {
mock.ExpectQuery(`information_schema.SCHEMATA`).WillReturnRows(
sqlmock.NewRows([]string{"SCHEMA_NAME"}).AddRow("test01").AddRow("test02"))
mock.ExpectQuery("VIEWS WHERE TABLE_SCHEMA = 'test02'").
WillReturnError(errors.New("non retryable error"))
dbMetas := []*MDDatabaseMeta{
{Name: "test01"},
{Name: "test02", Views: []*MDTableMeta{{DB: "test02", Name: "v"}}},
}
require.ErrorContains(t, importer.Run(ctx, dbMetas), "non retryable error")
require.NoError(t, mock.ExpectationsWereMet())
})
t.Run("view: fail on create", func(t *testing.T) {
mock.ExpectQuery(`information_schema.SCHEMATA`).WillReturnRows(
sqlmock.NewRows([]string{"SCHEMA_NAME"}).AddRow("test01").AddRow("test02"))
mock.ExpectQuery("VIEWS WHERE TABLE_SCHEMA = 'test02'").
WillReturnRows(sqlmock.NewRows([]string{"TABLE_NAME"}))
fileNameV0 := "empty-file.sql"
fileNameV1 := "invalid-schema.sql"
fileNameV2 := "test02.v2-schema-view.sql"
require.NoError(t, os.WriteFile(path.Join(tempDir, fileNameV0), []byte(""), 0o644))
require.NoError(t, os.WriteFile(path.Join(tempDir, fileNameV1), []byte("xxxx;"), 0o644))
require.NoError(t, os.WriteFile(path.Join(tempDir, fileNameV2), []byte("create view v2 as select * from t;"), 0o644))
dbMetas := []*MDDatabaseMeta{
{Name: "test01"},
{Name: "test02", Views: []*MDTableMeta{
{DB: "test02", Name: "V0", charSet: "auto", SchemaFile: FileInfo{FileMeta: SourceFileMeta{Path: fileNameV0}}},
{DB: "test02", Name: "v1", charSet: "auto", SchemaFile: FileInfo{FileMeta: SourceFileMeta{Path: fileNameV1}}},
{DB: "test02", Name: "V2", charSet: "auto", SchemaFile: FileInfo{FileMeta: SourceFileMeta{Path: fileNameV2}}}}},
}
require.ErrorContains(t, importer.Run(ctx, dbMetas), `line 1 column 4 near "xxxx;"`)
require.NoError(t, mock.ExpectationsWereMet())
// create view v2 downstream manually as workaround
mock.ExpectQuery(`information_schema.SCHEMATA`).WillReturnRows(
sqlmock.NewRows([]string{"SCHEMA_NAME"}).AddRow("test01").AddRow("test02"))
mock.ExpectQuery("VIEWS WHERE TABLE_SCHEMA = 'test02'").
WillReturnRows(sqlmock.NewRows([]string{"TABLE_NAME"}).AddRow("V1"))
mock.ExpectExec("VIEW `test02`.`V2` AS SELECT").
WillReturnResult(sqlmock.NewResult(0, 0))
require.NoError(t, importer.Run(ctx, dbMetas))
require.NoError(t, mock.ExpectationsWereMet())
require.NoError(t, os.Remove(path.Join(tempDir, fileNameV0)))
require.NoError(t, os.Remove(path.Join(tempDir, fileNameV1)))
require.NoError(t, os.Remove(path.Join(tempDir, fileNameV2)))
})
}
func TestSchemaImporterManyTables(t *testing.T) {
db, mock, err := sqlmock.New()
mock.MatchExpectationsInOrder(false)
require.NoError(t, err)
t.Cleanup(func() {
require.NoError(t, mock.ExpectationsWereMet())
// have to ignore the error here, as sqlmock doesn't allow set number of
// expectations, and each opened connection requires a Close() call.
_ = db.Close()
})
ctx := context.Background()
tempDir := t.TempDir()
store, err := storage.NewLocalStorage(tempDir)
require.NoError(t, err)
logger := log.Logger{Logger: zap.NewExample()}
importer := NewSchemaImporter(logger, mysql.SQLMode(0), db, store, 8)
mock.ExpectQuery(`information_schema.SCHEMATA`).WillReturnRows(
sqlmock.NewRows([]string{"SCHEMA_NAME"}))
dbMetas := make([]*MDDatabaseMeta, 0, 30)
for i := 0; i < 30; i++ {
dbName := fmt.Sprintf("test%02d", i)
dbMeta := &MDDatabaseMeta{Name: dbName, Tables: make([]*MDTableMeta, 0, 100)}
mock.ExpectExec(fmt.Sprintf("CREATE DATABASE IF NOT EXISTS `%s`", dbName)).
WillReturnResult(sqlmock.NewResult(0, 0))
mock.ExpectQuery(fmt.Sprintf("TABLES WHERE TABLE_SCHEMA = '%s'", dbName)).
WillReturnRows(sqlmock.NewRows([]string{"TABLE_NAME"}))
for j := 0; j < 50; j++ {
tblName := fmt.Sprintf("t%03d", j)
fileName := fmt.Sprintf("%s.%s-schema.sql", dbName, tblName)
require.NoError(t, os.WriteFile(path.Join(tempDir, fileName), []byte(fmt.Sprintf("CREATE TABLE %s(a int);", tblName)), 0o644))
mock.ExpectExec(fmt.Sprintf("CREATE TABLE IF NOT EXISTS `%s`.`%s`", dbName, tblName)).
WillReturnResult(sqlmock.NewResult(0, 0))
dbMeta.Tables = append(dbMeta.Tables, &MDTableMeta{
DB: dbName, Name: tblName, charSet: "auto",
SchemaFile: FileInfo{FileMeta: SourceFileMeta{Path: fileName}},
})
}
dbMetas = append(dbMetas, dbMeta)
}
require.NoError(t, importer.Run(ctx, dbMetas))
require.NoError(t, mock.ExpectationsWereMet())
}
func TestCreateTableIfNotExistsStmt(t *testing.T) {
dbName := "testdb"
p := parser.New()
createSQLIfNotExistsStmt := func(createTable, tableName string) []string {
res, err := createIfNotExistsStmt(p, createTable, dbName, tableName)
require.NoError(t, err)
return res
}
require.Equal(t, []string{"CREATE DATABASE IF NOT EXISTS `testdb` CHARACTER SET = utf8 COLLATE = utf8_general_ci;"},
createSQLIfNotExistsStmt("CREATE DATABASE `foo` CHARACTER SET = utf8 COLLATE = utf8_general_ci;", ""))
require.Equal(t, []string{"CREATE TABLE IF NOT EXISTS `testdb`.`foo` (`bar` TINYINT(1));"},
createSQLIfNotExistsStmt("CREATE TABLE `foo`(`bar` TINYINT(1));", "foo"))
require.Equal(t, []string{"CREATE TABLE IF NOT EXISTS `testdb`.`foo` (`bar` TINYINT(1));"},
createSQLIfNotExistsStmt("CREATE TABLE IF NOT EXISTS `foo`(`bar` TINYINT(1));", "foo"))
// case insensitive
require.Equal(t, []string{"CREATE TABLE IF NOT EXISTS `testdb`.`fOo` (`bar` TINYINT(1));"},
createSQLIfNotExistsStmt("/* cOmmEnt */ creAte tablE `fOo`(`bar` TinyinT(1));", "fOo"))
require.Equal(t, []string{"CREATE TABLE IF NOT EXISTS `testdb`.`FoO` (`bAR` TINYINT(1));"},
createSQLIfNotExistsStmt("/* coMMenT */ crEatE tAble If not EXISts `FoO`(`bAR` tiNyInT(1));", "FoO"))
// only one "CREATE TABLE" is replaced
require.Equal(t, []string{"CREATE TABLE IF NOT EXISTS `testdb`.`foo` (`bar` INT(1) COMMENT 'CREATE TABLE');"},
createSQLIfNotExistsStmt("CREATE TABLE `foo`(`bar` INT(1) COMMENT 'CREATE TABLE');", "foo"))
// test clustered index consistency
require.Equal(t, []string{"CREATE TABLE IF NOT EXISTS `testdb`.`foo` (`bar` INT(1) PRIMARY KEY /*T![clustered_index] CLUSTERED */ COMMENT 'CREATE TABLE');"},
createSQLIfNotExistsStmt("CREATE TABLE `foo`(`bar` INT(1) PRIMARY KEY CLUSTERED COMMENT 'CREATE TABLE');", "foo"))
require.Equal(t, []string{"CREATE TABLE IF NOT EXISTS `testdb`.`foo` (`bar` INT(1) COMMENT 'CREATE TABLE',PRIMARY KEY(`bar`) /*T![clustered_index] NONCLUSTERED */);"},
createSQLIfNotExistsStmt("CREATE TABLE `foo`(`bar` INT(1) COMMENT 'CREATE TABLE', PRIMARY KEY (`bar`) NONCLUSTERED);", "foo"))
require.Equal(t, []string{"CREATE TABLE IF NOT EXISTS `testdb`.`foo` (`bar` INT(1) PRIMARY KEY /*T![clustered_index] NONCLUSTERED */ COMMENT 'CREATE TABLE');"},
createSQLIfNotExistsStmt("CREATE TABLE `foo`(`bar` INT(1) PRIMARY KEY /*T![clustered_index] NONCLUSTERED */ COMMENT 'CREATE TABLE');", "foo"))
require.Equal(t, []string{"CREATE TABLE IF NOT EXISTS `testdb`.`foo` (`bar` INT(1) COMMENT 'CREATE TABLE',PRIMARY KEY(`bar`) /*T![clustered_index] CLUSTERED */);"},
createSQLIfNotExistsStmt("CREATE TABLE `foo`(`bar` INT(1) COMMENT 'CREATE TABLE', PRIMARY KEY (`bar`) /*T![clustered_index] CLUSTERED */);", "foo"))
require.Equal(t, []string{"CREATE TABLE IF NOT EXISTS `testdb`.`foo` (`bar` INT(1) PRIMARY KEY /*T![auto_rand] AUTO_RANDOM(2) */ COMMENT 'CREATE TABLE');"},
createSQLIfNotExistsStmt("CREATE TABLE `foo`(`bar` INT(1) PRIMARY KEY AUTO_RANDOM(2) COMMENT 'CREATE TABLE');", "foo"))
// upper case becomes shorter
require.Equal(t, []string{"CREATE TABLE IF NOT EXISTS `testdb`.`ſ` (`ı` TINYINT(1));"},
createSQLIfNotExistsStmt("CREATE TABLE `ſ`(`ı` TINYINT(1));", "ſ"))
// upper case becomes longer
require.Equal(t, []string{"CREATE TABLE IF NOT EXISTS `testdb`.`ɑ` (`ȿ` TINYINT(1));"},
createSQLIfNotExistsStmt("CREATE TABLE `ɑ`(`ȿ` TINYINT(1));", "ɑ"))
// non-utf-8
require.Equal(t, []string{"CREATE TABLE IF NOT EXISTS `testdb`.`\xcc\xcc\xcc` (`???` TINYINT(1));"},
createSQLIfNotExistsStmt("CREATE TABLE `\xcc\xcc\xcc`(`\xdd\xdd\xdd` TINYINT(1));", "\xcc\xcc\xcc"))
// renaming a table
require.Equal(t, []string{"CREATE TABLE IF NOT EXISTS `testdb`.`ba``r` (`x` INT);"},
createSQLIfNotExistsStmt("create table foo(x int);", "ba`r"))
// conditional comments
require.Equal(t, []string{
"SET NAMES 'binary';",
"SET @@SESSION.`FOREIGN_KEY_CHECKS`=0;",
"CREATE TABLE IF NOT EXISTS `testdb`.`m` (`z` DOUBLE) ENGINE = InnoDB AUTO_INCREMENT = 8343230 DEFAULT CHARACTER SET = UTF8;",
},
createSQLIfNotExistsStmt(`
/*!40101 SET NAMES binary*/;
/*!40014 SET FOREIGN_KEY_CHECKS=0*/;
CREATE TABLE x.y (z double) ENGINE=InnoDB AUTO_INCREMENT=8343230 DEFAULT CHARSET=utf8;
`, "m"))
// create view
require.Equal(t, []string{
"SET NAMES 'binary';",
"DROP TABLE IF EXISTS `testdb`.`m`;",
"DROP VIEW IF EXISTS `testdb`.`m`;",
"SET @`PREV_CHARACTER_SET_CLIENT`=@@`character_set_client`;",
"SET @`PREV_CHARACTER_SET_RESULTS`=@@`character_set_results`;",
"SET @`PREV_COLLATION_CONNECTION`=@@`collation_connection`;",
"SET @@SESSION.`character_set_client`=`utf8`;",
"SET @@SESSION.`character_set_results`=`utf8`;",
"SET @@SESSION.`collation_connection`=`utf8_general_ci`;",
"CREATE ALGORITHM = UNDEFINED DEFINER = `root`@`192.168.198.178` SQL SECURITY DEFINER VIEW `testdb`.`m` (`s`) AS SELECT `s` FROM `db1`.`v1` WHERE `i`<2;",
"SET @@SESSION.`character_set_client`=@`PREV_CHARACTER_SET_CLIENT`;",
"SET @@SESSION.`character_set_results`=@`PREV_CHARACTER_SET_RESULTS`;",
"SET @@SESSION.`collation_connection`=@`PREV_COLLATION_CONNECTION`;",
},
createSQLIfNotExistsStmt(`
/*!40101 SET NAMES binary*/;
DROP TABLE IF EXISTS v2;
DROP VIEW IF EXISTS v2;
SET @PREV_CHARACTER_SET_CLIENT=@@CHARACTER_SET_CLIENT;
SET @PREV_CHARACTER_SET_RESULTS=@@CHARACTER_SET_RESULTS;
SET @PREV_COLLATION_CONNECTION=@@COLLATION_CONNECTION;
SET character_set_client = utf8;
SET character_set_results = utf8;
SET collation_connection = utf8_general_ci;
CREATE ALGORITHM=UNDEFINED DEFINER=root@192.168.198.178 SQL SECURITY DEFINER VIEW v2 (s) AS SELECT s FROM db1.v1 WHERE i<2;
SET character_set_client = @PREV_CHARACTER_SET_CLIENT;
SET character_set_results = @PREV_CHARACTER_SET_RESULTS;
SET collation_connection = @PREV_COLLATION_CONNECTION;
`, "m"))
}