diff --git a/lightning/pkg/importer/BUILD.bazel b/lightning/pkg/importer/BUILD.bazel index 0075a4db17..83797e9553 100644 --- a/lightning/pkg/importer/BUILD.bazel +++ b/lightning/pkg/importer/BUILD.bazel @@ -54,7 +54,6 @@ go_library( "//pkg/meta/autoid", "//pkg/parser", "//pkg/parser/ast", - "//pkg/parser/format", "//pkg/parser/model", "//pkg/parser/mysql", "//pkg/planner/core", @@ -116,7 +115,6 @@ go_test( "meta_manager_test.go", "precheck_impl_test.go", "precheck_test.go", - "restore_schema_test.go", "table_import_test.go", "tidb_test.go", ], diff --git a/lightning/pkg/importer/import.go b/lightning/pkg/importer/import.go index 20ad2e3b06..a1cd053023 100644 --- a/lightning/pkg/importer/import.go +++ b/lightning/pkg/importer/import.go @@ -56,7 +56,6 @@ import ( "github.com/pingcap/tidb/pkg/lightning/tikv" "github.com/pingcap/tidb/pkg/lightning/worker" "github.com/pingcap/tidb/pkg/meta/autoid" - "github.com/pingcap/tidb/pkg/parser" "github.com/pingcap/tidb/pkg/parser/model" "github.com/pingcap/tidb/pkg/session" "github.com/pingcap/tidb/pkg/sessionctx/variable" @@ -583,324 +582,19 @@ outside: return errors.Trace(err) } -type schemaStmtType int - -// String implements fmt.Stringer interface. -func (stmtType schemaStmtType) String() string { - switch stmtType { - case schemaCreateDatabase: - return "restore database schema" - case schemaCreateTable: - return "restore table schema" - case schemaCreateView: - return "restore view schema" - } - return "unknown statement of schema" -} - -const ( - schemaCreateDatabase schemaStmtType = iota - schemaCreateTable - schemaCreateView -) - -type schemaJob struct { - dbName string - tblName string // empty for create db jobs - stmtType schemaStmtType - stmts []string -} - -type restoreSchemaWorker struct { - ctx context.Context - quit context.CancelFunc - logger log.Logger - jobCh chan *schemaJob - errCh chan error - wg sync.WaitGroup - db *sql.DB - parser *parser.Parser - store storage.ExternalStorage -} - -func (worker *restoreSchemaWorker) addJob(sqlStr string, job *schemaJob) error { - stmts, err := createIfNotExistsStmt(worker.parser, sqlStr, job.dbName, job.tblName) - if err != nil { - return errors.Trace(err) - } - job.stmts = stmts - return worker.appendJob(job) -} - -func (worker *restoreSchemaWorker) makeJobs( - dbMetas []*mydump.MDDatabaseMeta, - getDBs func(context.Context) ([]*model.DBInfo, error), - getTables func(context.Context, string) ([]*model.TableInfo, error), -) error { - defer func() { - close(worker.jobCh) - worker.quit() - }() - - if len(dbMetas) == 0 { - return nil - } - - // 1. restore databases, execute statements concurrency - - dbs, err := getDBs(worker.ctx) - if err != nil { - worker.logger.Warn("get databases from downstream failed", zap.Error(err)) - } - dbSet := make(set.StringSet, len(dbs)) - for _, db := range dbs { - dbSet.Insert(db.Name.L) - } - - for _, dbMeta := range dbMetas { - // if downstream already has this database, we can skip ddl job - if dbSet.Exist(strings.ToLower(dbMeta.Name)) { - worker.logger.Info( - "database already exists in downstream, skip processing the source file", - zap.String("db", dbMeta.Name), - ) - continue - } - - sql := dbMeta.GetSchema(worker.ctx, worker.store) - err = worker.addJob(sql, &schemaJob{ - dbName: dbMeta.Name, - tblName: "", - stmtType: schemaCreateDatabase, - }) - if err != nil { - return err - } - } - err = worker.wait() - if err != nil { - return err - } - - // 2. restore tables, execute statements concurrency - - for _, dbMeta := range dbMetas { - // we can ignore error here, and let check failed later if schema not match - tables, err := getTables(worker.ctx, dbMeta.Name) - if err != nil { - worker.logger.Warn("get tables from downstream failed", zap.Error(err)) - } - tableSet := make(set.StringSet, len(tables)) - for _, t := range tables { - tableSet.Insert(t.Name.L) - } - for _, tblMeta := range dbMeta.Tables { - if tableSet.Exist(strings.ToLower(tblMeta.Name)) { - // we already has this table in TiDB. - // we should skip ddl job and let SchemaValid check. - worker.logger.Info( - "table already exists in downstream, skip processing the source file", - zap.String("db", dbMeta.Name), - zap.String("table", tblMeta.Name), - ) - continue - } else if tblMeta.SchemaFile.FileMeta.Path == "" { - return common.ErrSchemaNotExists.GenWithStackByArgs(dbMeta.Name, tblMeta.Name) - } - sql, err := tblMeta.GetSchema(worker.ctx, worker.store) - if err != nil { - return err - } - if sql != "" { - err = worker.addJob(sql, &schemaJob{ - dbName: dbMeta.Name, - tblName: tblMeta.Name, - stmtType: schemaCreateTable, - }) - if err != nil { - return err - } - } - } - } - err = worker.wait() - if err != nil { - return err - } - // 3. restore views. Since views can cross database we must restore views after all table schemas are restored. - for _, dbMeta := range dbMetas { - for _, viewMeta := range dbMeta.Views { - sql, err := viewMeta.GetSchema(worker.ctx, worker.store) - if sql != "" { - err = worker.addJob(sql, &schemaJob{ - dbName: dbMeta.Name, - tblName: viewMeta.Name, - stmtType: schemaCreateView, - }) - if err != nil { - return err - } - // we don't support restore views concurrency, cauz it maybe will raise a error - err = worker.wait() - if err != nil { - return err - } - } - if err != nil { - return err - } - } - } - return nil -} - -func (worker *restoreSchemaWorker) doJob() { - var session *sql.Conn - defer func() { - if session != nil { - _ = session.Close() - } - }() -loop: - for { - select { - case <-worker.ctx.Done(): - // don't `return` or throw `worker.ctx.Err()`here, - // if we `return`, we can't mark cancelled jobs as done, - // if we `throw(worker.ctx.Err())`, it will be blocked to death - break loop - case job := <-worker.jobCh: - if job == nil { - // successful exit - return - } - var err error - if session == nil { - session, err = func() (*sql.Conn, error) { - return worker.db.Conn(worker.ctx) - }() - if err != nil { - worker.wg.Done() - worker.throw(err) - // don't return - break loop - } - } - logger := worker.logger.With(zap.String("db", job.dbName), zap.String("table", job.tblName)) - sqlWithRetry := common.SQLWithRetry{ - Logger: worker.logger, - DB: session, - } - for _, stmt := range job.stmts { - task := logger.Begin(zap.DebugLevel, fmt.Sprintf("execute SQL: %s", stmt)) - err = sqlWithRetry.Exec(worker.ctx, "run create schema job", stmt) - if err != nil { - // try to imitate IF NOT EXISTS behavior for parsing errors - exists := false - switch job.stmtType { - case schemaCreateDatabase: - var err2 error - exists, err2 = common.SchemaExists(worker.ctx, session, job.dbName) - if err2 != nil { - task.Error("failed to check database existence", zap.Error(err2)) - } - case schemaCreateTable: - exists, _ = common.TableExists(worker.ctx, session, job.dbName, job.tblName) - } - if exists { - err = nil - } - } - task.End(zap.ErrorLevel, err) - - if err != nil { - err = common.ErrCreateSchema.Wrap(err).GenWithStackByArgs(common.UniqueTable(job.dbName, job.tblName), job.stmtType.String()) - worker.wg.Done() - worker.throw(err) - // don't return - break loop - } - } - worker.wg.Done() - } - } - // mark the cancelled job as `Done`, a little tricky, - // cauz we need make sure `worker.wg.Wait()` wouldn't blocked forever - for range worker.jobCh { - worker.wg.Done() - } -} - -func (worker *restoreSchemaWorker) wait() error { - // avoid to `worker.wg.Wait()` blocked forever when all `doJob`'s goroutine exited. - // don't worry about goroutine below, it never become a zombie, - // cauz we have mechanism to clean cancelled jobs in `worker.jobCh`. - // means whole jobs has been send to `worker.jobCh` would be done. - waitCh := make(chan struct{}) - go func() { - worker.wg.Wait() - close(waitCh) - }() - select { - case err := <-worker.errCh: - return err - case <-worker.ctx.Done(): - return worker.ctx.Err() - case <-waitCh: - return nil - } -} - -func (worker *restoreSchemaWorker) throw(err error) { - select { - case <-worker.ctx.Done(): - // don't throw `worker.ctx.Err()` again, it will be blocked to death. - return - case worker.errCh <- err: - worker.quit() - } -} - -func (worker *restoreSchemaWorker) appendJob(job *schemaJob) error { - worker.wg.Add(1) - select { - case err := <-worker.errCh: - // cancel the job - worker.wg.Done() - return err - case <-worker.ctx.Done(): - // cancel the job - worker.wg.Done() - return errors.Trace(worker.ctx.Err()) - case worker.jobCh <- job: - return nil - } -} - func (rc *Controller) restoreSchema(ctx context.Context) error { // create table with schema file // we can handle the duplicated created with createIfNotExist statement // and we will check the schema in TiDB is valid with the datafile in DataCheck later. - logTask := log.FromContext(ctx).Begin(zap.InfoLevel, "restore all schema") + logger := log.FromContext(ctx) concurrency := min(rc.cfg.App.RegionConcurrency, 8) - childCtx, cancel := context.WithCancel(ctx) - p := parser.New() - p.SetSQLMode(rc.cfg.TiDB.SQLMode) - worker := restoreSchemaWorker{ - ctx: childCtx, - quit: cancel, - logger: log.FromContext(ctx), - jobCh: make(chan *schemaJob, concurrency), - errCh: make(chan error), - db: rc.db, - parser: p, - store: rc.store, - } - for i := 0; i < concurrency; i++ { - go worker.doJob() - } - err := worker.makeJobs(rc.dbMetas, rc.preInfoGetter.FetchRemoteDBModels, rc.preInfoGetter.FetchRemoteTableModels) - logTask.End(zap.ErrorLevel, err) + // sql.DB is a connection pool, we set it to concurrency + 1(for job generator) + // to reuse connections, as we might call db.Conn/conn.Close many times. + // there's no API to get sql.DB.MaxIdleConns, so we revert to its default which is 2 + rc.db.SetMaxIdleConns(concurrency + 1) + defer rc.db.SetMaxIdleConns(2) + schemaImp := mydump.NewSchemaImporter(logger, rc.cfg.TiDB.SQLMode, rc.db, rc.store, concurrency) + err := schemaImp.Run(ctx, rc.dbMetas) if err != nil { return err } diff --git a/lightning/pkg/importer/restore_schema_test.go b/lightning/pkg/importer/restore_schema_test.go deleted file mode 100644 index d8a4026cb9..0000000000 --- a/lightning/pkg/importer/restore_schema_test.go +++ /dev/null @@ -1,247 +0,0 @@ -// Copyright 2022 PingCAP, Inc. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package importer - -import ( - "context" - stderrors "errors" - "fmt" - "testing" - - "github.com/DATA-DOG/go-sqlmock" - "github.com/pingcap/errors" - "github.com/pingcap/tidb/br/pkg/mock" - "github.com/pingcap/tidb/br/pkg/storage" - "github.com/pingcap/tidb/pkg/ddl" - "github.com/pingcap/tidb/pkg/lightning/backend" - "github.com/pingcap/tidb/pkg/lightning/checkpoints" - "github.com/pingcap/tidb/pkg/lightning/config" - "github.com/pingcap/tidb/pkg/lightning/mydump" - "github.com/pingcap/tidb/pkg/parser" - "github.com/pingcap/tidb/pkg/parser/ast" - "github.com/pingcap/tidb/pkg/parser/model" - "github.com/pingcap/tidb/pkg/parser/mysql" - tmock "github.com/pingcap/tidb/pkg/util/mock" - filter "github.com/pingcap/tidb/pkg/util/table-filter" - "github.com/stretchr/testify/require" - "github.com/stretchr/testify/suite" - "go.uber.org/mock/gomock" -) - -type restoreSchemaSuite struct { - suite.Suite - ctx context.Context - rc *Controller - controller *gomock.Controller - dbMock sqlmock.Sqlmock - tableInfos []*model.TableInfo - infoGetter *PreImportInfoGetterImpl - targetInfoGetter *TargetInfoGetterImpl -} - -func TestRestoreSchemaSuite(t *testing.T) { - suite.Run(t, new(restoreSchemaSuite)) -} - -func (s *restoreSchemaSuite) SetupSuite() { - ctx := context.Background() - fakeDataDir := s.T().TempDir() - - store, err := storage.NewLocalStorage(fakeDataDir) - require.NoError(s.T(), err) - // restore database schema file - fakeDBName := "fakedb" - // please follow the `mydump.defaultFileRouteRules`, matches files like '{schema}-schema-create.sql' - fakeFileName := fmt.Sprintf("%s-schema-create.sql", fakeDBName) - err = store.WriteFile(ctx, fakeFileName, []byte(fmt.Sprintf("CREATE DATABASE %s;", fakeDBName))) - require.NoError(s.T(), err) - // restore table schema files - fakeTableFilesCount := 8 - - p := parser.New() - p.SetSQLMode(mysql.ModeANSIQuotes) - se := tmock.NewContext() - - tableInfos := make([]*model.TableInfo, 0, fakeTableFilesCount) - for i := 1; i <= fakeTableFilesCount; i++ { - fakeTableName := fmt.Sprintf("tbl%d", i) - // please follow the `mydump.defaultFileRouteRules`, matches files like '{schema}.{table}-schema.sql' - fakeFileName := fmt.Sprintf("%s.%s-schema.sql", fakeDBName, fakeTableName) - fakeFileContent := fmt.Sprintf("CREATE TABLE %s(i TINYINT);", fakeTableName) - err = store.WriteFile(ctx, fakeFileName, []byte(fakeFileContent)) - require.NoError(s.T(), err) - - node, err := p.ParseOneStmt(fakeFileContent, "", "") - require.NoError(s.T(), err) - core, err := ddl.MockTableInfo(se, node.(*ast.CreateTableStmt), 0xabcdef) - require.NoError(s.T(), err) - core.State = model.StatePublic - tableInfos = append(tableInfos, core) - } - s.tableInfos = tableInfos - // restore view schema files - fakeViewFilesCount := 8 - for i := 1; i <= fakeViewFilesCount; i++ { - fakeViewName := fmt.Sprintf("tbl%d", i) - // please follow the `mydump.defaultFileRouteRules`, matches files like '{schema}.{table}-schema-view.sql' - fakeFileName := fmt.Sprintf("%s.%s-schema-view.sql", fakeDBName, fakeViewName) - fakeFileContent := []byte(fmt.Sprintf("CREATE ALGORITHM=UNDEFINED VIEW `%s` (`i`) AS SELECT `i` FROM `%s`.`%s`;", fakeViewName, fakeDBName, fmt.Sprintf("tbl%d", i))) - err = store.WriteFile(ctx, fakeFileName, fakeFileContent) - require.NoError(s.T(), err) - } - config := config.NewConfig() - config.Mydumper.DefaultFileRules = true - config.Mydumper.CharacterSet = "utf8mb4" - config.App.RegionConcurrency = 8 - mydumpLoader, err := mydump.NewLoaderWithStore(ctx, mydump.NewLoaderCfg(config), store) - s.Require().NoError(err) - - dbMetas := mydumpLoader.GetDatabases() - targetInfoGetter := &TargetInfoGetterImpl{ - cfg: config, - } - preInfoGetter := &PreImportInfoGetterImpl{ - cfg: config, - srcStorage: store, - targetInfoGetter: targetInfoGetter, - dbMetas: dbMetas, - } - preInfoGetter.Init() - s.rc = &Controller{ - checkTemplate: NewSimpleTemplate(), - cfg: config, - store: store, - dbMetas: dbMetas, - checkpointsDB: &checkpoints.NullCheckpointsDB{}, - preInfoGetter: preInfoGetter, - } - s.infoGetter = preInfoGetter - s.targetInfoGetter = targetInfoGetter -} - -//nolint:interfacer // change test case signature might cause Check failed to find this test case? -func (s *restoreSchemaSuite) SetupTest() { - s.controller, s.ctx = gomock.WithContext(context.Background(), s.T()) - mockTargetInfoGetter := mock.NewMockTargetInfoGetter(s.controller) - mockBackend := mock.NewMockBackend(s.controller) - mockTargetInfoGetter.EXPECT(). - FetchRemoteDBModels(gomock.Any()). - AnyTimes(). - Return([]*model.DBInfo{{Name: model.NewCIStr("fakedb")}}, nil) - mockTargetInfoGetter.EXPECT(). - FetchRemoteTableModels(gomock.Any(), gomock.Any()). - AnyTimes(). - Return(s.tableInfos, nil) - mockBackend.EXPECT().Close() - theBackend := backend.MakeEngineManager(mockBackend) - s.rc.engineMgr = theBackend - s.rc.backend = mockBackend - s.targetInfoGetter.backend = mockTargetInfoGetter - - mockDB, sqlMock, err := sqlmock.New() - require.NoError(s.T(), err) - for i := 0; i < 17; i++ { - sqlMock.ExpectExec(".*").WillReturnResult(sqlmock.NewResult(int64(i), 1)) - } - s.targetInfoGetter.db = mockDB - s.rc.db = mockDB - s.dbMock = sqlMock -} - -func (s *restoreSchemaSuite) TearDownTest() { - s.rc.Close() - s.controller.Finish() -} - -func (s *restoreSchemaSuite) TestRestoreSchemaSuccessful() { - // before restore, if sysVars is initialized by other test, the time_zone should be default value - if len(s.rc.sysVars) > 0 { - tz, ok := s.rc.sysVars["time_zone"] - require.True(s.T(), ok) - require.Equal(s.T(), "SYSTEM", tz) - } - - s.dbMock.ExpectQuery(".*").WillReturnRows(sqlmock.NewRows([]string{"time_zone"}).AddRow("SYSTEM")) - s.rc.cfg.TiDB.Vars = map[string]string{ - "time_zone": "UTC", - } - err := s.rc.restoreSchema(s.ctx) - require.NoError(s.T(), err) - - // test after restore schema, sysVars has been updated - tz, ok := s.rc.sysVars["time_zone"] - require.True(s.T(), ok) - require.Equal(s.T(), "UTC", tz) -} - -func (s *restoreSchemaSuite) TestRestoreSchemaFailed() { - // use injectErr which cannot be retried - injectErr := stderrors.New("could not match actual sql") - mockDB, sqlMock, err := sqlmock.New() - require.NoError(s.T(), err) - sqlMock.ExpectExec(".*").WillReturnError(injectErr) - for i := 0; i < 16; i++ { - sqlMock.ExpectExec(".*").WillReturnResult(sqlmock.NewResult(int64(i), 1)) - } - - s.rc.db = mockDB - s.targetInfoGetter.db = mockDB - err = s.rc.restoreSchema(s.ctx) - require.Error(s.T(), err) - require.True(s.T(), errors.ErrorEqual(err, injectErr)) -} - -// When restoring a CSV with `-no-schema` and the target table doesn't exist -// then we can't restore the schema as the `Path` is empty. This is to make -// sure this results in the correct error. -// https://github.com/pingcap/br/issues/1394 -func (s *restoreSchemaSuite) TestNoSchemaPath() { - fakeTable := mydump.MDTableMeta{ - DB: "fakedb", - Name: "fake1", - SchemaFile: mydump.FileInfo{ - TableName: filter.Table{ - Schema: "fakedb", - Name: "fake1", - }, - FileMeta: mydump.SourceFileMeta{ - Path: "", - }, - }, - DataFiles: []mydump.FileInfo{}, - TotalSize: 0, - } - s.rc.dbMetas[0].Tables = append(s.rc.dbMetas[0].Tables, &fakeTable) - err := s.rc.restoreSchema(s.ctx) - require.Error(s.T(), err) - require.Regexp(s.T(), `table .* schema not found`, err.Error()) - s.rc.dbMetas[0].Tables = s.rc.dbMetas[0].Tables[:len(s.rc.dbMetas[0].Tables)-1] -} - -func (s *restoreSchemaSuite) TestRestoreSchemaContextCancel() { - childCtx, cancel := context.WithCancel(s.ctx) - mockDB, sqlMock, err := sqlmock.New() - require.NoError(s.T(), err) - for i := 0; i < 17; i++ { - sqlMock.ExpectExec(".*").WillReturnResult(sqlmock.NewResult(int64(i), 1)) - } - s.rc.db = mockDB - s.targetInfoGetter.db = mockDB - cancel() - err = s.rc.restoreSchema(childCtx) - require.Error(s.T(), err) - err = errors.Cause(err) - require.Equal(s.T(), childCtx.Err(), err) -} diff --git a/lightning/pkg/importer/tidb.go b/lightning/pkg/importer/tidb.go index 9f083f76fe..a732bee093 100644 --- a/lightning/pkg/importer/tidb.go +++ b/lightning/pkg/importer/tidb.go @@ -31,8 +31,6 @@ import ( "github.com/pingcap/tidb/pkg/lightning/metric" "github.com/pingcap/tidb/pkg/lightning/mydump" "github.com/pingcap/tidb/pkg/parser" - "github.com/pingcap/tidb/pkg/parser/ast" - "github.com/pingcap/tidb/pkg/parser/format" "github.com/pingcap/tidb/pkg/parser/model" "github.com/pingcap/tidb/pkg/parser/mysql" "github.com/pingcap/tidb/pkg/sessionctx/variable" @@ -131,47 +129,6 @@ func (timgr *TiDBManager) Close() { timgr.db.Close() } -func createIfNotExistsStmt(p *parser.Parser, createTable, dbName, tblName string) ([]string, error) { - stmts, _, err := p.ParseSQL(createTable) - if err != nil { - return []string{}, common.ErrInvalidSchemaStmt.Wrap(err).GenWithStackByArgs(createTable) - } - - var res strings.Builder - ctx := format.NewRestoreCtx(format.DefaultRestoreFlags|format.RestoreTiDBSpecialComment|format.RestoreWithTTLEnableOff, &res) - - retStmts := make([]string, 0, len(stmts)) - for _, stmt := range stmts { - switch node := stmt.(type) { - case *ast.CreateDatabaseStmt: - node.Name = model.NewCIStr(dbName) - node.IfNotExists = true - case *ast.DropDatabaseStmt: - node.Name = model.NewCIStr(dbName) - node.IfExists = true - case *ast.CreateTableStmt: - node.Table.Schema = model.NewCIStr(dbName) - node.Table.Name = model.NewCIStr(tblName) - node.IfNotExists = true - case *ast.CreateViewStmt: - node.ViewName.Schema = model.NewCIStr(dbName) - node.ViewName.Name = model.NewCIStr(tblName) - case *ast.DropTableStmt: - node.Tables[0].Schema = model.NewCIStr(dbName) - node.Tables[0].Name = model.NewCIStr(tblName) - node.IfExists = true - } - if err := stmt.Restore(ctx); err != nil { - return []string{}, common.ErrInvalidSchemaStmt.Wrap(err).GenWithStackByArgs(createTable) - } - ctx.WritePlain(";") - retStmts = append(retStmts, res.String()) - res.Reset() - } - - return retStmts, nil -} - // DropTable drops a table. func (timgr *TiDBManager) DropTable(ctx context.Context, tableName string) error { sql := common.SQLWithRetry{ diff --git a/lightning/pkg/importer/tidb_test.go b/lightning/pkg/importer/tidb_test.go index 830ab7b242..e10f3bede6 100644 --- a/lightning/pkg/importer/tidb_test.go +++ b/lightning/pkg/importer/tidb_test.go @@ -28,7 +28,6 @@ import ( "github.com/pingcap/tidb/pkg/lightning/checkpoints" "github.com/pingcap/tidb/pkg/lightning/metric" "github.com/pingcap/tidb/pkg/lightning/mydump" - "github.com/pingcap/tidb/pkg/parser" "github.com/pingcap/tidb/pkg/parser/ast" "github.com/pingcap/tidb/pkg/parser/model" tmysql "github.com/pingcap/tidb/pkg/parser/mysql" @@ -61,109 +60,6 @@ func newTiDBSuite(t *testing.T) *tidbSuite { return &s } -func TestCreateTableIfNotExistsStmt(t *testing.T) { - dbName := "testdb" - p := parser.New() - createSQLIfNotExistsStmt := func(createTable, tableName string) []string { - res, err := createIfNotExistsStmt(p, createTable, dbName, tableName) - require.NoError(t, err) - return res - } - - require.Equal(t, []string{"CREATE DATABASE IF NOT EXISTS `testdb` CHARACTER SET = utf8 COLLATE = utf8_general_ci;"}, - createSQLIfNotExistsStmt("CREATE DATABASE `foo` CHARACTER SET = utf8 COLLATE = utf8_general_ci;", "")) - - require.Equal(t, []string{"CREATE TABLE IF NOT EXISTS `testdb`.`foo` (`bar` TINYINT(1));"}, - createSQLIfNotExistsStmt("CREATE TABLE `foo`(`bar` TINYINT(1));", "foo")) - - require.Equal(t, []string{"CREATE TABLE IF NOT EXISTS `testdb`.`foo` (`bar` TINYINT(1));"}, - createSQLIfNotExistsStmt("CREATE TABLE IF NOT EXISTS `foo`(`bar` TINYINT(1));", "foo")) - - // case insensitive - require.Equal(t, []string{"CREATE TABLE IF NOT EXISTS `testdb`.`fOo` (`bar` TINYINT(1));"}, - createSQLIfNotExistsStmt("/* cOmmEnt */ creAte tablE `fOo`(`bar` TinyinT(1));", "fOo")) - - require.Equal(t, []string{"CREATE TABLE IF NOT EXISTS `testdb`.`FoO` (`bAR` TINYINT(1));"}, - createSQLIfNotExistsStmt("/* coMMenT */ crEatE tAble If not EXISts `FoO`(`bAR` tiNyInT(1));", "FoO")) - - // only one "CREATE TABLE" is replaced - require.Equal(t, []string{"CREATE TABLE IF NOT EXISTS `testdb`.`foo` (`bar` INT(1) COMMENT 'CREATE TABLE');"}, - createSQLIfNotExistsStmt("CREATE TABLE `foo`(`bar` INT(1) COMMENT 'CREATE TABLE');", "foo")) - - // test clustered index consistency - require.Equal(t, []string{"CREATE TABLE IF NOT EXISTS `testdb`.`foo` (`bar` INT(1) PRIMARY KEY /*T![clustered_index] CLUSTERED */ COMMENT 'CREATE TABLE');"}, - createSQLIfNotExistsStmt("CREATE TABLE `foo`(`bar` INT(1) PRIMARY KEY CLUSTERED COMMENT 'CREATE TABLE');", "foo")) - require.Equal(t, []string{"CREATE TABLE IF NOT EXISTS `testdb`.`foo` (`bar` INT(1) COMMENT 'CREATE TABLE',PRIMARY KEY(`bar`) /*T![clustered_index] NONCLUSTERED */);"}, - createSQLIfNotExistsStmt("CREATE TABLE `foo`(`bar` INT(1) COMMENT 'CREATE TABLE', PRIMARY KEY (`bar`) NONCLUSTERED);", "foo")) - require.Equal(t, []string{"CREATE TABLE IF NOT EXISTS `testdb`.`foo` (`bar` INT(1) PRIMARY KEY /*T![clustered_index] NONCLUSTERED */ COMMENT 'CREATE TABLE');"}, - createSQLIfNotExistsStmt("CREATE TABLE `foo`(`bar` INT(1) PRIMARY KEY /*T![clustered_index] NONCLUSTERED */ COMMENT 'CREATE TABLE');", "foo")) - require.Equal(t, []string{"CREATE TABLE IF NOT EXISTS `testdb`.`foo` (`bar` INT(1) COMMENT 'CREATE TABLE',PRIMARY KEY(`bar`) /*T![clustered_index] CLUSTERED */);"}, - createSQLIfNotExistsStmt("CREATE TABLE `foo`(`bar` INT(1) COMMENT 'CREATE TABLE', PRIMARY KEY (`bar`) /*T![clustered_index] CLUSTERED */);", "foo")) - - require.Equal(t, []string{"CREATE TABLE IF NOT EXISTS `testdb`.`foo` (`bar` INT(1) PRIMARY KEY /*T![auto_rand] AUTO_RANDOM(2) */ COMMENT 'CREATE TABLE');"}, - createSQLIfNotExistsStmt("CREATE TABLE `foo`(`bar` INT(1) PRIMARY KEY AUTO_RANDOM(2) COMMENT 'CREATE TABLE');", "foo")) - - // upper case becomes shorter - require.Equal(t, []string{"CREATE TABLE IF NOT EXISTS `testdb`.`ſ` (`ı` TINYINT(1));"}, - createSQLIfNotExistsStmt("CREATE TABLE `ſ`(`ı` TINYINT(1));", "ſ")) - - // upper case becomes longer - require.Equal(t, []string{"CREATE TABLE IF NOT EXISTS `testdb`.`ɑ` (`ȿ` TINYINT(1));"}, - createSQLIfNotExistsStmt("CREATE TABLE `ɑ`(`ȿ` TINYINT(1));", "ɑ")) - - // non-utf-8 - require.Equal(t, []string{"CREATE TABLE IF NOT EXISTS `testdb`.`\xcc\xcc\xcc` (`???` TINYINT(1));"}, - createSQLIfNotExistsStmt("CREATE TABLE `\xcc\xcc\xcc`(`\xdd\xdd\xdd` TINYINT(1));", "\xcc\xcc\xcc")) - - // renaming a table - require.Equal(t, []string{"CREATE TABLE IF NOT EXISTS `testdb`.`ba``r` (`x` INT);"}, - createSQLIfNotExistsStmt("create table foo(x int);", "ba`r")) - - // conditional comments - require.Equal(t, []string{ - "SET NAMES 'binary';", - "SET @@SESSION.`FOREIGN_KEY_CHECKS`=0;", - "CREATE TABLE IF NOT EXISTS `testdb`.`m` (`z` DOUBLE) ENGINE = InnoDB AUTO_INCREMENT = 8343230 DEFAULT CHARACTER SET = UTF8;", - }, - createSQLIfNotExistsStmt(` - /*!40101 SET NAMES binary*/; - /*!40014 SET FOREIGN_KEY_CHECKS=0*/; - CREATE TABLE x.y (z double) ENGINE=InnoDB AUTO_INCREMENT=8343230 DEFAULT CHARSET=utf8; - `, "m")) - - // create view - require.Equal(t, []string{ - "SET NAMES 'binary';", - "DROP TABLE IF EXISTS `testdb`.`m`;", - "DROP VIEW IF EXISTS `testdb`.`m`;", - "SET @`PREV_CHARACTER_SET_CLIENT`=@@`character_set_client`;", - "SET @`PREV_CHARACTER_SET_RESULTS`=@@`character_set_results`;", - "SET @`PREV_COLLATION_CONNECTION`=@@`collation_connection`;", - "SET @@SESSION.`character_set_client`=`utf8`;", - "SET @@SESSION.`character_set_results`=`utf8`;", - "SET @@SESSION.`collation_connection`=`utf8_general_ci`;", - "CREATE ALGORITHM = UNDEFINED DEFINER = `root`@`192.168.198.178` SQL SECURITY DEFINER VIEW `testdb`.`m` (`s`) AS SELECT `s` FROM `db1`.`v1` WHERE `i`<2;", - "SET @@SESSION.`character_set_client`=@`PREV_CHARACTER_SET_CLIENT`;", - "SET @@SESSION.`character_set_results`=@`PREV_CHARACTER_SET_RESULTS`;", - "SET @@SESSION.`collation_connection`=@`PREV_COLLATION_CONNECTION`;", - }, - createSQLIfNotExistsStmt(` - /*!40101 SET NAMES binary*/; - DROP TABLE IF EXISTS v2; - DROP VIEW IF EXISTS v2; - SET @PREV_CHARACTER_SET_CLIENT=@@CHARACTER_SET_CLIENT; - SET @PREV_CHARACTER_SET_RESULTS=@@CHARACTER_SET_RESULTS; - SET @PREV_COLLATION_CONNECTION=@@COLLATION_CONNECTION; - SET character_set_client = utf8; - SET character_set_results = utf8; - SET collation_connection = utf8_general_ci; - CREATE ALGORITHM=UNDEFINED DEFINER=root@192.168.198.178 SQL SECURITY DEFINER VIEW v2 (s) AS SELECT s FROM db1.v1 WHERE i<2; - SET character_set_client = @PREV_CHARACTER_SET_CLIENT; - SET character_set_results = @PREV_CHARACTER_SET_RESULTS; - SET collation_connection = @PREV_COLLATION_CONNECTION; - `, "m")) -} - func TestDropTable(t *testing.T) { s := newTiDBSuite(t) ctx := context.Background() diff --git a/pkg/lightning/mydump/BUILD.bazel b/pkg/lightning/mydump/BUILD.bazel index 754f92ae65..381da027f3 100644 --- a/pkg/lightning/mydump/BUILD.bazel +++ b/pkg/lightning/mydump/BUILD.bazel @@ -13,6 +13,7 @@ go_library( "reader.go", "region.go", "router.go", + "schema_import.go", ], importpath = "github.com/pingcap/tidb/pkg/lightning/mydump", visibility = ["//visibility:public"], @@ -24,11 +25,18 @@ go_library( "//pkg/lightning/log", "//pkg/lightning/metric", "//pkg/lightning/worker", + "//pkg/parser", + "//pkg/parser/ast", + "//pkg/parser/format", + "//pkg/parser/model", "//pkg/parser/mysql", "//pkg/types", + "//pkg/util", "//pkg/util/filter", "//pkg/util/regexpr-router", + "//pkg/util/set", "//pkg/util/slice", + "//pkg/util/sqlescape", "//pkg/util/table-filter", "//pkg/util/zeropool", "@com_github_pingcap_errors//:errors", @@ -59,6 +67,7 @@ go_test( "reader_test.go", "region_test.go", "router_test.go", + "schema_import_test.go", ], data = glob([ "csv/*", @@ -75,12 +84,14 @@ go_test( "//pkg/lightning/config", "//pkg/lightning/log", "//pkg/lightning/worker", + "//pkg/parser", "//pkg/parser/mysql", "//pkg/testkit/testsetup", "//pkg/types", "//pkg/util/filter", "//pkg/util/table-filter", "//pkg/util/table-router", + "@com_github_data_dog_go_sqlmock//:go-sqlmock", "@com_github_pingcap_errors//:errors", "@com_github_stretchr_testify//assert", "@com_github_stretchr_testify//require", diff --git a/pkg/lightning/mydump/loader.go b/pkg/lightning/mydump/loader.go index 30c1e316d9..d2fc407ddc 100644 --- a/pkg/lightning/mydump/loader.go +++ b/pkg/lightning/mydump/loader.go @@ -130,7 +130,7 @@ func (m *MDTableMeta) GetSchema(ctx context.Context, store storage.ExternalStora zap.String("Path", m.SchemaFile.FileMeta.Path), log.ShortError(err), ) - return "", err + return "", errors.Trace(err) } return string(schema), nil } diff --git a/pkg/lightning/mydump/schema_import.go b/pkg/lightning/mydump/schema_import.go new file mode 100644 index 0000000000..32120f0ce4 --- /dev/null +++ b/pkg/lightning/mydump/schema_import.go @@ -0,0 +1,371 @@ +// Copyright 2024 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package mydump + +import ( + "context" + "database/sql" + "fmt" + "strings" + + "github.com/pingcap/errors" + "github.com/pingcap/tidb/br/pkg/storage" + "github.com/pingcap/tidb/pkg/lightning/common" + "github.com/pingcap/tidb/pkg/lightning/log" + "github.com/pingcap/tidb/pkg/parser" + "github.com/pingcap/tidb/pkg/parser/ast" + "github.com/pingcap/tidb/pkg/parser/format" + "github.com/pingcap/tidb/pkg/parser/model" + "github.com/pingcap/tidb/pkg/parser/mysql" + "github.com/pingcap/tidb/pkg/util" + "github.com/pingcap/tidb/pkg/util/set" + "github.com/pingcap/tidb/pkg/util/sqlescape" + "go.uber.org/zap" +) + +type schemaStmtType int + +// String implements fmt.Stringer interface. +func (stmtType schemaStmtType) String() string { + switch stmtType { + case schemaCreateDatabase: + return "restore database schema" + case schemaCreateTable: + return "restore table schema" + case schemaCreateView: + return "restore view schema" + } + return "unknown statement of schema" +} + +const ( + schemaCreateDatabase schemaStmtType = iota + schemaCreateTable + schemaCreateView +) + +type schemaJob struct { + dbName string + tblName string // empty for create db jobs + stmtType schemaStmtType + sqlStr string +} + +// SchemaImporter is used to import schema from dump files. +type SchemaImporter struct { + logger log.Logger + db *sql.DB + sqlMode mysql.SQLMode + store storage.ExternalStorage + concurrency int +} + +// NewSchemaImporter creates a new SchemaImporter instance. +func NewSchemaImporter(logger log.Logger, sqlMode mysql.SQLMode, db *sql.DB, store storage.ExternalStorage, concurrency int) *SchemaImporter { + return &SchemaImporter{ + logger: logger, + db: db, + sqlMode: sqlMode, + store: store, + concurrency: concurrency, + } +} + +// Run imports all schemas from the given database metas. +func (si *SchemaImporter) Run(ctx context.Context, dbMetas []*MDDatabaseMeta) (err error) { + logTask := si.logger.Begin(zap.InfoLevel, "restore all schema") + defer func() { + logTask.End(zap.ErrorLevel, err) + }() + + if len(dbMetas) == 0 { + return nil + } + + if err = si.importDatabases(ctx, dbMetas); err != nil { + return errors.Trace(err) + } + if err = si.importTables(ctx, dbMetas); err != nil { + return errors.Trace(err) + } + return errors.Trace(si.importViews(ctx, dbMetas)) +} + +func (si *SchemaImporter) importDatabases(ctx context.Context, dbMetas []*MDDatabaseMeta) error { + existingSchemas, err := si.getExistingDatabases(ctx) + if err != nil { + return err + } + + ch := make(chan *MDDatabaseMeta) + eg, egCtx := util.NewErrorGroupWithRecoverWithCtx(ctx) + for i := 0; i < si.concurrency; i++ { + eg.Go(func() error { + p := parser.New() + p.SetSQLMode(si.sqlMode) + for dbMeta := range ch { + sqlStr := dbMeta.GetSchema(egCtx, si.store) + if err2 := si.runJob(egCtx, p, &schemaJob{ + dbName: dbMeta.Name, + stmtType: schemaCreateDatabase, + sqlStr: sqlStr, + }); err2 != nil { + return err2 + } + } + return nil + }) + } + eg.Go(func() error { + defer close(ch) + for i := range dbMetas { + dbMeta := dbMetas[i] + // if downstream already has this database, we can skip ddl job + if existingSchemas.Exist(strings.ToLower(dbMeta.Name)) { + si.logger.Info("database already exists in downstream, skip", + zap.String("db", dbMeta.Name), + ) + continue + } + select { + case ch <- dbMeta: + case <-egCtx.Done(): + } + } + return nil + }) + + return eg.Wait() +} + +func (si *SchemaImporter) importTables(ctx context.Context, dbMetas []*MDDatabaseMeta) error { + ch := make(chan *MDTableMeta) + eg, egCtx := util.NewErrorGroupWithRecoverWithCtx(ctx) + for i := 0; i < si.concurrency; i++ { + eg.Go(func() error { + p := parser.New() + p.SetSQLMode(si.sqlMode) + for tableMeta := range ch { + sqlStr, err := tableMeta.GetSchema(egCtx, si.store) + if err != nil { + return err + } + if err = si.runJob(egCtx, p, &schemaJob{ + dbName: tableMeta.DB, + tblName: tableMeta.Name, + stmtType: schemaCreateTable, + sqlStr: sqlStr, + }); err != nil { + return err + } + } + return nil + }) + } + eg.Go(func() error { + defer close(ch) + for _, dbMeta := range dbMetas { + if len(dbMeta.Tables) == 0 { + continue + } + tables, err := si.getExistingTables(egCtx, dbMeta.Name) + if err != nil { + return err + } + for i := range dbMeta.Tables { + tblMeta := dbMeta.Tables[i] + if tables.Exist(strings.ToLower(tblMeta.Name)) { + // we already has this table in TiDB. + // we should skip ddl job and let SchemaValid check. + si.logger.Info("table already exists in downstream, skip", + zap.String("db", dbMeta.Name), + zap.String("table", tblMeta.Name), + ) + continue + } else if tblMeta.SchemaFile.FileMeta.Path == "" { + return common.ErrSchemaNotExists.GenWithStackByArgs(dbMeta.Name, tblMeta.Name) + } + + select { + case ch <- tblMeta: + case <-egCtx.Done(): + return egCtx.Err() + } + } + } + return nil + }) + + return eg.Wait() +} + +// dumpling dump a view as a table-schema sql file which creates a table of same name +// as the view, and a view-schema sql file which drops the table and creates the view. +func (si *SchemaImporter) importViews(ctx context.Context, dbMetas []*MDDatabaseMeta) error { + // 3. restore views. Since views can cross database we must restore views after all table schemas are restored. + // we don't support restore views concurrency, cauz it maybe will raise a error + p := parser.New() + p.SetSQLMode(si.sqlMode) + for _, dbMeta := range dbMetas { + if len(dbMeta.Views) == 0 { + continue + } + existingViews, err := si.getExistingViews(ctx, dbMeta.Name) + if err != nil { + return err + } + for _, viewMeta := range dbMeta.Views { + if existingViews.Exist(strings.ToLower(viewMeta.Name)) { + si.logger.Info("view already exists in downstream, skip", + zap.String("db", dbMeta.Name), + zap.String("view-name", viewMeta.Name)) + continue + } + sqlStr, err := viewMeta.GetSchema(ctx, si.store) + if err != nil { + return err + } + if strings.TrimSpace(sqlStr) == "" { + si.logger.Info("view schema is empty, skip", + zap.String("db", dbMeta.Name), + zap.String("view-name", viewMeta.Name)) + continue + } + if err = si.runJob(ctx, p, &schemaJob{ + dbName: dbMeta.Name, + tblName: viewMeta.Name, + stmtType: schemaCreateView, + sqlStr: sqlStr, + }); err != nil { + return err + } + } + } + return nil +} + +func (si *SchemaImporter) runJob(ctx context.Context, p *parser.Parser, job *schemaJob) error { + stmts, err := createIfNotExistsStmt(p, job.sqlStr, job.dbName, job.tblName) + if err != nil { + return errors.Trace(err) + } + conn, err := si.db.Conn(ctx) + if err != nil { + return err + } + defer func() { + _ = conn.Close() + }() + + logger := si.logger.With(zap.String("db", job.dbName), zap.String("table", job.tblName)) + sqlWithRetry := common.SQLWithRetry{ + Logger: logger, + DB: conn, + } + for _, stmt := range stmts { + task := logger.Begin(zap.DebugLevel, fmt.Sprintf("execute SQL: %s", stmt)) + err = sqlWithRetry.Exec(ctx, "run create schema job", stmt) + task.End(zap.ErrorLevel, err) + + if err != nil { + return common.ErrCreateSchema.Wrap(err).GenWithStackByArgs(common.UniqueTable(job.dbName, job.tblName), job.stmtType.String()) + } + } + return nil +} + +func (si *SchemaImporter) getExistingDatabases(ctx context.Context) (set.StringSet, error) { + return si.getExistingSchemas(ctx, `SELECT SCHEMA_NAME FROM information_schema.SCHEMATA`) +} + +// the result contains views too, but as table and view share the same name space, it's ok. +func (si *SchemaImporter) getExistingTables(ctx context.Context, dbName string) (set.StringSet, error) { + sb := new(strings.Builder) + sqlescape.MustFormatSQL(sb, `SELECT TABLE_NAME FROM information_schema.TABLES WHERE TABLE_SCHEMA = %?`, dbName) + return si.getExistingSchemas(ctx, sb.String()) +} + +func (si *SchemaImporter) getExistingViews(ctx context.Context, dbName string) (set.StringSet, error) { + sb := new(strings.Builder) + sqlescape.MustFormatSQL(sb, `SELECT TABLE_NAME FROM information_schema.VIEWS WHERE TABLE_SCHEMA = %?`, dbName) + return si.getExistingSchemas(ctx, sb.String()) +} + +// get existing databases/tables/views using the given query, the first column of +// the query result should be the name. +// The returned names are convert to lower case. +func (si *SchemaImporter) getExistingSchemas(ctx context.Context, query string) (set.StringSet, error) { + conn, err := si.db.Conn(ctx) + if err != nil { + return nil, errors.Trace(err) + } + defer func() { + _ = conn.Close() + }() + sqlWithRetry := common.SQLWithRetry{ + Logger: si.logger, + DB: conn, + } + stringRows, err := sqlWithRetry.QueryStringRows(ctx, "get existing schemas", query) + if err != nil { + return nil, errors.Trace(err) + } + res := make(set.StringSet, len(stringRows)) + for _, row := range stringRows { + res.Insert(strings.ToLower(row[0])) + } + return res, nil +} + +func createIfNotExistsStmt(p *parser.Parser, createTable, dbName, tblName string) ([]string, error) { + stmts, _, err := p.ParseSQL(createTable) + if err != nil { + return []string{}, common.ErrInvalidSchemaStmt.Wrap(err).GenWithStackByArgs(createTable) + } + + var res strings.Builder + ctx := format.NewRestoreCtx(format.DefaultRestoreFlags|format.RestoreTiDBSpecialComment|format.RestoreWithTTLEnableOff, &res) + + retStmts := make([]string, 0, len(stmts)) + for _, stmt := range stmts { + switch node := stmt.(type) { + case *ast.CreateDatabaseStmt: + node.Name = model.NewCIStr(dbName) + node.IfNotExists = true + case *ast.DropDatabaseStmt: + node.Name = model.NewCIStr(dbName) + node.IfExists = true + case *ast.CreateTableStmt: + node.Table.Schema = model.NewCIStr(dbName) + node.Table.Name = model.NewCIStr(tblName) + node.IfNotExists = true + case *ast.CreateViewStmt: + node.ViewName.Schema = model.NewCIStr(dbName) + node.ViewName.Name = model.NewCIStr(tblName) + case *ast.DropTableStmt: + node.Tables[0].Schema = model.NewCIStr(dbName) + node.Tables[0].Name = model.NewCIStr(tblName) + node.IfExists = true + } + if err := stmt.Restore(ctx); err != nil { + return []string{}, common.ErrInvalidSchemaStmt.Wrap(err).GenWithStackByArgs(createTable) + } + ctx.WritePlain(";") + retStmts = append(retStmts, res.String()) + res.Reset() + } + + return retStmts, nil +} diff --git a/pkg/lightning/mydump/schema_import_test.go b/pkg/lightning/mydump/schema_import_test.go new file mode 100644 index 0000000000..024aea9484 --- /dev/null +++ b/pkg/lightning/mydump/schema_import_test.go @@ -0,0 +1,370 @@ +// Copyright 2024 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package mydump + +import ( + "context" + "fmt" + "os" + "path" + "testing" + + "github.com/DATA-DOG/go-sqlmock" + "github.com/pingcap/errors" + "github.com/pingcap/tidb/br/pkg/storage" + "github.com/pingcap/tidb/pkg/lightning/common" + "github.com/pingcap/tidb/pkg/lightning/log" + "github.com/pingcap/tidb/pkg/parser" + "github.com/pingcap/tidb/pkg/parser/mysql" + "github.com/stretchr/testify/require" + "go.uber.org/zap" +) + +func TestSchemaImporter(t *testing.T) { + db, mock, err := sqlmock.New() + mock.MatchExpectationsInOrder(false) + require.NoError(t, err) + t.Cleanup(func() { + require.NoError(t, mock.ExpectationsWereMet()) + // have to ignore the error here, as sqlmock doesn't allow set number of + // expectations, and each opened connection requires a Close() call. + _ = db.Close() + }) + ctx := context.Background() + tempDir := t.TempDir() + store, err := storage.NewLocalStorage(tempDir) + require.NoError(t, err) + logger := log.Logger{Logger: zap.NewExample()} + importer := NewSchemaImporter(logger, mysql.SQLMode(0), db, store, 4) + require.NoError(t, importer.Run(ctx, nil)) + + t.Run("get existing schema err", func(t *testing.T) { + mock.ExpectQuery(`information_schema.SCHEMATA`).WillReturnError(errors.New("non retryable error")) + require.ErrorContains(t, importer.Run(ctx, []*MDDatabaseMeta{{Name: "test"}}), "non retryable error") + require.NoError(t, mock.ExpectationsWereMet()) + }) + + t.Run("database already exists", func(t *testing.T) { + mock.ExpectQuery(`information_schema.SCHEMATA`).WillReturnRows( + sqlmock.NewRows([]string{"SCHEMA_NAME"}).AddRow("test")) + require.NoError(t, importer.Run(ctx, []*MDDatabaseMeta{{Name: "test"}})) + require.NoError(t, mock.ExpectationsWereMet()) + }) + + t.Run("create non exist database", func(t *testing.T) { + mock.ExpectQuery(`information_schema.SCHEMATA`).WillReturnRows( + sqlmock.NewRows([]string{"SCHEMA_NAME"})) + dbMetas := make([]*MDDatabaseMeta, 0, 10) + for i := 0; i < 10; i++ { + mock.ExpectExec(fmt.Sprintf("CREATE DATABASE IF NOT EXISTS `test%02d`", i)). + WillReturnResult(sqlmock.NewResult(0, 0)) + dbMetas = append(dbMetas, &MDDatabaseMeta{Name: fmt.Sprintf("test%02d", i)}) + } + require.NoError(t, importer.Run(ctx, dbMetas)) + require.NoError(t, mock.ExpectationsWereMet()) + }) + + t.Run("break on database error", func(t *testing.T) { + importer2 := NewSchemaImporter(logger, mysql.SQLMode(0), db, store, 1) + mock.ExpectQuery(`information_schema.SCHEMATA`).WillReturnRows( + sqlmock.NewRows([]string{"SCHEMA_NAME"})) + fileName := "invalid-schema.sql" + require.NoError(t, os.WriteFile(path.Join(tempDir, fileName), []byte("CREATE invalid;"), 0o644)) + dbMetas := []*MDDatabaseMeta{ + {Name: "test", charSet: "auto", SchemaFile: FileInfo{FileMeta: SourceFileMeta{Path: fileName}}}, + {Name: "test2"}, // not chance to run + } + require.ErrorContains(t, importer2.Run(ctx, dbMetas), "invalid schema statement") + require.NoError(t, mock.ExpectationsWereMet()) + require.NoError(t, os.Remove(path.Join(tempDir, fileName))) + + dbMetas = append([]*MDDatabaseMeta{{Name: "ttt"}}, dbMetas...) + mock.ExpectQuery(`information_schema.SCHEMATA`).WillReturnRows( + sqlmock.NewRows([]string{"SCHEMA_NAME"})) + mock.ExpectExec("CREATE DATABASE IF NOT EXISTS `ttt`"). + WillReturnError(errors.New("non retryable error")) + err2 := importer2.Run(ctx, dbMetas) + require.ErrorIs(t, err2, common.ErrCreateSchema) + require.ErrorContains(t, err2, "non retryable error") + require.NoError(t, mock.ExpectationsWereMet()) + }) + + t.Run("table: get existing schema err", func(t *testing.T) { + mock.ExpectQuery(`information_schema.SCHEMATA`).WillReturnRows( + sqlmock.NewRows([]string{"SCHEMA_NAME"}). + AddRow("test01").AddRow("test02").AddRow("test03"). + AddRow("test04").AddRow("test05")) + mock.ExpectQuery("TABLES WHERE TABLE_SCHEMA = 'test02'"). + WillReturnError(errors.New("non retryable error")) + dbMetas := []*MDDatabaseMeta{ + {Name: "test01"}, + {Name: "test02", Tables: []*MDTableMeta{{DB: "test02", Name: "t"}}}, + } + require.ErrorContains(t, importer.Run(ctx, dbMetas), "non retryable error") + require.NoError(t, mock.ExpectationsWereMet()) + }) + + t.Run("table: invalid schema file", func(t *testing.T) { + mock.ExpectQuery(`information_schema.SCHEMATA`).WillReturnRows( + sqlmock.NewRows([]string{"SCHEMA_NAME"}). + AddRow("test01").AddRow("test02").AddRow("test03"). + AddRow("test04").AddRow("test05")) + mock.ExpectQuery("TABLES WHERE TABLE_SCHEMA = 'test01'"). + WillReturnRows(sqlmock.NewRows([]string{"TABLE_NAME"}).AddRow("t1")) + fileName := "t2-invalid-schema.sql" + require.NoError(t, os.WriteFile(path.Join(tempDir, fileName), []byte("CREATE table t2 whatever;"), 0o644)) + dbMetas := []*MDDatabaseMeta{ + {Name: "test01", Tables: []*MDTableMeta{ + {DB: "test01", Name: "t1"}, + {DB: "test01", Name: "T2", charSet: "auto", + SchemaFile: FileInfo{FileMeta: SourceFileMeta{Path: fileName}}}, + }}, + } + require.ErrorContains(t, importer.Run(ctx, dbMetas), "line 1 column 24 near") + require.NoError(t, mock.ExpectationsWereMet()) + + // create table t2 downstream manually as workaround + mock.ExpectQuery(`information_schema.SCHEMATA`).WillReturnRows( + sqlmock.NewRows([]string{"SCHEMA_NAME"}). + AddRow("test01").AddRow("test02").AddRow("test03"). + AddRow("test04").AddRow("test05")) + mock.ExpectQuery("TABLES WHERE TABLE_SCHEMA = 'test01'"). + WillReturnRows(sqlmock.NewRows([]string{"TABLE_NAME"}).AddRow("t1").AddRow("t2")) + require.NoError(t, importer.Run(ctx, dbMetas)) + require.NoError(t, mock.ExpectationsWereMet()) + require.NoError(t, os.Remove(path.Join(tempDir, fileName))) + }) + + t.Run("table: break on error", func(t *testing.T) { + importer2 := NewSchemaImporter(logger, mysql.SQLMode(0), db, store, 1) + mock.ExpectQuery(`information_schema.SCHEMATA`).WillReturnRows( + sqlmock.NewRows([]string{"SCHEMA_NAME"}). + AddRow("test01").AddRow("test02").AddRow("test03"). + AddRow("test04").AddRow("test05")) + fileNameT1 := "test01.t1-schema.sql" + fileNameT2 := "test01.t2-schema.sql" + require.NoError(t, os.WriteFile(path.Join(tempDir, fileNameT1), []byte("CREATE table t1(a int);"), 0o644)) + require.NoError(t, os.WriteFile(path.Join(tempDir, fileNameT2), []byte("CREATE table t2(a int);"), 0o644)) + dbMetas := []*MDDatabaseMeta{ + {Name: "test01", Tables: []*MDTableMeta{ + {DB: "test01", Name: "t1", charSet: "auto", SchemaFile: FileInfo{FileMeta: SourceFileMeta{Path: fileNameT1}}}, + {DB: "test01", Name: "t2", charSet: "auto", SchemaFile: FileInfo{FileMeta: SourceFileMeta{Path: fileNameT2}}}, + }}, + } + mock.ExpectQuery("TABLES WHERE TABLE_SCHEMA = 'test01'"). + WillReturnRows(sqlmock.NewRows([]string{"TABLE_NAME"})) + mock.ExpectExec("CREATE TABLE IF NOT EXISTS `test01`.`t1`"). + WillReturnError(errors.New("non retryable create table error")) + require.ErrorContains(t, importer2.Run(ctx, dbMetas), "non retryable create table error") + require.NoError(t, mock.ExpectationsWereMet()) + require.NoError(t, os.Remove(path.Join(tempDir, fileNameT1))) + require.NoError(t, os.Remove(path.Join(tempDir, fileNameT2))) + }) + + t.Run("view: get existing schema err", func(t *testing.T) { + mock.ExpectQuery(`information_schema.SCHEMATA`).WillReturnRows( + sqlmock.NewRows([]string{"SCHEMA_NAME"}).AddRow("test01").AddRow("test02")) + mock.ExpectQuery("VIEWS WHERE TABLE_SCHEMA = 'test02'"). + WillReturnError(errors.New("non retryable error")) + dbMetas := []*MDDatabaseMeta{ + {Name: "test01"}, + {Name: "test02", Views: []*MDTableMeta{{DB: "test02", Name: "v"}}}, + } + require.ErrorContains(t, importer.Run(ctx, dbMetas), "non retryable error") + require.NoError(t, mock.ExpectationsWereMet()) + }) + + t.Run("view: fail on create", func(t *testing.T) { + mock.ExpectQuery(`information_schema.SCHEMATA`).WillReturnRows( + sqlmock.NewRows([]string{"SCHEMA_NAME"}).AddRow("test01").AddRow("test02")) + mock.ExpectQuery("VIEWS WHERE TABLE_SCHEMA = 'test02'"). + WillReturnRows(sqlmock.NewRows([]string{"TABLE_NAME"})) + fileNameV0 := "empty-file.sql" + fileNameV1 := "invalid-schema.sql" + fileNameV2 := "test02.v2-schema-view.sql" + require.NoError(t, os.WriteFile(path.Join(tempDir, fileNameV0), []byte(""), 0o644)) + require.NoError(t, os.WriteFile(path.Join(tempDir, fileNameV1), []byte("xxxx;"), 0o644)) + require.NoError(t, os.WriteFile(path.Join(tempDir, fileNameV2), []byte("create view v2 as select * from t;"), 0o644)) + dbMetas := []*MDDatabaseMeta{ + {Name: "test01"}, + {Name: "test02", Views: []*MDTableMeta{ + {DB: "test02", Name: "V0", charSet: "auto", SchemaFile: FileInfo{FileMeta: SourceFileMeta{Path: fileNameV0}}}, + {DB: "test02", Name: "v1", charSet: "auto", SchemaFile: FileInfo{FileMeta: SourceFileMeta{Path: fileNameV1}}}, + {DB: "test02", Name: "V2", charSet: "auto", SchemaFile: FileInfo{FileMeta: SourceFileMeta{Path: fileNameV2}}}}}, + } + require.ErrorContains(t, importer.Run(ctx, dbMetas), `line 1 column 4 near "xxxx;"`) + require.NoError(t, mock.ExpectationsWereMet()) + + // create view v2 downstream manually as workaround + mock.ExpectQuery(`information_schema.SCHEMATA`).WillReturnRows( + sqlmock.NewRows([]string{"SCHEMA_NAME"}).AddRow("test01").AddRow("test02")) + mock.ExpectQuery("VIEWS WHERE TABLE_SCHEMA = 'test02'"). + WillReturnRows(sqlmock.NewRows([]string{"TABLE_NAME"}).AddRow("V1")) + mock.ExpectExec("VIEW `test02`.`V2` AS SELECT"). + WillReturnResult(sqlmock.NewResult(0, 0)) + require.NoError(t, importer.Run(ctx, dbMetas)) + require.NoError(t, mock.ExpectationsWereMet()) + require.NoError(t, os.Remove(path.Join(tempDir, fileNameV0))) + require.NoError(t, os.Remove(path.Join(tempDir, fileNameV1))) + require.NoError(t, os.Remove(path.Join(tempDir, fileNameV2))) + }) +} + +func TestSchemaImporterManyTables(t *testing.T) { + db, mock, err := sqlmock.New() + mock.MatchExpectationsInOrder(false) + require.NoError(t, err) + t.Cleanup(func() { + require.NoError(t, mock.ExpectationsWereMet()) + // have to ignore the error here, as sqlmock doesn't allow set number of + // expectations, and each opened connection requires a Close() call. + _ = db.Close() + }) + ctx := context.Background() + tempDir := t.TempDir() + store, err := storage.NewLocalStorage(tempDir) + require.NoError(t, err) + logger := log.Logger{Logger: zap.NewExample()} + importer := NewSchemaImporter(logger, mysql.SQLMode(0), db, store, 8) + + mock.ExpectQuery(`information_schema.SCHEMATA`).WillReturnRows( + sqlmock.NewRows([]string{"SCHEMA_NAME"})) + dbMetas := make([]*MDDatabaseMeta, 0, 30) + for i := 0; i < 30; i++ { + dbName := fmt.Sprintf("test%02d", i) + dbMeta := &MDDatabaseMeta{Name: dbName, Tables: make([]*MDTableMeta, 0, 100)} + mock.ExpectExec(fmt.Sprintf("CREATE DATABASE IF NOT EXISTS `%s`", dbName)). + WillReturnResult(sqlmock.NewResult(0, 0)) + mock.ExpectQuery(fmt.Sprintf("TABLES WHERE TABLE_SCHEMA = '%s'", dbName)). + WillReturnRows(sqlmock.NewRows([]string{"TABLE_NAME"})) + for j := 0; j < 50; j++ { + tblName := fmt.Sprintf("t%03d", j) + fileName := fmt.Sprintf("%s.%s-schema.sql", dbName, tblName) + require.NoError(t, os.WriteFile(path.Join(tempDir, fileName), []byte(fmt.Sprintf("CREATE TABLE %s(a int);", tblName)), 0o644)) + mock.ExpectExec(fmt.Sprintf("CREATE TABLE IF NOT EXISTS `%s`.`%s`", dbName, tblName)). + WillReturnResult(sqlmock.NewResult(0, 0)) + dbMeta.Tables = append(dbMeta.Tables, &MDTableMeta{ + DB: dbName, Name: tblName, charSet: "auto", + SchemaFile: FileInfo{FileMeta: SourceFileMeta{Path: fileName}}, + }) + } + dbMetas = append(dbMetas, dbMeta) + } + require.NoError(t, importer.Run(ctx, dbMetas)) + require.NoError(t, mock.ExpectationsWereMet()) +} + +func TestCreateTableIfNotExistsStmt(t *testing.T) { + dbName := "testdb" + p := parser.New() + createSQLIfNotExistsStmt := func(createTable, tableName string) []string { + res, err := createIfNotExistsStmt(p, createTable, dbName, tableName) + require.NoError(t, err) + return res + } + + require.Equal(t, []string{"CREATE DATABASE IF NOT EXISTS `testdb` CHARACTER SET = utf8 COLLATE = utf8_general_ci;"}, + createSQLIfNotExistsStmt("CREATE DATABASE `foo` CHARACTER SET = utf8 COLLATE = utf8_general_ci;", "")) + + require.Equal(t, []string{"CREATE TABLE IF NOT EXISTS `testdb`.`foo` (`bar` TINYINT(1));"}, + createSQLIfNotExistsStmt("CREATE TABLE `foo`(`bar` TINYINT(1));", "foo")) + + require.Equal(t, []string{"CREATE TABLE IF NOT EXISTS `testdb`.`foo` (`bar` TINYINT(1));"}, + createSQLIfNotExistsStmt("CREATE TABLE IF NOT EXISTS `foo`(`bar` TINYINT(1));", "foo")) + + // case insensitive + require.Equal(t, []string{"CREATE TABLE IF NOT EXISTS `testdb`.`fOo` (`bar` TINYINT(1));"}, + createSQLIfNotExistsStmt("/* cOmmEnt */ creAte tablE `fOo`(`bar` TinyinT(1));", "fOo")) + + require.Equal(t, []string{"CREATE TABLE IF NOT EXISTS `testdb`.`FoO` (`bAR` TINYINT(1));"}, + createSQLIfNotExistsStmt("/* coMMenT */ crEatE tAble If not EXISts `FoO`(`bAR` tiNyInT(1));", "FoO")) + + // only one "CREATE TABLE" is replaced + require.Equal(t, []string{"CREATE TABLE IF NOT EXISTS `testdb`.`foo` (`bar` INT(1) COMMENT 'CREATE TABLE');"}, + createSQLIfNotExistsStmt("CREATE TABLE `foo`(`bar` INT(1) COMMENT 'CREATE TABLE');", "foo")) + + // test clustered index consistency + require.Equal(t, []string{"CREATE TABLE IF NOT EXISTS `testdb`.`foo` (`bar` INT(1) PRIMARY KEY /*T![clustered_index] CLUSTERED */ COMMENT 'CREATE TABLE');"}, + createSQLIfNotExistsStmt("CREATE TABLE `foo`(`bar` INT(1) PRIMARY KEY CLUSTERED COMMENT 'CREATE TABLE');", "foo")) + require.Equal(t, []string{"CREATE TABLE IF NOT EXISTS `testdb`.`foo` (`bar` INT(1) COMMENT 'CREATE TABLE',PRIMARY KEY(`bar`) /*T![clustered_index] NONCLUSTERED */);"}, + createSQLIfNotExistsStmt("CREATE TABLE `foo`(`bar` INT(1) COMMENT 'CREATE TABLE', PRIMARY KEY (`bar`) NONCLUSTERED);", "foo")) + require.Equal(t, []string{"CREATE TABLE IF NOT EXISTS `testdb`.`foo` (`bar` INT(1) PRIMARY KEY /*T![clustered_index] NONCLUSTERED */ COMMENT 'CREATE TABLE');"}, + createSQLIfNotExistsStmt("CREATE TABLE `foo`(`bar` INT(1) PRIMARY KEY /*T![clustered_index] NONCLUSTERED */ COMMENT 'CREATE TABLE');", "foo")) + require.Equal(t, []string{"CREATE TABLE IF NOT EXISTS `testdb`.`foo` (`bar` INT(1) COMMENT 'CREATE TABLE',PRIMARY KEY(`bar`) /*T![clustered_index] CLUSTERED */);"}, + createSQLIfNotExistsStmt("CREATE TABLE `foo`(`bar` INT(1) COMMENT 'CREATE TABLE', PRIMARY KEY (`bar`) /*T![clustered_index] CLUSTERED */);", "foo")) + + require.Equal(t, []string{"CREATE TABLE IF NOT EXISTS `testdb`.`foo` (`bar` INT(1) PRIMARY KEY /*T![auto_rand] AUTO_RANDOM(2) */ COMMENT 'CREATE TABLE');"}, + createSQLIfNotExistsStmt("CREATE TABLE `foo`(`bar` INT(1) PRIMARY KEY AUTO_RANDOM(2) COMMENT 'CREATE TABLE');", "foo")) + + // upper case becomes shorter + require.Equal(t, []string{"CREATE TABLE IF NOT EXISTS `testdb`.`ſ` (`ı` TINYINT(1));"}, + createSQLIfNotExistsStmt("CREATE TABLE `ſ`(`ı` TINYINT(1));", "ſ")) + + // upper case becomes longer + require.Equal(t, []string{"CREATE TABLE IF NOT EXISTS `testdb`.`ɑ` (`ȿ` TINYINT(1));"}, + createSQLIfNotExistsStmt("CREATE TABLE `ɑ`(`ȿ` TINYINT(1));", "ɑ")) + + // non-utf-8 + require.Equal(t, []string{"CREATE TABLE IF NOT EXISTS `testdb`.`\xcc\xcc\xcc` (`???` TINYINT(1));"}, + createSQLIfNotExistsStmt("CREATE TABLE `\xcc\xcc\xcc`(`\xdd\xdd\xdd` TINYINT(1));", "\xcc\xcc\xcc")) + + // renaming a table + require.Equal(t, []string{"CREATE TABLE IF NOT EXISTS `testdb`.`ba``r` (`x` INT);"}, + createSQLIfNotExistsStmt("create table foo(x int);", "ba`r")) + + // conditional comments + require.Equal(t, []string{ + "SET NAMES 'binary';", + "SET @@SESSION.`FOREIGN_KEY_CHECKS`=0;", + "CREATE TABLE IF NOT EXISTS `testdb`.`m` (`z` DOUBLE) ENGINE = InnoDB AUTO_INCREMENT = 8343230 DEFAULT CHARACTER SET = UTF8;", + }, + createSQLIfNotExistsStmt(` + /*!40101 SET NAMES binary*/; + /*!40014 SET FOREIGN_KEY_CHECKS=0*/; + CREATE TABLE x.y (z double) ENGINE=InnoDB AUTO_INCREMENT=8343230 DEFAULT CHARSET=utf8; + `, "m")) + + // create view + require.Equal(t, []string{ + "SET NAMES 'binary';", + "DROP TABLE IF EXISTS `testdb`.`m`;", + "DROP VIEW IF EXISTS `testdb`.`m`;", + "SET @`PREV_CHARACTER_SET_CLIENT`=@@`character_set_client`;", + "SET @`PREV_CHARACTER_SET_RESULTS`=@@`character_set_results`;", + "SET @`PREV_COLLATION_CONNECTION`=@@`collation_connection`;", + "SET @@SESSION.`character_set_client`=`utf8`;", + "SET @@SESSION.`character_set_results`=`utf8`;", + "SET @@SESSION.`collation_connection`=`utf8_general_ci`;", + "CREATE ALGORITHM = UNDEFINED DEFINER = `root`@`192.168.198.178` SQL SECURITY DEFINER VIEW `testdb`.`m` (`s`) AS SELECT `s` FROM `db1`.`v1` WHERE `i`<2;", + "SET @@SESSION.`character_set_client`=@`PREV_CHARACTER_SET_CLIENT`;", + "SET @@SESSION.`character_set_results`=@`PREV_CHARACTER_SET_RESULTS`;", + "SET @@SESSION.`collation_connection`=@`PREV_COLLATION_CONNECTION`;", + }, + createSQLIfNotExistsStmt(` + /*!40101 SET NAMES binary*/; + DROP TABLE IF EXISTS v2; + DROP VIEW IF EXISTS v2; + SET @PREV_CHARACTER_SET_CLIENT=@@CHARACTER_SET_CLIENT; + SET @PREV_CHARACTER_SET_RESULTS=@@CHARACTER_SET_RESULTS; + SET @PREV_COLLATION_CONNECTION=@@COLLATION_CONNECTION; + SET character_set_client = utf8; + SET character_set_results = utf8; + SET collation_connection = utf8_general_ci; + CREATE ALGORITHM=UNDEFINED DEFINER=root@192.168.198.178 SQL SECURITY DEFINER VIEW v2 (s) AS SELECT s FROM db1.v1 WHERE i<2; + SET character_set_client = @PREV_CHARACTER_SET_CLIENT; + SET character_set_results = @PREV_CHARACTER_SET_RESULTS; + SET collation_connection = @PREV_COLLATION_CONNECTION; + `, "m")) +}