diff --git a/Makefile b/Makefile index cc9ac5ac1e..a01a96db41 100644 --- a/Makefile +++ b/Makefile @@ -546,6 +546,7 @@ mock_import: mockgen tools/bin/mockgen -package mock github.com/pingcap/tidb/pkg/disttask/framework/planner LogicalPlan,PipelineSpec > pkg/disttask/framework/mock/plan_mock.go tools/bin/mockgen -package mock github.com/pingcap/tidb/pkg/disttask/framework/storage Manager > pkg/disttask/framework/mock/storage_manager_mock.go tools/bin/mockgen -package mock github.com/pingcap/tidb/pkg/ingestor/ingestcli Client,WriteClient > pkg/ingestor/ingestcli/mock/client_mock.go + tools/bin/mockgen -package mock github.com/pingcap/tidb/pkg/importsdk FileScanner,JobManager,SQLGenerator,SDK > pkg/importsdk/mock/sdk_mock.go .PHONY: gen_mock gen_mock: mockgen diff --git a/pkg/importsdk/BUILD.bazel b/pkg/importsdk/BUILD.bazel index fa6759a8d5..dc570d1d36 100644 --- a/pkg/importsdk/BUILD.bazel +++ b/pkg/importsdk/BUILD.bazel @@ -2,26 +2,45 @@ load("@io_bazel_rules_go//go:def.bzl", "go_library", "go_test") go_library( name = "importsdk", - srcs = ["sdk.go"], + srcs = [ + "config.go", + "error.go", + "file_scanner.go", + "job_manager.go", + "model.go", + "pattern.go", + "sdk.go", + "sql_generator.go", + ], importpath = "github.com/pingcap/tidb/pkg/importsdk", visibility = ["//visibility:public"], deps = [ "//br/pkg/storage", + "//pkg/lightning/common", "//pkg/lightning/config", "//pkg/lightning/log", "//pkg/lightning/mydump", "//pkg/parser/mysql", "@com_github_pingcap_errors//:errors", + "@org_uber_go_zap//:zap", ], ) go_test( name = "importsdk_test", timeout = "short", - srcs = ["sdk_test.go"], + srcs = [ + "config_test.go", + "file_scanner_test.go", + "job_manager_test.go", + "model_test.go", + "pattern_test.go", + "sdk_test.go", + "sql_generator_test.go", + ], embed = [":importsdk"], flaky = True, - shard_count = 9, + shard_count = 20, deps = [ "//pkg/lightning/config", "//pkg/lightning/log", diff --git a/pkg/importsdk/config.go b/pkg/importsdk/config.go new file mode 100644 index 0000000000..4235eccb7b --- /dev/null +++ b/pkg/importsdk/config.go @@ -0,0 +1,118 @@ +// Copyright 2025 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package importsdk + +import ( + "github.com/pingcap/tidb/pkg/lightning/config" + "github.com/pingcap/tidb/pkg/lightning/log" + "github.com/pingcap/tidb/pkg/parser/mysql" +) + +// SDKOption customizes the SDK configuration +type SDKOption func(*SDKConfig) + +// SDKConfig is the configuration for the SDK +type SDKConfig struct { + // Loader options + concurrency int + sqlMode mysql.SQLMode + fileRouteRules []*config.FileRouteRule + routes config.Routes + filter []string + charset string + maxScanFiles *int + skipInvalidFiles bool + + // General options + logger log.Logger +} + +func defaultSDKConfig() *SDKConfig { + return &SDKConfig{ + concurrency: 4, + filter: config.GetDefaultFilter(), + logger: log.L(), + charset: "auto", + } +} + +// WithConcurrency sets the number of concurrent DB/Table creation workers. +func WithConcurrency(n int) SDKOption { + return func(cfg *SDKConfig) { + if n > 0 { + cfg.concurrency = n + } + } +} + +// WithLogger specifies a custom logger +func WithLogger(logger log.Logger) SDKOption { + return func(cfg *SDKConfig) { + cfg.logger = logger + } +} + +// WithSQLMode specifies the SQL mode for schema parsing +func WithSQLMode(mode mysql.SQLMode) SDKOption { + return func(cfg *SDKConfig) { + cfg.sqlMode = mode + } +} + +// WithFilter specifies a filter for the loader +func WithFilter(filter []string) SDKOption { + return func(cfg *SDKConfig) { + cfg.filter = filter + } +} + +// WithFileRouters sets the file routing rules. +func WithFileRouters(rules []*config.FileRouteRule) SDKOption { + return func(c *SDKConfig) { + c.fileRouteRules = rules + } +} + +// WithRoutes sets the table routing rules. +func WithRoutes(routes config.Routes) SDKOption { + return func(c *SDKConfig) { + c.routes = routes + } +} + +// WithCharset specifies the character set for import (default "auto"). +func WithCharset(cs string) SDKOption { + return func(cfg *SDKConfig) { + if cs != "" { + cfg.charset = cs + } + } +} + +// WithMaxScanFiles specifies custom file scan limitation +func WithMaxScanFiles(limit int) SDKOption { + return func(cfg *SDKConfig) { + if limit > 0 { + cfg.maxScanFiles = &limit + } + } +} + +// WithSkipInvalidFiles specifies whether sdk need raise error on found invalid files +func WithSkipInvalidFiles(skip bool) SDKOption { + return func(cfg *SDKConfig) { + cfg.skipInvalidFiles = skip + } +} diff --git a/pkg/importsdk/config_test.go b/pkg/importsdk/config_test.go new file mode 100644 index 0000000000..7a628912da --- /dev/null +++ b/pkg/importsdk/config_test.go @@ -0,0 +1,89 @@ +// Copyright 2025 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package importsdk + +import ( + "testing" + + "github.com/pingcap/tidb/pkg/lightning/config" + "github.com/pingcap/tidb/pkg/lightning/log" + "github.com/pingcap/tidb/pkg/parser/mysql" + "github.com/stretchr/testify/require" +) + +func TestDefaultSDKConfig(t *testing.T) { + cfg := defaultSDKConfig() + require.Equal(t, 4, cfg.concurrency) + require.Equal(t, config.GetDefaultFilter(), cfg.filter) + require.Equal(t, log.L(), cfg.logger) + require.Equal(t, "auto", cfg.charset) +} + +func TestSDKOptions(t *testing.T) { + cfg := defaultSDKConfig() + + // Test WithConcurrency + WithConcurrency(10)(cfg) + require.Equal(t, 10, cfg.concurrency) + WithConcurrency(-1)(cfg) // Should ignore invalid value + require.Equal(t, 10, cfg.concurrency) + + // Test WithLogger + logger := log.L() + WithLogger(logger)(cfg) + require.Equal(t, logger, cfg.logger) + + // Test WithSQLMode + mode := mysql.ModeStrictTransTables + WithSQLMode(mode)(cfg) + require.Equal(t, mode, cfg.sqlMode) + + // Test WithFilter + filter := []string{"*.*"} + WithFilter(filter)(cfg) + require.Equal(t, filter, cfg.filter) + + // Test WithFileRouters + routers := []*config.FileRouteRule{{Schema: "test"}} + WithFileRouters(routers)(cfg) + require.Equal(t, routers, cfg.fileRouteRules) + + // Test WithCharset + WithCharset("utf8mb4")(cfg) + require.Equal(t, "utf8mb4", cfg.charset) + WithCharset("")(cfg) // Should ignore empty value + require.Equal(t, "utf8mb4", cfg.charset) + + // Test WithMaxScanFiles + WithMaxScanFiles(100)(cfg) + require.NotNil(t, cfg.maxScanFiles) + require.Equal(t, 100, *cfg.maxScanFiles) + + // Test WithSkipInvalidFiles + WithSkipInvalidFiles(true)(cfg) + require.True(t, cfg.skipInvalidFiles) + + // Test WithRoutes + routes := config.Routes{ + { + SchemaPattern: "source_db", + TablePattern: "source_table", + TargetSchema: "target_db", + TargetTable: "target_table", + }, + } + WithRoutes(routes)(cfg) + require.Equal(t, routes, cfg.routes) +} diff --git a/pkg/importsdk/error.go b/pkg/importsdk/error.go new file mode 100644 index 0000000000..eb88c1bd90 --- /dev/null +++ b/pkg/importsdk/error.go @@ -0,0 +1,46 @@ +// Copyright 2025 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package importsdk + +import "github.com/pingcap/errors" + +var ( + // ErrNoDatabasesFound indicates that the dump source contains no recognizable databases. + ErrNoDatabasesFound = errors.New("no databases found in the source path") + // ErrSchemaNotFound indicates the target schema doesn't exist in the dump source. + ErrSchemaNotFound = errors.New("schema not found") + // ErrTableNotFound indicates the target table doesn't exist in the dump source. + ErrTableNotFound = errors.New("table not found") + // ErrNoTableDataFiles indicates a table has zero data files and thus cannot proceed. + ErrNoTableDataFiles = errors.New("no data files for table") + // ErrWildcardNotSpecific indicates a wildcard cannot uniquely match the table's files. + ErrWildcardNotSpecific = errors.New("cannot generate a unique wildcard pattern for the table's data files") + // ErrJobNotFound indicates the job is not found. + ErrJobNotFound = errors.New("job not found") + // ErrNoJobIDReturned indicates that the submit job query did not return a job ID. + ErrNoJobIDReturned = errors.New("no job id returned") + // ErrInvalidOptions indicates the options are invalid. + ErrInvalidOptions = errors.New("invalid options") + // ErrMultipleFieldsDefinedNullBy indicates that multiple FIELDS_DEFINED_NULL_BY values are defined, which is not supported. + ErrMultipleFieldsDefinedNullBy = errors.New("IMPORT INTO only supports one FIELDS_DEFINED_NULL_BY value") + // ErrParseStorageURL indicates that the storage backend URL is invalid. + ErrParseStorageURL = errors.New("failed to parse storage backend URL") + // ErrCreateExternalStorage indicates that the external storage cannot be created. + ErrCreateExternalStorage = errors.New("failed to create external storage") + // ErrCreateLoader indicates that the MyDump loader cannot be created. + ErrCreateLoader = errors.New("failed to create MyDump loader") + // ErrCreateSchema indicates that creating schemas and tables failed. + ErrCreateSchema = errors.New("failed to create schemas and tables") +) diff --git a/pkg/importsdk/file_scanner.go b/pkg/importsdk/file_scanner.go new file mode 100644 index 0000000000..14b42dab8a --- /dev/null +++ b/pkg/importsdk/file_scanner.go @@ -0,0 +1,271 @@ +// Copyright 2025 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package importsdk + +import ( + "context" + "database/sql" + "strings" + + "github.com/pingcap/errors" + "github.com/pingcap/tidb/br/pkg/storage" + "github.com/pingcap/tidb/pkg/lightning/common" + "github.com/pingcap/tidb/pkg/lightning/log" + "github.com/pingcap/tidb/pkg/lightning/mydump" + "go.uber.org/zap" +) + +// FileScanner defines the interface for scanning files +type FileScanner interface { + CreateSchemasAndTables(ctx context.Context) error + CreateSchemaAndTableByName(ctx context.Context, schema, table string) error + GetTableMetas(ctx context.Context) ([]*TableMeta, error) + GetTableMetaByName(ctx context.Context, db, table string) (*TableMeta, error) + GetTotalSize(ctx context.Context) int64 + Close() error +} + +type fileScanner struct { + sourcePath string + db *sql.DB + store storage.ExternalStorage + loader *mydump.MDLoader + logger log.Logger + config *SDKConfig +} + +// NewFileScanner creates a new FileScanner +func NewFileScanner(ctx context.Context, sourcePath string, db *sql.DB, cfg *SDKConfig) (FileScanner, error) { + u, err := storage.ParseBackend(sourcePath, nil) + if err != nil { + return nil, errors.Annotatef(ErrParseStorageURL, "source=%s, err=%v", sourcePath, err) + } + store, err := storage.New(ctx, u, &storage.ExternalStorageOptions{}) + if err != nil { + return nil, errors.Annotatef(ErrCreateExternalStorage, "source=%s, err=%v", sourcePath, err) + } + + ldrCfg := mydump.LoaderConfig{ + SourceURL: sourcePath, + Filter: cfg.filter, + FileRouters: cfg.fileRouteRules, + DefaultFileRules: len(cfg.fileRouteRules) == 0, + CharacterSet: cfg.charset, + Routes: cfg.routes, + } + + var loaderOptions []mydump.MDLoaderSetupOption + if cfg.maxScanFiles != nil && *cfg.maxScanFiles > 0 { + loaderOptions = append(loaderOptions, mydump.WithMaxScanFiles(*cfg.maxScanFiles)) + } + if cfg.concurrency > 0 { + loaderOptions = append(loaderOptions, mydump.WithScanFileConcurrency(cfg.concurrency)) + } + + // TODO: we can skip some time-consuming operation in constructFileInfo (like get real size of compressed file). + loader, err := mydump.NewLoaderWithStore(ctx, ldrCfg, store, loaderOptions...) + if err != nil { + if loader == nil || !errors.ErrorEqual(err, common.ErrTooManySourceFiles) { + return nil, errors.Annotatef(ErrCreateLoader, "source=%s, charset=%s, err=%v", sourcePath, cfg.charset, err) + } + } + + return &fileScanner{ + sourcePath: sourcePath, + db: db, + store: store, + loader: loader, + logger: cfg.logger, + config: cfg, + }, nil +} + +func (s *fileScanner) CreateSchemasAndTables(ctx context.Context) error { + dbMetas := s.loader.GetDatabases() + if len(dbMetas) == 0 { + return errors.Annotatef(ErrNoDatabasesFound, "source=%s", s.sourcePath) + } + + // Create all schemas and tables + importer := mydump.NewSchemaImporter( + s.logger, + s.config.sqlMode, + s.db, + s.store, + s.config.concurrency, + ) + + err := importer.Run(ctx, dbMetas) + if err != nil { + return errors.Annotatef(ErrCreateSchema, "source=%s, db_count=%d, err=%v", s.sourcePath, len(dbMetas), err) + } + + return nil +} + +// CreateSchemaAndTableByName creates specific table and database schema from source +func (s *fileScanner) CreateSchemaAndTableByName(ctx context.Context, schema, table string) error { + dbMetas := s.loader.GetDatabases() + // Find the specific table + for _, dbMeta := range dbMetas { + if dbMeta.Name != schema { + continue + } + + for _, tblMeta := range dbMeta.Tables { + if tblMeta.Name != table { + continue + } + + importer := mydump.NewSchemaImporter( + s.logger, + s.config.sqlMode, + s.db, + s.store, + s.config.concurrency, + ) + + err := importer.Run(ctx, []*mydump.MDDatabaseMeta{{ + Name: dbMeta.Name, + SchemaFile: dbMeta.SchemaFile, + Tables: []*mydump.MDTableMeta{tblMeta}, + }}) + if err != nil { + return errors.Annotatef(ErrCreateSchema, "source=%s, schema=%s, table=%s, err=%v", s.sourcePath, schema, table, err) + } + + return nil + } + + return errors.Annotatef(ErrTableNotFound, "schema=%s, table=%s", schema, table) + } + + return errors.Annotatef(ErrSchemaNotFound, "schema=%s", schema) +} + +func (s *fileScanner) GetTableMetas(context.Context) ([]*TableMeta, error) { + dbMetas := s.loader.GetDatabases() + allFiles := s.loader.GetAllFiles() + var results []*TableMeta + for _, dbMeta := range dbMetas { + for _, tblMeta := range dbMeta.Tables { + tableMeta, err := s.buildTableMeta(dbMeta, tblMeta, allFiles) + if err != nil { + if s.config.skipInvalidFiles { + s.logger.Warn("skipping table due to invalid files", zap.String("database", dbMeta.Name), zap.String("table", tblMeta.Name), zap.Error(err)) + continue + } + return nil, err + } + results = append(results, tableMeta) + } + } + + return results, nil +} + +func (s *fileScanner) GetTotalSize(ctx context.Context) int64 { + var total int64 + dbMetas := s.loader.GetDatabases() + for _, dbMeta := range dbMetas { + for _, tblMeta := range dbMeta.Tables { + total += tblMeta.TotalSize + } + } + return total +} + +func (s *fileScanner) Close() error { + if s.store != nil { + s.store.Close() + } + return nil +} + +func (s *fileScanner) buildTableMeta( + dbMeta *mydump.MDDatabaseMeta, + tblMeta *mydump.MDTableMeta, + allDataFiles map[string]mydump.FileInfo, +) (*TableMeta, error) { + tableMeta := &TableMeta{ + Database: dbMeta.Name, + Table: tblMeta.Name, + DataFiles: make([]DataFileMeta, 0, len(tblMeta.DataFiles)), + SchemaFile: tblMeta.SchemaFile.FileMeta.Path, + } + + // Process data files + dataFiles, totalSize := processDataFiles(tblMeta.DataFiles) + tableMeta.DataFiles = dataFiles + tableMeta.TotalSize = totalSize + + if len(tblMeta.DataFiles) == 0 { + s.logger.Warn("table has no data files", zap.String("database", dbMeta.Name), zap.String("table", tblMeta.Name)) + return tableMeta, nil + } + + wildcard, err := generateWildcardPath(tblMeta.DataFiles, allDataFiles) + if err != nil { + return nil, errors.Trace(err) + } + uri := s.store.URI() + // import into only support absolute path + uri = strings.TrimPrefix(uri, "file://") + tableMeta.WildcardPath = strings.TrimSuffix(uri, "/") + "/" + wildcard + + return tableMeta, nil +} + +// processDataFiles converts mydump data files to DataFileMeta and calculates total size +func processDataFiles(files []mydump.FileInfo) ([]DataFileMeta, int64) { + dataFiles := make([]DataFileMeta, 0, len(files)) + var totalSize int64 + + for _, dataFile := range files { + fileMeta := createDataFileMeta(dataFile) + dataFiles = append(dataFiles, fileMeta) + totalSize += dataFile.FileMeta.RealSize + } + + return dataFiles, totalSize +} + +// createDataFileMeta creates a DataFileMeta from a mydump.DataFile +func createDataFileMeta(file mydump.FileInfo) DataFileMeta { + return DataFileMeta{ + Path: file.FileMeta.Path, + Size: file.FileMeta.RealSize, + Format: file.FileMeta.Type, + Compression: file.FileMeta.Compression, + } +} + +func (s *fileScanner) GetTableMetaByName(ctx context.Context, db, table string) (*TableMeta, error) { + dbMetas := s.loader.GetDatabases() + allFiles := s.loader.GetAllFiles() + + for _, dbMeta := range dbMetas { + if dbMeta.Name != db { + continue + } + for _, tblMeta := range dbMeta.Tables { + if tblMeta.Name != table { + continue + } + return s.buildTableMeta(dbMeta, tblMeta, allFiles) + } + } + return nil, errors.Annotatef(ErrTableNotFound, "table %s.%s not found", db, table) +} diff --git a/pkg/importsdk/file_scanner_test.go b/pkg/importsdk/file_scanner_test.go new file mode 100644 index 0000000000..170f9f4298 --- /dev/null +++ b/pkg/importsdk/file_scanner_test.go @@ -0,0 +1,177 @@ +// Copyright 2025 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package importsdk + +import ( + "context" + "os" + "path/filepath" + "regexp" + "testing" + + "github.com/DATA-DOG/go-sqlmock" + "github.com/pingcap/tidb/pkg/lightning/config" + "github.com/pingcap/tidb/pkg/lightning/mydump" + filter "github.com/pingcap/tidb/pkg/util/table-filter" + "github.com/stretchr/testify/require" +) + +func TestCreateDataFileMeta(t *testing.T) { + fi := mydump.FileInfo{ + TableName: filter.Table{ + Schema: "db", + Name: "table", + }, + FileMeta: mydump.SourceFileMeta{ + Path: "s3://bucket/path/to/f", + FileSize: 123, + Type: mydump.SourceTypeCSV, + Compression: mydump.CompressionGZ, + RealSize: 456, + }, + } + df := createDataFileMeta(fi) + require.Equal(t, "s3://bucket/path/to/f", df.Path) + require.Equal(t, int64(456), df.Size) + require.Equal(t, mydump.SourceTypeCSV, df.Format) + require.Equal(t, mydump.CompressionGZ, df.Compression) +} + +func TestProcessDataFiles(t *testing.T) { + files := []mydump.FileInfo{ + {FileMeta: mydump.SourceFileMeta{Path: "s3://bucket/a", RealSize: 10}}, + {FileMeta: mydump.SourceFileMeta{Path: "s3://bucket/b", RealSize: 20}}, + } + dfm, total := processDataFiles(files) + require.Len(t, dfm, 2) + require.Equal(t, int64(30), total) + require.Equal(t, "s3://bucket/a", dfm[0].Path) + require.Equal(t, "s3://bucket/b", dfm[1].Path) +} + +func TestFileScanner(t *testing.T) { + tmpDir := t.TempDir() + ctx := context.Background() + + require.NoError(t, os.WriteFile(filepath.Join(tmpDir, "db1-schema-create.sql"), []byte("CREATE DATABASE db1;"), 0644)) + require.NoError(t, os.WriteFile(filepath.Join(tmpDir, "db1.t1-schema.sql"), []byte("CREATE TABLE t1 (id INT);"), 0644)) + require.NoError(t, os.WriteFile(filepath.Join(tmpDir, "db1.t1.001.csv"), []byte("1\n2"), 0644)) + + db, mock, err := sqlmock.New() + require.NoError(t, err) + defer db.Close() + + cfg := defaultSDKConfig() + scanner, err := NewFileScanner(ctx, "file://"+tmpDir, db, cfg) + require.NoError(t, err) + defer scanner.Close() + + t.Run("GetTotalSize", func(t *testing.T) { + size := scanner.GetTotalSize(ctx) + require.Equal(t, int64(3), size) + }) + + t.Run("GetTableMetas", func(t *testing.T) { + metas, err := scanner.GetTableMetas(ctx) + require.NoError(t, err) + require.Len(t, metas, 1) + require.Equal(t, "db1", metas[0].Database) + require.Equal(t, "t1", metas[0].Table) + require.Equal(t, int64(3), metas[0].TotalSize) + require.Len(t, metas[0].DataFiles, 1) + }) + + t.Run("GetTableMetaByName", func(t *testing.T) { + meta, err := scanner.GetTableMetaByName(ctx, "db1", "t1") + require.NoError(t, err) + require.Equal(t, "db1", meta.Database) + require.Equal(t, "t1", meta.Table) + + _, err = scanner.GetTableMetaByName(ctx, "db1", "nonexistent") + require.Error(t, err) + }) + + t.Run("CreateSchemasAndTables", func(t *testing.T) { + mock.ExpectQuery("SELECT SCHEMA_NAME FROM information_schema.SCHEMATA.*").WillReturnRows(sqlmock.NewRows([]string{"SCHEMA_NAME"})) + mock.ExpectExec(regexp.QuoteMeta("CREATE DATABASE IF NOT EXISTS `db1`")).WillReturnResult(sqlmock.NewResult(0, 0)) + mock.ExpectExec(regexp.QuoteMeta("CREATE TABLE IF NOT EXISTS `db1`.`t1`")).WillReturnResult(sqlmock.NewResult(0, 0)) + + err := scanner.CreateSchemasAndTables(ctx) + require.NoError(t, err) + require.NoError(t, mock.ExpectationsWereMet()) + }) + + t.Run("CreateSchemaAndTableByName", func(t *testing.T) { + mock.ExpectQuery("SELECT SCHEMA_NAME FROM information_schema.SCHEMATA.*").WillReturnRows(sqlmock.NewRows([]string{"SCHEMA_NAME"})) + mock.ExpectExec(regexp.QuoteMeta("CREATE DATABASE IF NOT EXISTS `db1`")).WillReturnResult(sqlmock.NewResult(0, 0)) + mock.ExpectExec(regexp.QuoteMeta("CREATE TABLE IF NOT EXISTS `db1`.`t1`")).WillReturnResult(sqlmock.NewResult(0, 0)) + + err := scanner.CreateSchemaAndTableByName(ctx, "db1", "t1") + require.NoError(t, err) + require.NoError(t, mock.ExpectationsWereMet()) + + err = scanner.CreateSchemaAndTableByName(ctx, "db1", "nonexistent") + require.Error(t, err) + }) +} + +func TestFileScannerWithSkipInvalidFiles(t *testing.T) { + tmpDir := t.TempDir() + ctx := context.Background() + + require.NoError(t, os.WriteFile(filepath.Join(tmpDir, "data1.csv"), []byte("1"), 0644)) + require.NoError(t, os.WriteFile(filepath.Join(tmpDir, "data2.csv"), []byte("1"), 0644)) + require.NoError(t, os.WriteFile(filepath.Join(tmpDir, "data3.csv"), []byte("1"), 0644)) + + db, _, err := sqlmock.New() + require.NoError(t, err) + defer db.Close() + + rules := []*config.FileRouteRule{ + { + Pattern: "data[1-2].csv", + Schema: "db1", + Table: "t1", + Type: "csv", + }, + { + Pattern: "data3.csv", + Schema: "db1", + Table: "t2", + Type: "csv", + }, + } + + cfg := defaultSDKConfig() + WithFileRouters(rules)(cfg) + + scanner, err := NewFileScanner(ctx, "file://"+tmpDir, db, cfg) + require.NoError(t, err) + defer scanner.Close() + + metas, err := scanner.GetTableMetas(ctx) + require.Error(t, err) + require.Nil(t, metas) + + cfg.skipInvalidFiles = true + scanner2, err := NewFileScanner(ctx, "file://"+tmpDir, db, cfg) + require.NoError(t, err) + defer scanner2.Close() + + metas2, err := scanner2.GetTableMetas(ctx) + require.NoError(t, err) + require.Len(t, metas2, 1) + require.Equal(t, "t2", metas2[0].Table) +} diff --git a/pkg/importsdk/job_manager.go b/pkg/importsdk/job_manager.go new file mode 100644 index 0000000000..c2ab3a0c3b --- /dev/null +++ b/pkg/importsdk/job_manager.go @@ -0,0 +1,258 @@ +// Copyright 2025 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package importsdk + +import ( + "context" + "database/sql" + "fmt" + "strings" + "time" + + "github.com/pingcap/errors" +) + +// JobManager defines the interface for managing import jobs +type JobManager interface { + SubmitJob(ctx context.Context, query string) (int64, error) + GetJobStatus(ctx context.Context, jobID int64) (*JobStatus, error) + CancelJob(ctx context.Context, jobID int64) error + GetGroupSummary(ctx context.Context, groupKey string) (*GroupStatus, error) + GetJobsByGroup(ctx context.Context, groupKey string) ([]*JobStatus, error) +} + +const timeLayout = "2006-01-02 15:04:05" + +type jobManager struct { + db *sql.DB +} + +// NewJobManager creates a new JobManager +func NewJobManager(db *sql.DB) JobManager { + return &jobManager{ + db: db, + } +} + +// SubmitJob submits an import job and returns the job ID +func (m *jobManager) SubmitJob(ctx context.Context, query string) (int64, error) { + rows, err := m.db.QueryContext(ctx, query) + if err != nil { + return 0, errors.Trace(err) + } + defer rows.Close() + + if rows.Next() { + status, err := scanJobStatus(rows) + if err != nil { + return 0, errors.Trace(err) + } + return status.JobID, nil + } + + if err := rows.Err(); err != nil { + return 0, errors.Trace(err) + } + + return 0, ErrNoJobIDReturned +} + +// GetJobStatus gets the status of an import job +func (m *jobManager) GetJobStatus(ctx context.Context, jobID int64) (*JobStatus, error) { + query := fmt.Sprintf("SHOW IMPORT JOB %d", jobID) + rows, err := m.db.QueryContext(ctx, query) + if err != nil { + return nil, errors.Trace(err) + } + defer rows.Close() + + if rows.Next() { + return scanJobStatus(rows) + } + + if err := rows.Err(); err != nil { + return nil, errors.Trace(err) + } + + return nil, ErrJobNotFound +} + +// GetGroupSummary returns aggregated information for the specified group key. +func (m *jobManager) GetGroupSummary(ctx context.Context, groupKey string) (*GroupStatus, error) { + if groupKey == "" { + return nil, ErrInvalidOptions + } + query := fmt.Sprintf("SHOW IMPORT GROUP '%s'", strings.ReplaceAll(groupKey, "'", "''")) + rows, err := m.db.QueryContext(ctx, query) + if err != nil { + return nil, errors.Trace(err) + } + defer rows.Close() + + if rows.Next() { + status, err := scanGroupStatus(rows) + if err != nil { + return nil, errors.Trace(err) + } + return status, nil + } + + if err := rows.Err(); err != nil { + return nil, errors.Trace(err) + } + return nil, ErrJobNotFound +} + +// GetJobsByGroup returns all jobs for the specified group key. +func (m *jobManager) GetJobsByGroup(ctx context.Context, groupKey string) ([]*JobStatus, error) { + if groupKey == "" { + return nil, ErrInvalidOptions + } + query := fmt.Sprintf("SHOW IMPORT JOBS WHERE GROUP_KEY = '%s'", strings.ReplaceAll(groupKey, "'", "''")) + rows, err := m.db.QueryContext(ctx, query) + if err != nil { + return nil, errors.Trace(err) + } + defer rows.Close() + + var jobs []*JobStatus + for rows.Next() { + status, err := scanJobStatus(rows) + if err != nil { + return nil, errors.Trace(err) + } + jobs = append(jobs, status) + } + + if err := rows.Err(); err != nil { + return nil, errors.Trace(err) + } + return jobs, nil +} + +func scanJobStatus(rows *sql.Rows) (*JobStatus, error) { + var ( + id int64 + groupKey sql.NullString + dataSource string + targetTable string + tableID int64 + phase string + status string + sourceFileSize string + importedRows sql.NullInt64 + resultMessage sql.NullString + createTimeStr string + startTimeStr sql.NullString + endTimeStr sql.NullString + createdBy string + updateTimeStr sql.NullString + step sql.NullString + processedSize sql.NullString + totalSize sql.NullString + percent sql.NullString + speed sql.NullString + eta sql.NullString + ) + + err := rows.Scan( + &id, &groupKey, &dataSource, &targetTable, &tableID, + &phase, &status, &sourceFileSize, &importedRows, &resultMessage, + &createTimeStr, &startTimeStr, &endTimeStr, &createdBy, &updateTimeStr, + &step, &processedSize, &totalSize, &percent, &speed, &eta, + ) + if err != nil { + return nil, errors.Trace(err) + } + + // Parse times + createTime := parseTime(createTimeStr) + startTime := parseNullTime(startTimeStr) + endTime := parseNullTime(endTimeStr) + updateTime := parseNullTime(updateTimeStr) + + return &JobStatus{ + JobID: id, + GroupKey: groupKey.String, + DataSource: dataSource, + TargetTable: targetTable, + TableID: tableID, + Phase: phase, + Status: status, + SourceFileSize: sourceFileSize, + ImportedRows: importedRows.Int64, + ResultMessage: resultMessage.String, + CreateTime: createTime, + StartTime: startTime, + EndTime: endTime, + CreatedBy: createdBy, + UpdateTime: updateTime, + Step: step.String, + ProcessedSize: processedSize.String, + TotalSize: totalSize.String, + Percent: percent.String, + Speed: speed.String, + ETA: eta.String, + }, nil +} + +func scanGroupStatus(rows *sql.Rows) (*GroupStatus, error) { + var ( + groupKey string + totalJobs int64 + pending int64 + running int64 + completed int64 + failed int64 + cancelled int64 + firstCreateTime sql.NullString + lastUpdateTime sql.NullString + ) + + if err := rows.Scan(&groupKey, &totalJobs, &pending, &running, &completed, &failed, &cancelled, &firstCreateTime, &lastUpdateTime); err != nil { + return nil, errors.Trace(err) + } + + return &GroupStatus{ + GroupKey: groupKey, + TotalJobs: totalJobs, + Pending: pending, + Running: running, + Completed: completed, + Failed: failed, + Cancelled: cancelled, + FirstJobCreateTime: parseNullTime(firstCreateTime), + LastJobUpdateTime: parseNullTime(lastUpdateTime), + }, nil +} + +// CancelJob cancels an import job +func (m *jobManager) CancelJob(ctx context.Context, jobID int64) error { + query := fmt.Sprintf("CANCEL IMPORT JOB %d", jobID) + _, err := m.db.ExecContext(ctx, query) + return errors.Trace(err) +} + +func parseTime(s string) time.Time { + t, _ := time.Parse(timeLayout, s) + return t +} + +func parseNullTime(ns sql.NullString) time.Time { + if !ns.Valid { + return time.Time{} + } + return parseTime(ns.String) +} diff --git a/pkg/importsdk/job_manager_test.go b/pkg/importsdk/job_manager_test.go new file mode 100644 index 0000000000..317a182f06 --- /dev/null +++ b/pkg/importsdk/job_manager_test.go @@ -0,0 +1,242 @@ +// Copyright 2025 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package importsdk + +import ( + "context" + "database/sql" + "testing" + + "github.com/DATA-DOG/go-sqlmock" + "github.com/stretchr/testify/require" +) + +func TestSubmitJob(t *testing.T) { + db, mock, err := sqlmock.New() + require.NoError(t, err) + defer db.Close() + + manager := NewJobManager(db) + ctx := context.Background() + sqlQuery := "IMPORT INTO ..." + + // Columns expected by scanJobStatus + cols := []string{ + "Job_ID", "Group_Key", "Data_Source", "Target_Table", "Table_ID", + "Phase", "Status", "Source_File_Size", "Imported_Rows", "Result_Message", + "Create_Time", "Start_Time", "End_Time", "Created_By", "Update_Time", + "Step", "Processed_Size", "Total_Size", "Percent", "Speed", "ETA", + } + + // Case 1: Success + rows := sqlmock.NewRows(cols).AddRow( + int64(123), "", "s3://bucket/file.csv", "db.table", int64(1), + "import", "finished", "100MB", int64(1000), "success", + "2023-01-01 10:00:00", "2023-01-01 10:00:01", "2023-01-01 10:00:02", "user", "2023-01-01 10:00:02", + "", "100MB", "100MB", "100%", "10MB/s", "0s", + ) + mock.ExpectQuery(sqlQuery).WillReturnRows(rows) + + jobID, err := manager.SubmitJob(ctx, sqlQuery) + require.NoError(t, err) + require.Equal(t, int64(123), jobID) + + // Case 2: No rows returned + mock.ExpectQuery(sqlQuery).WillReturnRows(sqlmock.NewRows(cols)) + _, err = manager.SubmitJob(ctx, sqlQuery) + require.Error(t, err) + require.Contains(t, err.Error(), "no job id returned") + + // Case 3: Error + mock.ExpectQuery(sqlQuery).WillReturnError(sql.ErrConnDone) + _, err = manager.SubmitJob(ctx, sqlQuery) + require.Error(t, err) + + require.NoError(t, mock.ExpectationsWereMet()) +} + +func TestGetJobStatus(t *testing.T) { + db, mock, err := sqlmock.New() + require.NoError(t, err) + defer db.Close() + + manager := NewJobManager(db) + ctx := context.Background() + jobID := int64(123) + + // Columns expected by GetJobStatus (same as scanJobStatus) + cols := []string{ + "Job_ID", "Group_Key", "Data_Source", "Target_Table", "Table_ID", + "Phase", "Status", "Source_File_Size", "Imported_Rows", "Result_Message", + "Create_Time", "Start_Time", "End_Time", "Created_By", "Update_Time", + "Step", "Processed_Size", "Total_Size", "Percent", "Speed", "ETA", + } + + // Case 1: Success + rows := sqlmock.NewRows(cols).AddRow( + jobID, "", "s3://bucket/file.csv", "db.table", int64(1), + "import", "finished", "100MB", int64(1000), "success", + "2023-01-01 10:00:00", "2023-01-01 10:00:01", "2023-01-01 10:00:02", "user", "2023-01-01 10:00:02", + "", "100MB", "100MB", "100%", "10MB/s", "0s", + ) + mock.ExpectQuery("SHOW IMPORT JOB 123").WillReturnRows(rows) + + status, err := manager.GetJobStatus(ctx, jobID) + require.NoError(t, err) + require.Equal(t, jobID, status.JobID) + require.Equal(t, "finished", status.Status) + + // Case 2: Job not found + mock.ExpectQuery("SHOW IMPORT JOB 123").WillReturnRows(sqlmock.NewRows(cols)) + _, err = manager.GetJobStatus(ctx, jobID) + require.ErrorIs(t, err, ErrJobNotFound) + + // Case 3: Error + mock.ExpectQuery("SHOW IMPORT JOB 123").WillReturnError(sql.ErrConnDone) + _, err = manager.GetJobStatus(ctx, jobID) + require.Error(t, err) + + require.NoError(t, mock.ExpectationsWereMet()) +} + +func TestCancelJob(t *testing.T) { + db, mock, err := sqlmock.New() + require.NoError(t, err) + defer db.Close() + + manager := NewJobManager(db) + ctx := context.Background() + jobID := int64(123) + + // Case 1: Success + mock.ExpectExec("CANCEL IMPORT JOB 123").WillReturnResult(sqlmock.NewResult(0, 0)) + err = manager.CancelJob(ctx, jobID) + require.NoError(t, err) + + // Case 2: Error + mock.ExpectExec("CANCEL IMPORT JOB 123").WillReturnError(sql.ErrConnDone) + err = manager.CancelJob(ctx, jobID) + require.Error(t, err) + + require.NoError(t, mock.ExpectationsWereMet()) +} + +func TestGetGroupSummary(t *testing.T) { + db, mock, err := sqlmock.New() + require.NoError(t, err) + defer db.Close() + + manager := NewJobManager(db) + ctx := context.Background() + groupKey := "test_group" + + // Columns expected by scanGroupStatus + cols := []string{ + "Group_Key", "Total_Jobs", "Pending", "Running", "Completed", "Failed", "Cancelled", + "First_Job_Create_Time", "Last_Job_Update_Time", + } + + // Case 1: Success + rows := sqlmock.NewRows(cols).AddRow( + groupKey, int64(10), int64(1), int64(2), int64(3), int64(2), int64(2), + "2023-01-01 10:00:00", "2023-01-01 12:00:00", + ) + mock.ExpectQuery("SHOW IMPORT GROUP 'test_group'").WillReturnRows(rows) + + summary, err := manager.GetGroupSummary(ctx, groupKey) + require.NoError(t, err) + require.Equal(t, groupKey, summary.GroupKey) + require.Equal(t, int64(10), summary.TotalJobs) + require.Equal(t, int64(1), summary.Pending) + require.Equal(t, int64(2), summary.Running) + require.Equal(t, int64(3), summary.Completed) + require.Equal(t, int64(2), summary.Failed) + require.Equal(t, int64(2), summary.Cancelled) + + // Case 2: Empty Group Key + _, err = manager.GetGroupSummary(ctx, "") + require.ErrorIs(t, err, ErrInvalidOptions) + + // Case 3: Group not found + mock.ExpectQuery("SHOW IMPORT GROUP 'test_group'").WillReturnRows(sqlmock.NewRows(cols)) + _, err = manager.GetGroupSummary(ctx, groupKey) + require.ErrorIs(t, err, ErrJobNotFound) + + // Case 4: Error + mock.ExpectQuery("SHOW IMPORT GROUP 'test_group'").WillReturnError(sql.ErrConnDone) + _, err = manager.GetGroupSummary(ctx, groupKey) + require.Error(t, err) + + require.NoError(t, mock.ExpectationsWereMet()) +} + +func TestGetJobsByGroup(t *testing.T) { + db, mock, err := sqlmock.New() + require.NoError(t, err) + defer db.Close() + + manager := NewJobManager(db) + ctx := context.Background() + groupKey := "test_group" + + // Columns expected by scanJobStatus + cols := []string{ + "Job_ID", "Group_Key", "Data_Source", "Target_Table", "Table_ID", + "Phase", "Status", "Source_File_Size", "Imported_Rows", "Result_Message", + "Create_Time", "Start_Time", "End_Time", "Created_By", "Update_Time", + "Step", "Processed_Size", "Total_Size", "Percent", "Speed", "ETA", + } + + // Case 1: Success + rows := sqlmock.NewRows(cols). + AddRow( + int64(1), groupKey, "s3://bucket/file1.csv", "db.t1", int64(1), + "import", "finished", "100MB", int64(1000), "success", + "2023-01-01 10:00:00", "2023-01-01 10:00:01", "2023-01-01 10:00:02", "user", "2023-01-01 10:00:02", + "", "100MB", "100MB", "100%", "10MB/s", "0s", + ). + AddRow( + int64(2), groupKey, "s3://bucket/file2.csv", "db.t2", int64(2), + "import", "running", "200MB", int64(500), "", + "2023-01-01 10:00:00", "2023-01-01 10:00:01", "", "user", "2023-01-01 10:00:02", + "", "100MB", "200MB", "50%", "10MB/s", "10s", + ) + mock.ExpectQuery("SHOW IMPORT JOBS WHERE GROUP_KEY = 'test_group'").WillReturnRows(rows) + + jobs, err := manager.GetJobsByGroup(ctx, groupKey) + require.NoError(t, err) + require.Len(t, jobs, 2) + require.Equal(t, int64(1), jobs[0].JobID) + require.Equal(t, "finished", jobs[0].Status) + require.Equal(t, int64(2), jobs[1].JobID) + require.Equal(t, "running", jobs[1].Status) + + // Case 2: Empty Group Key + _, err = manager.GetJobsByGroup(ctx, "") + require.ErrorIs(t, err, ErrInvalidOptions) + + // Case 3: No jobs found (empty list) + mock.ExpectQuery("SHOW IMPORT JOBS WHERE GROUP_KEY = 'test_group'").WillReturnRows(sqlmock.NewRows(cols)) + jobs, err = manager.GetJobsByGroup(ctx, groupKey) + require.NoError(t, err) + require.Empty(t, jobs) + + // Case 4: Error + mock.ExpectQuery("SHOW IMPORT JOBS WHERE GROUP_KEY = 'test_group'").WillReturnError(sql.ErrConnDone) + _, err = manager.GetJobsByGroup(ctx, groupKey) + require.Error(t, err) + + require.NoError(t, mock.ExpectationsWereMet()) +} diff --git a/pkg/importsdk/mock/BUILD.bazel b/pkg/importsdk/mock/BUILD.bazel new file mode 100644 index 0000000000..0e9df33acb --- /dev/null +++ b/pkg/importsdk/mock/BUILD.bazel @@ -0,0 +1,12 @@ +load("@io_bazel_rules_go//go:def.bzl", "go_library") + +go_library( + name = "mock", + srcs = ["sdk_mock.go"], + importpath = "github.com/pingcap/tidb/pkg/importsdk/mock", + visibility = ["//visibility:public"], + deps = [ + "//pkg/importsdk", + "@org_uber_go_mock//gomock", + ], +) diff --git a/pkg/importsdk/mock/sdk_mock.go b/pkg/importsdk/mock/sdk_mock.go new file mode 100644 index 0000000000..44c108d984 --- /dev/null +++ b/pkg/importsdk/mock/sdk_mock.go @@ -0,0 +1,480 @@ +// Code generated by MockGen. DO NOT EDIT. +// Source: github.com/pingcap/tidb/pkg/importsdk (interfaces: FileScanner,JobManager,SQLGenerator,SDK) +// +// Generated by this command: +// +// mockgen -package mock github.com/pingcap/tidb/pkg/importsdk FileScanner,JobManager,SQLGenerator,SDK +// + +// Package mock is a generated GoMock package. +package mock + +import ( + context "context" + reflect "reflect" + + importsdk "github.com/pingcap/tidb/pkg/importsdk" + gomock "go.uber.org/mock/gomock" +) + +// MockFileScanner is a mock of FileScanner interface. +type MockFileScanner struct { + ctrl *gomock.Controller + recorder *MockFileScannerMockRecorder +} + +// MockFileScannerMockRecorder is the mock recorder for MockFileScanner. +type MockFileScannerMockRecorder struct { + mock *MockFileScanner +} + +// NewMockFileScanner creates a new mock instance. +func NewMockFileScanner(ctrl *gomock.Controller) *MockFileScanner { + mock := &MockFileScanner{ctrl: ctrl} + mock.recorder = &MockFileScannerMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use. +func (m *MockFileScanner) EXPECT() *MockFileScannerMockRecorder { + return m.recorder +} + +// ISGOMOCK indicates that this struct is a gomock mock. +func (m *MockFileScanner) ISGOMOCK() struct{} { + return struct{}{} +} + +// Close mocks base method. +func (m *MockFileScanner) Close() error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Close") + ret0, _ := ret[0].(error) + return ret0 +} + +// Close indicates an expected call of Close. +func (mr *MockFileScannerMockRecorder) Close() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Close", reflect.TypeOf((*MockFileScanner)(nil).Close)) +} + +// CreateSchemaAndTableByName mocks base method. +func (m *MockFileScanner) CreateSchemaAndTableByName(arg0 context.Context, arg1, arg2 string) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "CreateSchemaAndTableByName", arg0, arg1, arg2) + ret0, _ := ret[0].(error) + return ret0 +} + +// CreateSchemaAndTableByName indicates an expected call of CreateSchemaAndTableByName. +func (mr *MockFileScannerMockRecorder) CreateSchemaAndTableByName(arg0, arg1, arg2 any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CreateSchemaAndTableByName", reflect.TypeOf((*MockFileScanner)(nil).CreateSchemaAndTableByName), arg0, arg1, arg2) +} + +// CreateSchemasAndTables mocks base method. +func (m *MockFileScanner) CreateSchemasAndTables(arg0 context.Context) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "CreateSchemasAndTables", arg0) + ret0, _ := ret[0].(error) + return ret0 +} + +// CreateSchemasAndTables indicates an expected call of CreateSchemasAndTables. +func (mr *MockFileScannerMockRecorder) CreateSchemasAndTables(arg0 any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CreateSchemasAndTables", reflect.TypeOf((*MockFileScanner)(nil).CreateSchemasAndTables), arg0) +} + +// GetTableMetaByName mocks base method. +func (m *MockFileScanner) GetTableMetaByName(arg0 context.Context, arg1, arg2 string) (*importsdk.TableMeta, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetTableMetaByName", arg0, arg1, arg2) + ret0, _ := ret[0].(*importsdk.TableMeta) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// GetTableMetaByName indicates an expected call of GetTableMetaByName. +func (mr *MockFileScannerMockRecorder) GetTableMetaByName(arg0, arg1, arg2 any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetTableMetaByName", reflect.TypeOf((*MockFileScanner)(nil).GetTableMetaByName), arg0, arg1, arg2) +} + +// GetTableMetas mocks base method. +func (m *MockFileScanner) GetTableMetas(arg0 context.Context) ([]*importsdk.TableMeta, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetTableMetas", arg0) + ret0, _ := ret[0].([]*importsdk.TableMeta) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// GetTableMetas indicates an expected call of GetTableMetas. +func (mr *MockFileScannerMockRecorder) GetTableMetas(arg0 any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetTableMetas", reflect.TypeOf((*MockFileScanner)(nil).GetTableMetas), arg0) +} + +// GetTotalSize mocks base method. +func (m *MockFileScanner) GetTotalSize(arg0 context.Context) int64 { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetTotalSize", arg0) + ret0, _ := ret[0].(int64) + return ret0 +} + +// GetTotalSize indicates an expected call of GetTotalSize. +func (mr *MockFileScannerMockRecorder) GetTotalSize(arg0 any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetTotalSize", reflect.TypeOf((*MockFileScanner)(nil).GetTotalSize), arg0) +} + +// MockJobManager is a mock of JobManager interface. +type MockJobManager struct { + ctrl *gomock.Controller + recorder *MockJobManagerMockRecorder +} + +// MockJobManagerMockRecorder is the mock recorder for MockJobManager. +type MockJobManagerMockRecorder struct { + mock *MockJobManager +} + +// NewMockJobManager creates a new mock instance. +func NewMockJobManager(ctrl *gomock.Controller) *MockJobManager { + mock := &MockJobManager{ctrl: ctrl} + mock.recorder = &MockJobManagerMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use. +func (m *MockJobManager) EXPECT() *MockJobManagerMockRecorder { + return m.recorder +} + +// ISGOMOCK indicates that this struct is a gomock mock. +func (m *MockJobManager) ISGOMOCK() struct{} { + return struct{}{} +} + +// CancelJob mocks base method. +func (m *MockJobManager) CancelJob(arg0 context.Context, arg1 int64) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "CancelJob", arg0, arg1) + ret0, _ := ret[0].(error) + return ret0 +} + +// CancelJob indicates an expected call of CancelJob. +func (mr *MockJobManagerMockRecorder) CancelJob(arg0, arg1 any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CancelJob", reflect.TypeOf((*MockJobManager)(nil).CancelJob), arg0, arg1) +} + +// GetGroupSummary mocks base method. +func (m *MockJobManager) GetGroupSummary(arg0 context.Context, arg1 string) (*importsdk.GroupStatus, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetGroupSummary", arg0, arg1) + ret0, _ := ret[0].(*importsdk.GroupStatus) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// GetGroupSummary indicates an expected call of GetGroupSummary. +func (mr *MockJobManagerMockRecorder) GetGroupSummary(arg0, arg1 any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetGroupSummary", reflect.TypeOf((*MockJobManager)(nil).GetGroupSummary), arg0, arg1) +} + +// GetJobStatus mocks base method. +func (m *MockJobManager) GetJobStatus(arg0 context.Context, arg1 int64) (*importsdk.JobStatus, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetJobStatus", arg0, arg1) + ret0, _ := ret[0].(*importsdk.JobStatus) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// GetJobStatus indicates an expected call of GetJobStatus. +func (mr *MockJobManagerMockRecorder) GetJobStatus(arg0, arg1 any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetJobStatus", reflect.TypeOf((*MockJobManager)(nil).GetJobStatus), arg0, arg1) +} + +// GetJobsByGroup mocks base method. +func (m *MockJobManager) GetJobsByGroup(arg0 context.Context, arg1 string) ([]*importsdk.JobStatus, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetJobsByGroup", arg0, arg1) + ret0, _ := ret[0].([]*importsdk.JobStatus) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// GetJobsByGroup indicates an expected call of GetJobsByGroup. +func (mr *MockJobManagerMockRecorder) GetJobsByGroup(arg0, arg1 any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetJobsByGroup", reflect.TypeOf((*MockJobManager)(nil).GetJobsByGroup), arg0, arg1) +} + +// SubmitJob mocks base method. +func (m *MockJobManager) SubmitJob(arg0 context.Context, arg1 string) (int64, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "SubmitJob", arg0, arg1) + ret0, _ := ret[0].(int64) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// SubmitJob indicates an expected call of SubmitJob. +func (mr *MockJobManagerMockRecorder) SubmitJob(arg0, arg1 any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SubmitJob", reflect.TypeOf((*MockJobManager)(nil).SubmitJob), arg0, arg1) +} + +// MockSQLGenerator is a mock of SQLGenerator interface. +type MockSQLGenerator struct { + ctrl *gomock.Controller + recorder *MockSQLGeneratorMockRecorder +} + +// MockSQLGeneratorMockRecorder is the mock recorder for MockSQLGenerator. +type MockSQLGeneratorMockRecorder struct { + mock *MockSQLGenerator +} + +// NewMockSQLGenerator creates a new mock instance. +func NewMockSQLGenerator(ctrl *gomock.Controller) *MockSQLGenerator { + mock := &MockSQLGenerator{ctrl: ctrl} + mock.recorder = &MockSQLGeneratorMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use. +func (m *MockSQLGenerator) EXPECT() *MockSQLGeneratorMockRecorder { + return m.recorder +} + +// ISGOMOCK indicates that this struct is a gomock mock. +func (m *MockSQLGenerator) ISGOMOCK() struct{} { + return struct{}{} +} + +// GenerateImportSQL mocks base method. +func (m *MockSQLGenerator) GenerateImportSQL(arg0 *importsdk.TableMeta, arg1 *importsdk.ImportOptions) (string, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GenerateImportSQL", arg0, arg1) + ret0, _ := ret[0].(string) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// GenerateImportSQL indicates an expected call of GenerateImportSQL. +func (mr *MockSQLGeneratorMockRecorder) GenerateImportSQL(arg0, arg1 any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GenerateImportSQL", reflect.TypeOf((*MockSQLGenerator)(nil).GenerateImportSQL), arg0, arg1) +} + +// MockSDK is a mock of SDK interface. +type MockSDK struct { + ctrl *gomock.Controller + recorder *MockSDKMockRecorder +} + +// MockSDKMockRecorder is the mock recorder for MockSDK. +type MockSDKMockRecorder struct { + mock *MockSDK +} + +// NewMockSDK creates a new mock instance. +func NewMockSDK(ctrl *gomock.Controller) *MockSDK { + mock := &MockSDK{ctrl: ctrl} + mock.recorder = &MockSDKMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use. +func (m *MockSDK) EXPECT() *MockSDKMockRecorder { + return m.recorder +} + +// ISGOMOCK indicates that this struct is a gomock mock. +func (m *MockSDK) ISGOMOCK() struct{} { + return struct{}{} +} + +// CancelJob mocks base method. +func (m *MockSDK) CancelJob(arg0 context.Context, arg1 int64) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "CancelJob", arg0, arg1) + ret0, _ := ret[0].(error) + return ret0 +} + +// CancelJob indicates an expected call of CancelJob. +func (mr *MockSDKMockRecorder) CancelJob(arg0, arg1 any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CancelJob", reflect.TypeOf((*MockSDK)(nil).CancelJob), arg0, arg1) +} + +// Close mocks base method. +func (m *MockSDK) Close() error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Close") + ret0, _ := ret[0].(error) + return ret0 +} + +// Close indicates an expected call of Close. +func (mr *MockSDKMockRecorder) Close() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Close", reflect.TypeOf((*MockSDK)(nil).Close)) +} + +// CreateSchemaAndTableByName mocks base method. +func (m *MockSDK) CreateSchemaAndTableByName(arg0 context.Context, arg1, arg2 string) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "CreateSchemaAndTableByName", arg0, arg1, arg2) + ret0, _ := ret[0].(error) + return ret0 +} + +// CreateSchemaAndTableByName indicates an expected call of CreateSchemaAndTableByName. +func (mr *MockSDKMockRecorder) CreateSchemaAndTableByName(arg0, arg1, arg2 any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CreateSchemaAndTableByName", reflect.TypeOf((*MockSDK)(nil).CreateSchemaAndTableByName), arg0, arg1, arg2) +} + +// CreateSchemasAndTables mocks base method. +func (m *MockSDK) CreateSchemasAndTables(arg0 context.Context) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "CreateSchemasAndTables", arg0) + ret0, _ := ret[0].(error) + return ret0 +} + +// CreateSchemasAndTables indicates an expected call of CreateSchemasAndTables. +func (mr *MockSDKMockRecorder) CreateSchemasAndTables(arg0 any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CreateSchemasAndTables", reflect.TypeOf((*MockSDK)(nil).CreateSchemasAndTables), arg0) +} + +// GenerateImportSQL mocks base method. +func (m *MockSDK) GenerateImportSQL(arg0 *importsdk.TableMeta, arg1 *importsdk.ImportOptions) (string, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GenerateImportSQL", arg0, arg1) + ret0, _ := ret[0].(string) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// GenerateImportSQL indicates an expected call of GenerateImportSQL. +func (mr *MockSDKMockRecorder) GenerateImportSQL(arg0, arg1 any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GenerateImportSQL", reflect.TypeOf((*MockSDK)(nil).GenerateImportSQL), arg0, arg1) +} + +// GetGroupSummary mocks base method. +func (m *MockSDK) GetGroupSummary(arg0 context.Context, arg1 string) (*importsdk.GroupStatus, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetGroupSummary", arg0, arg1) + ret0, _ := ret[0].(*importsdk.GroupStatus) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// GetGroupSummary indicates an expected call of GetGroupSummary. +func (mr *MockSDKMockRecorder) GetGroupSummary(arg0, arg1 any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetGroupSummary", reflect.TypeOf((*MockSDK)(nil).GetGroupSummary), arg0, arg1) +} + +// GetJobStatus mocks base method. +func (m *MockSDK) GetJobStatus(arg0 context.Context, arg1 int64) (*importsdk.JobStatus, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetJobStatus", arg0, arg1) + ret0, _ := ret[0].(*importsdk.JobStatus) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// GetJobStatus indicates an expected call of GetJobStatus. +func (mr *MockSDKMockRecorder) GetJobStatus(arg0, arg1 any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetJobStatus", reflect.TypeOf((*MockSDK)(nil).GetJobStatus), arg0, arg1) +} + +// GetJobsByGroup mocks base method. +func (m *MockSDK) GetJobsByGroup(arg0 context.Context, arg1 string) ([]*importsdk.JobStatus, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetJobsByGroup", arg0, arg1) + ret0, _ := ret[0].([]*importsdk.JobStatus) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// GetJobsByGroup indicates an expected call of GetJobsByGroup. +func (mr *MockSDKMockRecorder) GetJobsByGroup(arg0, arg1 any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetJobsByGroup", reflect.TypeOf((*MockSDK)(nil).GetJobsByGroup), arg0, arg1) +} + +// GetTableMetaByName mocks base method. +func (m *MockSDK) GetTableMetaByName(arg0 context.Context, arg1, arg2 string) (*importsdk.TableMeta, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetTableMetaByName", arg0, arg1, arg2) + ret0, _ := ret[0].(*importsdk.TableMeta) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// GetTableMetaByName indicates an expected call of GetTableMetaByName. +func (mr *MockSDKMockRecorder) GetTableMetaByName(arg0, arg1, arg2 any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetTableMetaByName", reflect.TypeOf((*MockSDK)(nil).GetTableMetaByName), arg0, arg1, arg2) +} + +// GetTableMetas mocks base method. +func (m *MockSDK) GetTableMetas(arg0 context.Context) ([]*importsdk.TableMeta, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetTableMetas", arg0) + ret0, _ := ret[0].([]*importsdk.TableMeta) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// GetTableMetas indicates an expected call of GetTableMetas. +func (mr *MockSDKMockRecorder) GetTableMetas(arg0 any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetTableMetas", reflect.TypeOf((*MockSDK)(nil).GetTableMetas), arg0) +} + +// GetTotalSize mocks base method. +func (m *MockSDK) GetTotalSize(arg0 context.Context) int64 { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetTotalSize", arg0) + ret0, _ := ret[0].(int64) + return ret0 +} + +// GetTotalSize indicates an expected call of GetTotalSize. +func (mr *MockSDKMockRecorder) GetTotalSize(arg0 any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetTotalSize", reflect.TypeOf((*MockSDK)(nil).GetTotalSize), arg0) +} + +// SubmitJob mocks base method. +func (m *MockSDK) SubmitJob(arg0 context.Context, arg1 string) (int64, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "SubmitJob", arg0, arg1) + ret0, _ := ret[0].(int64) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// SubmitJob indicates an expected call of SubmitJob. +func (mr *MockSDKMockRecorder) SubmitJob(arg0, arg1 any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SubmitJob", reflect.TypeOf((*MockSDK)(nil).SubmitJob), arg0, arg1) +} diff --git a/pkg/importsdk/model.go b/pkg/importsdk/model.go new file mode 100644 index 0000000000..6f23772004 --- /dev/null +++ b/pkg/importsdk/model.go @@ -0,0 +1,119 @@ +// Copyright 2025 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package importsdk + +import ( + "time" + + "github.com/pingcap/tidb/pkg/lightning/config" + "github.com/pingcap/tidb/pkg/lightning/mydump" +) + +// TableMeta contains metadata for a table to be imported +type TableMeta struct { + Database string + Table string + DataFiles []DataFileMeta + TotalSize int64 // In bytes + WildcardPath string // Wildcard pattern that matches only this table's data files + SchemaFile string // Path to the table schema file, if available +} + +// DataFileMeta contains metadata for a data file +type DataFileMeta struct { + Path string + Size int64 + Format mydump.SourceType + Compression mydump.Compression +} + +// ImportOptions wraps the options for IMPORT INTO statement. +// It reuses structures from executor/importer where possible. +type ImportOptions struct { + Format string + CSVConfig *config.CSVConfig + Thread int + DiskQuota string + MaxWriteSpeed string + SplitFile bool + RecordErrors int64 + Detached bool + CloudStorageURI string + GroupKey string + SkipRows int + CharacterSet string + ChecksumTable string + DisableTiKVImportMode bool + DisablePrecheck bool + ResourceParameters string +} + +// GroupStatus represents the aggregated status for a group of import jobs. +type GroupStatus struct { + GroupKey string + TotalJobs int64 + Pending int64 + Running int64 + Completed int64 + Failed int64 + Cancelled int64 + FirstJobCreateTime time.Time + LastJobUpdateTime time.Time +} + +// JobStatus represents the status of an import job. +type JobStatus struct { + JobID int64 + GroupKey string + DataSource string + TargetTable string + TableID int64 + Phase string + Status string + SourceFileSize string + ImportedRows int64 + ResultMessage string + CreateTime time.Time + StartTime time.Time + EndTime time.Time + CreatedBy string + UpdateTime time.Time + Step string + ProcessedSize string + TotalSize string + Percent string + Speed string + ETA string +} + +// IsFinished returns true if the job is finished successfully. +func (s *JobStatus) IsFinished() bool { + return s.Status == "finished" +} + +// IsFailed returns true if the job failed. +func (s *JobStatus) IsFailed() bool { + return s.Status == "failed" +} + +// IsCancelled returns true if the job was cancelled. +func (s *JobStatus) IsCancelled() bool { + return s.Status == "cancelled" +} + +// IsCompleted returns true if the job is in a terminal state. +func (s *JobStatus) IsCompleted() bool { + return s.IsFinished() || s.IsFailed() || s.IsCancelled() +} diff --git a/pkg/importsdk/model_test.go b/pkg/importsdk/model_test.go new file mode 100644 index 0000000000..9759060c89 --- /dev/null +++ b/pkg/importsdk/model_test.go @@ -0,0 +1,82 @@ +// Copyright 2025 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package importsdk + +import ( + "testing" + + "github.com/stretchr/testify/require" +) + +func TestJobStatus(t *testing.T) { + tests := []struct { + status string + isFinished bool + isFailed bool + isCancelled bool + isCompleted bool + }{ + { + status: "finished", + isFinished: true, + isFailed: false, + isCancelled: false, + isCompleted: true, + }, + { + status: "failed", + isFinished: false, + isFailed: true, + isCancelled: false, + isCompleted: true, + }, + { + status: "cancelled", + isFinished: false, + isFailed: false, + isCancelled: true, + isCompleted: true, + }, + { + status: "running", + isFinished: false, + isFailed: false, + isCancelled: false, + isCompleted: false, + }, + { + status: "pending", + isFinished: false, + isFailed: false, + isCancelled: false, + isCompleted: false, + }, + { + status: "unknown", + isFinished: false, + isFailed: false, + isCancelled: false, + isCompleted: false, + }, + } + + for _, tt := range tests { + job := &JobStatus{Status: tt.status} + require.Equal(t, tt.isFinished, job.IsFinished(), "status: %s", tt.status) + require.Equal(t, tt.isFailed, job.IsFailed(), "status: %s", tt.status) + require.Equal(t, tt.isCancelled, job.IsCancelled(), "status: %s", tt.status) + require.Equal(t, tt.isCompleted, job.IsCompleted(), "status: %s", tt.status) + } +} diff --git a/pkg/importsdk/pattern.go b/pkg/importsdk/pattern.go new file mode 100644 index 0000000000..5299be9dc7 --- /dev/null +++ b/pkg/importsdk/pattern.go @@ -0,0 +1,175 @@ +// Copyright 2025 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package importsdk + +import ( + "path/filepath" + "strings" + + "github.com/pingcap/errors" + "github.com/pingcap/tidb/pkg/lightning/mydump" +) + +// generateWildcardPath creates a wildcard pattern path that matches only this table's files +func generateWildcardPath( + files []mydump.FileInfo, + allFiles map[string]mydump.FileInfo, +) (string, error) { + tableFiles := make(map[string]struct{}, len(files)) + for _, df := range files { + tableFiles[df.FileMeta.Path] = struct{}{} + } + + if len(files) == 0 { + return "", errors.Annotate(ErrNoTableDataFiles, "cannot generate wildcard pattern because the table has no data files") + } + + // If there's only one file, we can just return its path + if len(files) == 1 { + return files[0].FileMeta.Path, nil + } + + // Try Mydumper-specific pattern first + p := generateMydumperPattern(files[0]) + if p != "" && isValidPattern(p, tableFiles, allFiles) { + return p, nil + } + + // Fallback to generic prefix/suffix pattern + paths := make([]string, 0, len(files)) + for _, file := range files { + paths = append(paths, file.FileMeta.Path) + } + p = generatePrefixSuffixPattern(paths) + if p != "" && isValidPattern(p, tableFiles, allFiles) { + return p, nil + } + return "", errors.Annotatef(ErrWildcardNotSpecific, "failed to find a wildcard that matches all and only the table's files.") +} + +// isValidPattern checks if a wildcard pattern matches only the table's files +func isValidPattern(pattern string, tableFiles map[string]struct{}, allFiles map[string]mydump.FileInfo) bool { + if pattern == "" { + return false + } + + for path := range allFiles { + isMatch, err := filepath.Match(pattern, path) + if err != nil { + return false // Invalid pattern + } + _, isTableFile := tableFiles[path] + + // If pattern matches a file that's not from our table, it's invalid + if isMatch && !isTableFile { + return false + } + + // If pattern doesn't match our table's file, it's also invalid + if !isMatch && isTableFile { + return false + } + } + + return true +} + +// generateMydumperPattern generates a wildcard pattern for Mydumper-formatted data files +// belonging to a specific table, based on their naming convention. +// It returns a pattern string that matches all data files for the table, or an empty string if not applicable. +func generateMydumperPattern(file mydump.FileInfo) string { + dbName, tableName := file.TableName.Schema, file.TableName.Name + if dbName == "" || tableName == "" { + return "" + } + + // compute dirPrefix and basename + full := file.FileMeta.Path + dirPrefix, name := "", full + if idx := strings.LastIndex(full, "/"); idx >= 0 { + dirPrefix = full[:idx+1] + name = full[idx+1:] + } + + // compression ext from filename when compression exists (last suffix like .gz/.zst) + compExt := "" + if file.FileMeta.Compression != mydump.CompressionNone { + compExt = filepath.Ext(name) + } + + // data ext after stripping compression ext + base := strings.TrimSuffix(name, compExt) + dataExt := filepath.Ext(base) + return dirPrefix + dbName + "." + tableName + ".*" + dataExt + compExt +} + +// longestCommonPrefix finds the longest string that is a prefix of all strings in the slice +func longestCommonPrefix(strs []string) string { + if len(strs) == 0 { + return "" + } + + prefix := strs[0] + for _, s := range strs[1:] { + i := 0 + for i < len(prefix) && i < len(s) && prefix[i] == s[i] { + i++ + } + prefix = prefix[:i] + if prefix == "" { + break + } + } + + return prefix +} + +// longestCommonSuffix finds the longest string that is a suffix of all strings in the slice, starting after the given prefix length +func longestCommonSuffix(strs []string, prefixLen int) string { + if len(strs) == 0 { + return "" + } + + suffix := strs[0][prefixLen:] + for _, s := range strs[1:] { + remaining := s[prefixLen:] + i := 0 + for i < len(suffix) && i < len(remaining) && suffix[len(suffix)-i-1] == remaining[len(remaining)-i-1] { + i++ + } + suffix = suffix[len(suffix)-i:] + if suffix == "" { + break + } + } + + return suffix +} + +// generatePrefixSuffixPattern returns a wildcard pattern that matches all and only the given paths +// by finding the longest common prefix and suffix among them, and placing a '*' wildcard in between. +func generatePrefixSuffixPattern(paths []string) string { + if len(paths) == 0 { + return "" + } + if len(paths) == 1 { + return paths[0] + } + + prefix := longestCommonPrefix(paths) + suffix := longestCommonSuffix(paths, len(prefix)) + + return prefix + "*" + suffix +} diff --git a/pkg/importsdk/pattern_test.go b/pkg/importsdk/pattern_test.go new file mode 100644 index 0000000000..22cabeaa08 --- /dev/null +++ b/pkg/importsdk/pattern_test.go @@ -0,0 +1,208 @@ +// Copyright 2025 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package importsdk + +import ( + "testing" + + "github.com/pingcap/tidb/pkg/lightning/log" + "github.com/pingcap/tidb/pkg/lightning/mydump" + filter "github.com/pingcap/tidb/pkg/util/table-filter" + "github.com/stretchr/testify/require" +) + +func TestLongestCommonPrefix(t *testing.T) { + strs := []string{"s3://bucket/foo/bar/baz1", "s3://bucket/foo/bar/baz2", "s3://bucket/foo/bar/baz"} + p := longestCommonPrefix(strs) + require.Equal(t, "s3://bucket/foo/bar/baz", p) + + // no common prefix + require.Equal(t, "", longestCommonPrefix([]string{"a", "b"})) + + // empty inputs + require.Equal(t, "", longestCommonPrefix(nil)) + require.Equal(t, "", longestCommonPrefix([]string{})) +} + +func TestLongestCommonSuffix(t *testing.T) { + strs := []string{"abcXYZ", "defXYZ", "XYZ"} + s := longestCommonSuffix(strs, 0) + require.Equal(t, "XYZ", s) + + // no common suffix + require.Equal(t, "", longestCommonSuffix([]string{"a", "b"}, 0)) + + // empty inputs + require.Equal(t, "", longestCommonSuffix(nil, 0)) + require.Equal(t, "", longestCommonSuffix([]string{}, 0)) + + // same prefix + require.Equal(t, "", longestCommonSuffix([]string{"abc", "abc"}, 3)) + require.Equal(t, "f", longestCommonSuffix([]string{"abcdf", "abcef"}, 3)) +} + +func TestGeneratePrefixSuffixPattern(t *testing.T) { + paths := []string{"pre_middle_suf", "pre_most_suf"} + pattern := generatePrefixSuffixPattern(paths) + // common prefix "pre_m", suffix "_suf" + require.Equal(t, "pre_m*_suf", pattern) + + // empty inputs + require.Equal(t, "", generatePrefixSuffixPattern(nil)) + require.Equal(t, "", generatePrefixSuffixPattern([]string{})) + + // only one file + require.Equal(t, "pre_middle_suf", generatePrefixSuffixPattern([]string{"pre_middle_suf"})) + + // no common prefix/suffix + paths2 := []string{"foo", "bar"} + require.Equal(t, "*", generatePrefixSuffixPattern(paths2)) + + // overlapping prefix/suffix + paths3 := []string{"aaabaaa", "aaa"} + require.Equal(t, "aaa*", generatePrefixSuffixPattern(paths3)) +} + +func generateFileMetas(t *testing.T, paths []string) []mydump.FileInfo { + t.Helper() + + files := make([]mydump.FileInfo, 0, len(paths)) + fileRouter, err := mydump.NewDefaultFileRouter(log.L()) + require.NoError(t, err) + for _, p := range paths { + res, err := fileRouter.Route(p) + require.NoError(t, err) + files = append(files, mydump.FileInfo{ + TableName: res.Table, + FileMeta: mydump.SourceFileMeta{ + Path: p, + Type: res.Type, + Compression: res.Compression, + SortKey: res.Key, + }, + }) + } + return files +} + +func TestGenerateMydumperPattern(t *testing.T) { + paths := []string{"db.tb.0001.sql", "db.tb.0002.sql"} + p := generateMydumperPattern(generateFileMetas(t, paths)[0]) + require.Equal(t, "db.tb.*.sql", p) + + paths2 := []string{"s3://bucket/dir/db.tb.0001.sql", "s3://bucket/dir/db.tb.0002.sql"} + p2 := generateMydumperPattern(generateFileMetas(t, paths2)[0]) + require.Equal(t, "s3://bucket/dir/db.tb.*.sql", p2) + + // not mydumper pattern + require.Equal(t, "", generateMydumperPattern(mydump.FileInfo{ + TableName: filter.Table{}, + })) +} + +func TestValidatePattern(t *testing.T) { + tableFiles := map[string]struct{}{ + "a.txt": {}, "b.txt": {}, + } + // only table files in allFiles + smallAll := map[string]mydump.FileInfo{ + "a.txt": {}, "b.txt": {}, + } + require.True(t, isValidPattern("*.txt", tableFiles, smallAll)) + + // allFiles includes an extra file => invalid + fullAll := map[string]mydump.FileInfo{ + "a.txt": {}, "b.txt": {}, "c.txt": {}, + } + require.False(t, isValidPattern("*.txt", tableFiles, fullAll)) + + // If pattern doesn't match our table's file, it's also invalid + require.False(t, isValidPattern("*.csv", tableFiles, smallAll)) + + // empty pattern => invalid + require.False(t, isValidPattern("", tableFiles, smallAll)) +} + +func TestGenerateWildcardPath(t *testing.T) { + // Helper to create allFiles map + createAllFiles := func(paths []string) map[string]mydump.FileInfo { + allFiles := make(map[string]mydump.FileInfo) + for _, p := range paths { + allFiles[p] = mydump.FileInfo{ + FileMeta: mydump.SourceFileMeta{Path: p}, + } + } + return allFiles + } + + // No files + files1 := []mydump.FileInfo{} + allFiles1 := createAllFiles([]string{}) + _, err := generateWildcardPath(files1, allFiles1) + require.Error(t, err) + require.Contains(t, err.Error(), "no data files for table") + + // Single file + files2 := generateFileMetas(t, []string{"db.tb.0001.sql"}) + allFiles2 := createAllFiles([]string{"db.tb.0001.sql"}) + path2, err := generateWildcardPath(files2, allFiles2) + require.NoError(t, err) + require.Equal(t, "db.tb.0001.sql", path2) + + // Mydumper pattern succeeds + files3 := generateFileMetas(t, []string{"db.tb.0001.sql.gz", "db.tb.0002.sql.gz"}) + allFiles3 := createAllFiles([]string{"db.tb.0001.sql.gz", "db.tb.0002.sql.gz"}) + path3, err := generateWildcardPath(files3, allFiles3) + require.NoError(t, err) + require.Equal(t, "db.tb.*.sql.gz", path3) + + // Mydumper pattern fails, fallback to prefix/suffix succeeds + files4 := []mydump.FileInfo{ + { + TableName: filter.Table{Schema: "db", Name: "tb"}, + FileMeta: mydump.SourceFileMeta{ + Path: "a.sql", + Type: mydump.SourceTypeSQL, + Compression: mydump.CompressionNone, + }, + }, + { + TableName: filter.Table{Schema: "db", Name: "tb"}, + FileMeta: mydump.SourceFileMeta{ + Path: "b.sql", + Type: mydump.SourceTypeSQL, + Compression: mydump.CompressionNone, + }, + }, + } + allFiles4 := map[string]mydump.FileInfo{ + files4[0].FileMeta.Path: files4[0], + files4[1].FileMeta.Path: files4[1], + } + path4, err := generateWildcardPath(files4, allFiles4) + require.NoError(t, err) + require.Equal(t, "*.sql", path4) + + allFiles4["db-schema.sql"] = mydump.FileInfo{ + FileMeta: mydump.SourceFileMeta{ + Path: "db-schema.sql", + Type: mydump.SourceTypeSQL, + Compression: mydump.CompressionNone, + }, + } + _, err = generateWildcardPath(files4, allFiles4) + require.Error(t, err) + require.Contains(t, err.Error(), "cannot generate a unique wildcard pattern") +} diff --git a/pkg/importsdk/sdk.go b/pkg/importsdk/sdk.go index b14a5e7247..b68a7ad584 100644 --- a/pkg/importsdk/sdk.go +++ b/pkg/importsdk/sdk.go @@ -17,61 +17,21 @@ package importsdk import ( "context" "database/sql" - "path/filepath" - "strings" - - "github.com/pingcap/errors" - "github.com/pingcap/tidb/br/pkg/storage" - "github.com/pingcap/tidb/pkg/lightning/config" - "github.com/pingcap/tidb/pkg/lightning/log" - "github.com/pingcap/tidb/pkg/lightning/mydump" - "github.com/pingcap/tidb/pkg/parser/mysql" ) // SDK defines the interface for cloud import services type SDK interface { - // CreateSchemasAndTables creates all database schemas and tables from source path - CreateSchemasAndTables(ctx context.Context) error - - // GetTableMetas returns metadata for all tables in the source path - GetTableMetas(ctx context.Context) ([]*TableMeta, error) - - // GetTableMetaByName returns metadata for a specific table - GetTableMetaByName(ctx context.Context, schema, table string) (*TableMeta, error) - - // GetTotalSize returns the cumulative size (in bytes) of all data files under the source path - GetTotalSize(ctx context.Context) int64 - - // Close releases resources used by the SDK + FileScanner + JobManager + SQLGenerator Close() error } -// TableMeta contains metadata for a table to be imported -type TableMeta struct { - Database string - Table string - DataFiles []DataFileMeta - TotalSize int64 // In bytes - WildcardPath string // Wildcard pattern that matches only this table's data files - SchemaFile string // Path to the table schema file, if available -} - -// DataFileMeta contains metadata for a data file -type DataFileMeta struct { - Path string - Size int64 - Format mydump.SourceType - Compression mydump.Compression -} - -// ImportSDK implements SDK interface -type ImportSDK struct { - sourcePath string - db *sql.DB - store storage.ExternalStorage - loader *mydump.MDLoader - logger log.Logger - config *sdkConfig +// importSDK implements SDK interface +type importSDK struct { + FileScanner + JobManager + SQLGenerator } // NewImportSDK creates a new CloudImportSDK instance @@ -81,410 +41,22 @@ func NewImportSDK(ctx context.Context, sourcePath string, db *sql.DB, options .. opt(cfg) } - u, err := storage.ParseBackend(sourcePath, nil) + scanner, err := NewFileScanner(ctx, sourcePath, db, cfg) if err != nil { - return nil, errors.Annotatef(err, "failed to parse storage backend URL (source=%s). Please verify the URL format and credentials", sourcePath) - } - store, err := storage.New(ctx, u, &storage.ExternalStorageOptions{}) - if err != nil { - return nil, errors.Annotatef(err, "failed to create external storage (source=%s). Check network/connectivity and permissions", sourcePath) + return nil, err } - ldrCfg := mydump.LoaderConfig{ - SourceURL: sourcePath, - Filter: cfg.filter, - FileRouters: cfg.fileRouteRules, - DefaultFileRules: len(cfg.fileRouteRules) == 0, - CharacterSet: cfg.charset, - } + jobManager := NewJobManager(db) + sqlGenerator := NewSQLGenerator() - loader, err := mydump.NewLoaderWithStore(ctx, ldrCfg, store) - if err != nil { - return nil, errors.Annotatef(err, "failed to create MyDump loader (source=%s, charset=%s, filter=%v). Please check dump layout and router rules", sourcePath, cfg.charset, cfg.filter) - } - - return &ImportSDK{ - sourcePath: sourcePath, - db: db, - store: store, - loader: loader, - logger: cfg.logger, - config: cfg, + return &importSDK{ + FileScanner: scanner, + JobManager: jobManager, + SQLGenerator: sqlGenerator, }, nil } -// SDKOption customizes the SDK configuration -type SDKOption func(*sdkConfig) - -type sdkConfig struct { - // Loader options - concurrency int - sqlMode mysql.SQLMode - fileRouteRules []*config.FileRouteRule - filter []string - charset string - - // General options - logger log.Logger +// Close releases resources used by the SDK +func (sdk *importSDK) Close() error { + return sdk.FileScanner.Close() } - -func defaultSDKConfig() *sdkConfig { - return &sdkConfig{ - concurrency: 4, - filter: config.GetDefaultFilter(), - logger: log.L(), - charset: "auto", - } -} - -// WithConcurrency sets the number of concurrent DB/Table creation workers. -func WithConcurrency(n int) SDKOption { - return func(cfg *sdkConfig) { - if n > 0 { - cfg.concurrency = n - } - } -} - -// WithLogger specifies a custom logger -func WithLogger(logger log.Logger) SDKOption { - return func(cfg *sdkConfig) { - cfg.logger = logger - } -} - -// WithSQLMode specifies the SQL mode for schema parsing -func WithSQLMode(mode mysql.SQLMode) SDKOption { - return func(cfg *sdkConfig) { - cfg.sqlMode = mode - } -} - -// WithFilter specifies a filter for the loader -func WithFilter(filter []string) SDKOption { - return func(cfg *sdkConfig) { - cfg.filter = filter - } -} - -// WithFileRouters specifies custom file routing rules -func WithFileRouters(routers []*config.FileRouteRule) SDKOption { - return func(cfg *sdkConfig) { - cfg.fileRouteRules = routers - } -} - -// WithCharset specifies the character set for import (default "auto"). -func WithCharset(cs string) SDKOption { - return func(cfg *sdkConfig) { - if cs != "" { - cfg.charset = cs - } - } -} - -// CreateSchemasAndTables implements the CloudImportSDK interface -func (sdk *ImportSDK) CreateSchemasAndTables(ctx context.Context) error { - dbMetas := sdk.loader.GetDatabases() - if len(dbMetas) == 0 { - return errors.Annotatef(ErrNoDatabasesFound, "source=%s. Ensure the path contains valid dump files (*.sql, *.csv, *.parquet, etc.) and filter rules are correct", sdk.sourcePath) - } - - // Create all schemas and tables - importer := mydump.NewSchemaImporter( - sdk.logger, - sdk.config.sqlMode, - sdk.db, - sdk.store, - sdk.config.concurrency, - ) - - err := importer.Run(ctx, dbMetas) - if err != nil { - return errors.Annotatef(err, "creating schemas and tables failed (source=%s, db_count=%d, concurrency=%d)", sdk.sourcePath, len(dbMetas), sdk.config.concurrency) - } - - return nil -} - -// GetTableMetas implements the CloudImportSDK interface -func (sdk *ImportSDK) GetTableMetas(context.Context) ([]*TableMeta, error) { - dbMetas := sdk.loader.GetDatabases() - allFiles := sdk.loader.GetAllFiles() - var results []*TableMeta - for _, dbMeta := range dbMetas { - for _, tblMeta := range dbMeta.Tables { - tableMeta, err := sdk.buildTableMeta(dbMeta, tblMeta, allFiles) - if err != nil { - return nil, errors.Wrapf(err, "failed to build metadata for table %s.%s", - dbMeta.Name, tblMeta.Name) - } - results = append(results, tableMeta) - } - } - - return results, nil -} - -// GetTableMetaByName implements CloudImportSDK interface -func (sdk *ImportSDK) GetTableMetaByName(_ context.Context, schema, table string) (*TableMeta, error) { - dbMetas := sdk.loader.GetDatabases() - allFiles := sdk.loader.GetAllFiles() - // Find the specific table - for _, dbMeta := range dbMetas { - if dbMeta.Name != schema { - continue - } - - for _, tblMeta := range dbMeta.Tables { - if tblMeta.Name != table { - continue - } - - return sdk.buildTableMeta(dbMeta, tblMeta, allFiles) - } - - return nil, errors.Annotatef(ErrTableNotFound, "schema=%s, table=%s", schema, table) - } - - return nil, errors.Annotatef(ErrSchemaNotFound, "schema=%s", schema) -} - -// GetTotalSize implements CloudImportSDK interface -func (sdk *ImportSDK) GetTotalSize(ctx context.Context) int64 { - var total int64 - dbMetas := sdk.loader.GetDatabases() - for _, dbMeta := range dbMetas { - for _, tblMeta := range dbMeta.Tables { - total += tblMeta.TotalSize - } - } - return total -} - -// buildTableMeta creates a TableMeta from database and table metadata -func (sdk *ImportSDK) buildTableMeta( - dbMeta *mydump.MDDatabaseMeta, - tblMeta *mydump.MDTableMeta, - allDataFiles map[string]mydump.FileInfo, -) (*TableMeta, error) { - tableMeta := &TableMeta{ - Database: dbMeta.Name, - Table: tblMeta.Name, - DataFiles: make([]DataFileMeta, 0, len(tblMeta.DataFiles)), - SchemaFile: tblMeta.SchemaFile.FileMeta.Path, - } - - // Process data files - dataFiles, totalSize := processDataFiles(tblMeta.DataFiles) - tableMeta.DataFiles = dataFiles - tableMeta.TotalSize = totalSize - - wildcard, err := generateWildcardPath(tblMeta.DataFiles, allDataFiles) - if err != nil { - return nil, errors.Annotatef(err, "failed to build wildcard for table=%s.%s", dbMeta.Name, tblMeta.Name) - } - tableMeta.WildcardPath = strings.TrimSuffix(sdk.store.URI(), "/") + "/" + wildcard - - return tableMeta, nil -} - -// Close implements CloudImportSDK interface -func (sdk *ImportSDK) Close() error { - // close external storage - if sdk.store != nil { - sdk.store.Close() - } - return nil -} - -// processDataFiles converts mydump data files to DataFileMeta and calculates total size -func processDataFiles(files []mydump.FileInfo) ([]DataFileMeta, int64) { - dataFiles := make([]DataFileMeta, 0, len(files)) - var totalSize int64 - - for _, dataFile := range files { - fileMeta := createDataFileMeta(dataFile) - dataFiles = append(dataFiles, fileMeta) - totalSize += dataFile.FileMeta.RealSize - } - - return dataFiles, totalSize -} - -// createDataFileMeta creates a DataFileMeta from a mydump.DataFile -func createDataFileMeta(file mydump.FileInfo) DataFileMeta { - return DataFileMeta{ - Path: file.FileMeta.Path, - Size: file.FileMeta.RealSize, - Format: file.FileMeta.Type, - Compression: file.FileMeta.Compression, - } -} - -// generateWildcardPath creates a wildcard pattern path that matches only this table's files -func generateWildcardPath( - files []mydump.FileInfo, - allFiles map[string]mydump.FileInfo, -) (string, error) { - tableFiles := make(map[string]struct{}, len(files)) - for _, df := range files { - tableFiles[df.FileMeta.Path] = struct{}{} - } - - if len(files) == 0 { - return "", errors.Annotate(ErrNoTableDataFiles, "cannot generate wildcard pattern because the table has no data files") - } - - // If there's only one file, we can just return its path - if len(files) == 1 { - return files[0].FileMeta.Path, nil - } - - // Try Mydumper-specific pattern first - p := generateMydumperPattern(files[0]) - if p != "" && isValidPattern(p, tableFiles, allFiles) { - return p, nil - } - - // Fallback to generic prefix/suffix pattern - paths := make([]string, 0, len(files)) - for _, file := range files { - paths = append(paths, file.FileMeta.Path) - } - p = generatePrefixSuffixPattern(paths) - if p != "" && isValidPattern(p, tableFiles, allFiles) { - return p, nil - } - return "", errors.Annotatef(ErrWildcardNotSpecific, "failed to find a wildcard that matches all and only the table's files.") -} - -// isValidPattern checks if a wildcard pattern matches only the table's files -func isValidPattern(pattern string, tableFiles map[string]struct{}, allFiles map[string]mydump.FileInfo) bool { - if pattern == "" { - return false - } - - for path := range allFiles { - isMatch, err := filepath.Match(pattern, path) - if err != nil { - return false // Invalid pattern - } - _, isTableFile := tableFiles[path] - - // If pattern matches a file that's not from our table, it's invalid - if isMatch && !isTableFile { - return false - } - - // If pattern doesn't match our table's file, it's also invalid - if !isMatch && isTableFile { - return false - } - } - - return true -} - -// generateMydumperPattern generates a wildcard pattern for Mydumper-formatted data files -// belonging to a specific table, based on their naming convention. -// It returns a pattern string that matches all data files for the table, or an empty string if not applicable. -func generateMydumperPattern(file mydump.FileInfo) string { - dbName, tableName := file.TableName.Schema, file.TableName.Name - if dbName == "" || tableName == "" { - return "" - } - - // compute dirPrefix and basename - full := file.FileMeta.Path - dirPrefix, name := "", full - if idx := strings.LastIndex(full, "/"); idx >= 0 { - dirPrefix = full[:idx+1] - name = full[idx+1:] - } - - // compression ext from filename when compression exists (last suffix like .gz/.zst) - compExt := "" - if file.FileMeta.Compression != mydump.CompressionNone { - compExt = filepath.Ext(name) - } - - // data ext after stripping compression ext - base := strings.TrimSuffix(name, compExt) - dataExt := filepath.Ext(base) - return dirPrefix + dbName + "." + tableName + ".*" + dataExt + compExt -} - -// longestCommonPrefix finds the longest string that is a prefix of all strings in the slice -func longestCommonPrefix(strs []string) string { - if len(strs) == 0 { - return "" - } - - prefix := strs[0] - for _, s := range strs[1:] { - i := 0 - for i < len(prefix) && i < len(s) && prefix[i] == s[i] { - i++ - } - prefix = prefix[:i] - if prefix == "" { - break - } - } - - return prefix -} - -// longestCommonSuffix finds the longest string that is a suffix of all strings in the slice, starting after the given prefix length -func longestCommonSuffix(strs []string, prefixLen int) string { - if len(strs) == 0 { - return "" - } - - suffix := strs[0][prefixLen:] - for _, s := range strs[1:] { - remaining := s[prefixLen:] - i := 0 - for i < len(suffix) && i < len(remaining) && suffix[len(suffix)-i-1] == remaining[len(remaining)-i-1] { - i++ - } - suffix = suffix[len(suffix)-i:] - if suffix == "" { - break - } - } - - return suffix -} - -// generatePrefixSuffixPattern returns a wildcard pattern that matches all and only the given paths -// by finding the longest common prefix and suffix among them, and placing a '*' wildcard in between. -func generatePrefixSuffixPattern(paths []string) string { - if len(paths) == 0 { - return "" - } - if len(paths) == 1 { - return paths[0] - } - - prefix := longestCommonPrefix(paths) - suffix := longestCommonSuffix(paths, len(prefix)) - - return prefix + "*" + suffix -} - -// TODO: add error code and doc for cloud sdk -// Sentinel errors to categorize common failure scenarios for clearer user messages. -var ( - // ErrNoDatabasesFound indicates that the dump source contains no recognizable databases. - ErrNoDatabasesFound = errors.New("no databases found in the source path") - // ErrSchemaNotFound indicates the target schema doesn't exist in the dump source. - ErrSchemaNotFound = errors.New("schema not found") - // ErrTableNotFound indicates the target table doesn't exist in the dump source. - ErrTableNotFound = errors.New("table not found") - // ErrNoTableDataFiles indicates a table has zero data files and thus cannot proceed. - ErrNoTableDataFiles = errors.New("no data files for table") - // ErrWildcardNotSpecific indicates a wildcard cannot uniquely match the table's files. - ErrWildcardNotSpecific = errors.New("cannot generate a unique wildcard pattern for the table's data files") -) diff --git a/pkg/importsdk/sdk_test.go b/pkg/importsdk/sdk_test.go index aee9ef1642..b7bfef106a 100644 --- a/pkg/importsdk/sdk_test.go +++ b/pkg/importsdk/sdk_test.go @@ -27,8 +27,6 @@ import ( "github.com/pingcap/tidb/pkg/lightning/log" "github.com/pingcap/tidb/pkg/lightning/mydump" "github.com/pingcap/tidb/pkg/parser/mysql" - filter "github.com/pingcap/tidb/pkg/util/table-filter" - "github.com/stretchr/testify/require" "github.com/stretchr/testify/suite" ) @@ -355,219 +353,82 @@ func (s *mockGCSSuite) TestOnlyDataFiles() { s.NoError(err) } -func TestLongestCommonPrefix(t *testing.T) { - strs := []string{"s3://bucket/foo/bar/baz1", "s3://bucket/foo/bar/baz2", "s3://bucket/foo/bar/baz"} - p := longestCommonPrefix(strs) - require.Equal(t, "s3://bucket/foo/bar/baz", p) - - // no common prefix - require.Equal(t, "", longestCommonPrefix([]string{"a", "b"})) - - // empty inputs - require.Equal(t, "", longestCommonPrefix(nil)) - require.Equal(t, "", longestCommonPrefix([]string{})) +func (s *mockGCSSuite) TestScanLimitation() { + s.server.CreateBucketWithOpts(fakestorage.CreateBucketOpts{Name: "limitation"}) + s.server.CreateObject(fakestorage.Object{ + ObjectAttrs: fakestorage.ObjectAttrs{BucketName: "limitation", Name: "db2.tb2.001.csv"}, + Content: []byte("5,e\n6,f\n"), + }) + s.server.CreateObject(fakestorage.Object{ + ObjectAttrs: fakestorage.ObjectAttrs{BucketName: "limitation", Name: "db2.tb2.002.csv"}, + Content: []byte("7,g\n8,h\n"), + }) + db, _, err := sqlmock.New() + s.NoError(err) + defer db.Close() + importSDK, err := NewImportSDK( + context.Background(), + fmt.Sprintf("gs://limitation?endpoint=%s&access-key=aaaaaa&secret-access-key=bbbbbb", gcsEndpoint), + db, + WithCharset("utf8"), + WithConcurrency(8), + WithFilter([]string{"*.*"}), + WithSQLMode(mysql.ModeANSIQuotes), + WithLogger(log.L()), + WithSkipInvalidFiles(true), + WithMaxScanFiles(1), + ) + s.NoError(err) + defer importSDK.Close() + metas, err := importSDK.GetTableMetas(context.Background()) + s.NoError(err) + s.Len(metas, 1) + s.Len(metas[0].DataFiles, 1) } -func TestLongestCommonSuffix(t *testing.T) { - strs := []string{"abcXYZ", "defXYZ", "XYZ"} - s := longestCommonSuffix(strs, 0) - require.Equal(t, "XYZ", s) - - // no common suffix - require.Equal(t, "", longestCommonSuffix([]string{"a", "b"}, 0)) - - // empty inputs - require.Equal(t, "", longestCommonSuffix(nil, 0)) - require.Equal(t, "", longestCommonSuffix([]string{}, 0)) - - // same prefix - require.Equal(t, "", longestCommonSuffix([]string{"abc", "abc"}, 3)) - require.Equal(t, "f", longestCommonSuffix([]string{"abcdf", "abcef"}, 3)) -} - -func TestGeneratePrefixSuffixPattern(t *testing.T) { - paths := []string{"pre_middle_suf", "pre_most_suf"} - pattern := generatePrefixSuffixPattern(paths) - // common prefix "pre_m", suffix "_suf" - require.Equal(t, "pre_m*_suf", pattern) - - // empty inputs - require.Equal(t, "", generatePrefixSuffixPattern(nil)) - require.Equal(t, "", generatePrefixSuffixPattern([]string{})) - - // only one file - require.Equal(t, "pre_middle_suf", generatePrefixSuffixPattern([]string{"pre_middle_suf"})) - - // no common prefix/suffix - paths2 := []string{"foo", "bar"} - require.Equal(t, "*", generatePrefixSuffixPattern(paths2)) - - // overlapping prefix/suffix - paths3 := []string{"aaabaaa", "aaa"} - require.Equal(t, "aaa*", generatePrefixSuffixPattern(paths3)) -} - -func generateFileMetas(t *testing.T, paths []string) []mydump.FileInfo { - t.Helper() - - files := make([]mydump.FileInfo, 0, len(paths)) - fileRouter, err := mydump.NewDefaultFileRouter(log.L()) - require.NoError(t, err) - for _, p := range paths { - res, err := fileRouter.Route(p) - require.NoError(t, err) - files = append(files, mydump.FileInfo{ - TableName: res.Table, - FileMeta: mydump.SourceFileMeta{ - Path: p, - Type: res.Type, - Compression: res.Compression, - SortKey: res.Key, - }, +func (s *mockGCSSuite) TestCreateTableMetaByName() { + for i := 0; i < 2; i++ { + s.server.CreateObject(fakestorage.Object{ + ObjectAttrs: fakestorage.ObjectAttrs{BucketName: "specific-table-test", Name: fmt.Sprintf("db%d-schema-create.sql", i)}, + Content: []byte(fmt.Sprintf("CREATE DATABASE IF NOT EXISTS db%d;\n", i)), }) - } - return files -} - -func TestGenerateMydumperPattern(t *testing.T) { - paths := []string{"db.tb.0001.sql", "db.tb.0002.sql"} - p := generateMydumperPattern(generateFileMetas(t, paths)[0]) - require.Equal(t, "db.tb.*.sql", p) - - paths2 := []string{"s3://bucket/dir/db.tb.0001.sql", "s3://bucket/dir/db.tb.0002.sql"} - p2 := generateMydumperPattern(generateFileMetas(t, paths2)[0]) - require.Equal(t, "s3://bucket/dir/db.tb.*.sql", p2) - - // not mydumper pattern - require.Equal(t, "", generateMydumperPattern(mydump.FileInfo{ - TableName: filter.Table{}, - })) -} - -func TestCreateDataFileMeta(t *testing.T) { - fi := mydump.FileInfo{ - TableName: filter.Table{ - Schema: "db", - Name: "table", - }, - FileMeta: mydump.SourceFileMeta{ - Path: "s3://bucket/path/to/f", - FileSize: 123, - Type: mydump.SourceTypeCSV, - Compression: mydump.CompressionGZ, - RealSize: 456, - }, - } - df := createDataFileMeta(fi) - require.Equal(t, "s3://bucket/path/to/f", df.Path) - require.Equal(t, int64(456), df.Size) - require.Equal(t, mydump.SourceTypeCSV, df.Format) - require.Equal(t, mydump.CompressionGZ, df.Compression) -} - -func TestProcessDataFiles(t *testing.T) { - files := []mydump.FileInfo{ - {FileMeta: mydump.SourceFileMeta{Path: "s3://bucket/a", RealSize: 10}}, - {FileMeta: mydump.SourceFileMeta{Path: "s3://bucket/b", RealSize: 20}}, - } - dfm, total := processDataFiles(files) - require.Len(t, dfm, 2) - require.Equal(t, int64(30), total) - require.Equal(t, "s3://bucket/a", dfm[0].Path) - require.Equal(t, "s3://bucket/b", dfm[1].Path) -} - -func TestValidatePattern(t *testing.T) { - tableFiles := map[string]struct{}{ - "a.txt": {}, "b.txt": {}, - } - // only table files in allFiles - smallAll := map[string]mydump.FileInfo{ - "a.txt": {}, "b.txt": {}, - } - require.True(t, isValidPattern("*.txt", tableFiles, smallAll)) - - // allFiles includes an extra file => invalid - fullAll := map[string]mydump.FileInfo{ - "a.txt": {}, "b.txt": {}, "c.txt": {}, - } - require.False(t, isValidPattern("*.txt", tableFiles, fullAll)) - - // If pattern doesn't match our table's file, it's also invalid - require.False(t, isValidPattern("*.csv", tableFiles, smallAll)) - - // empty pattern => invalid - require.False(t, isValidPattern("", tableFiles, smallAll)) -} - -func TestGenerateWildcardPath(t *testing.T) { - // Helper to create allFiles map - createAllFiles := func(paths []string) map[string]mydump.FileInfo { - allFiles := make(map[string]mydump.FileInfo) - for _, p := range paths { - allFiles[p] = mydump.FileInfo{ - FileMeta: mydump.SourceFileMeta{Path: p}, - } + for j := 0; j != 2; j++ { + tableName := fmt.Sprintf("db%d.tb%d", i, j) + s.server.CreateObject(fakestorage.Object{ + ObjectAttrs: fakestorage.ObjectAttrs{BucketName: "specific-table-test", Name: fmt.Sprintf("%s-schema.sql", tableName)}, + Content: []byte(fmt.Sprintf("CREATE TABLE IF NOT EXISTS db%d.tb%d (a INT, b VARCHAR(10));\n", i, j)), + }) + s.server.CreateObject(fakestorage.Object{ + ObjectAttrs: fakestorage.ObjectAttrs{BucketName: "specific-table-test", Name: fmt.Sprintf("%s.001.sql", tableName)}, + Content: []byte(fmt.Sprintf("INSERT INTO db%d.tb%d VALUES (1,'a'),(2,'b');\n", i, j)), + }) + s.server.CreateObject(fakestorage.Object{ + ObjectAttrs: fakestorage.ObjectAttrs{BucketName: "specific-table-test", Name: fmt.Sprintf("%s.002.sql", tableName)}, + Content: []byte(fmt.Sprintf("INSERT INTO db%d.tb%d VALUES (3,'c'),(4,'d');\n", i, j)), + }) } - return allFiles } - // No files - files1 := []mydump.FileInfo{} - allFiles1 := createAllFiles([]string{}) - _, err := generateWildcardPath(files1, allFiles1) - require.Error(t, err) - require.Contains(t, err.Error(), "no data files for table") + db, mock, err := sqlmock.New() + s.NoError(err) + defer db.Close() - // Single file - files2 := generateFileMetas(t, []string{"db.tb.0001.sql"}) - allFiles2 := createAllFiles([]string{"db.tb.0001.sql"}) - path2, err := generateWildcardPath(files2, allFiles2) - require.NoError(t, err) - require.Equal(t, "db.tb.0001.sql", path2) + mock.ExpectQuery(`SELECT SCHEMA_NAME FROM information_schema.SCHEMATA`). + WillReturnRows(sqlmock.NewRows([]string{"SCHEMA_NAME"})) + mock.ExpectExec("CREATE DATABASE IF NOT EXISTS `db1`;"). + WillReturnResult(sqlmock.NewResult(0, 1)) + mock.ExpectExec(regexp.QuoteMeta("CREATE TABLE IF NOT EXISTS `db1`.`tb1` (`a` INT,`b` VARCHAR(10));")). + WillReturnResult(sqlmock.NewResult(0, 1)) - // Mydumper pattern succeeds - files3 := generateFileMetas(t, []string{"db.tb.0001.sql.gz", "db.tb.0002.sql.gz"}) - allFiles3 := createAllFiles([]string{"db.tb.0001.sql.gz", "db.tb.0002.sql.gz"}) - path3, err := generateWildcardPath(files3, allFiles3) - require.NoError(t, err) - require.Equal(t, "db.tb.*.sql.gz", path3) + importSDK, err := NewImportSDK( + context.Background(), + fmt.Sprintf("gs://specific-table-test?endpoint=%s&access-key=aaaaaa&secret-access-key=bbbbbb", gcsEndpoint), + db, + WithConcurrency(1), + ) + s.NoError(err) + defer importSDK.Close() - // Mydumper pattern fails, fallback to prefix/suffix succeeds - files4 := []mydump.FileInfo{ - { - TableName: filter.Table{Schema: "db", Name: "tb"}, - FileMeta: mydump.SourceFileMeta{ - Path: "a.sql", - Type: mydump.SourceTypeSQL, - Compression: mydump.CompressionNone, - }, - }, - { - TableName: filter.Table{Schema: "db", Name: "tb"}, - FileMeta: mydump.SourceFileMeta{ - Path: "b.sql", - Type: mydump.SourceTypeSQL, - Compression: mydump.CompressionNone, - }, - }, - } - allFiles4 := map[string]mydump.FileInfo{ - files4[0].FileMeta.Path: files4[0], - files4[1].FileMeta.Path: files4[1], - } - path4, err := generateWildcardPath(files4, allFiles4) - require.NoError(t, err) - require.Equal(t, "*.sql", path4) - - allFiles4["db-schema.sql"] = mydump.FileInfo{ - FileMeta: mydump.SourceFileMeta{ - Path: "db-schema.sql", - Type: mydump.SourceTypeSQL, - Compression: mydump.CompressionNone, - }, - } - _, err = generateWildcardPath(files4, allFiles4) - require.Error(t, err) - require.Contains(t, err.Error(), "cannot generate a unique wildcard pattern") + err = importSDK.CreateSchemaAndTableByName(context.Background(), "db1", "tb1") + s.NoError(err) } diff --git a/pkg/importsdk/sql_generator.go b/pkg/importsdk/sql_generator.go new file mode 100644 index 0000000000..1b27223ac1 --- /dev/null +++ b/pkg/importsdk/sql_generator.go @@ -0,0 +1,161 @@ +// Copyright 2025 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package importsdk + +import ( + "fmt" + "net/url" + "strings" + + "github.com/pingcap/tidb/pkg/lightning/config" +) + +// SQLGenerator defines the interface for generating IMPORT INTO SQL +type SQLGenerator interface { + GenerateImportSQL(tableMeta *TableMeta, options *ImportOptions) (string, error) +} + +type sqlGenerator struct{} + +// NewSQLGenerator creates a new SQLGenerator +func NewSQLGenerator() SQLGenerator { + return &sqlGenerator{} +} + +// GenerateImportSQL generates the IMPORT INTO SQL statement +func (g *sqlGenerator) GenerateImportSQL(tableMeta *TableMeta, options *ImportOptions) (string, error) { + var sb strings.Builder + sb.WriteString("IMPORT INTO ") + sb.WriteString(escapeIdentifier(tableMeta.Database)) + sb.WriteString(".") + sb.WriteString(escapeIdentifier(tableMeta.Table)) + + path := tableMeta.WildcardPath + if options.ResourceParameters != "" { + u, err := url.Parse(path) + if err == nil { + if u.RawQuery != "" { + u.RawQuery += "&" + options.ResourceParameters + } else { + u.RawQuery = options.ResourceParameters + } + path = u.String() + } + } + + sb.WriteString(" FROM '") + sb.WriteString(path) + sb.WriteString("'") + + if options.Format != "" { + sb.WriteString(" FORMAT '") + sb.WriteString(options.Format) + sb.WriteString("'") + } + + opts, err := g.buildOptions(options) + if err != nil { + return "", err + } + if len(opts) > 0 { + sb.WriteString(" WITH ") + sb.WriteString(strings.Join(opts, ", ")) + } + + return sb.String(), nil +} + +func (g *sqlGenerator) buildOptions(options *ImportOptions) ([]string, error) { + var opts []string + if options.Thread > 0 { + opts = append(opts, fmt.Sprintf("THREAD=%d", options.Thread)) + } + if options.DiskQuota != "" { + opts = append(opts, fmt.Sprintf("DISK_QUOTA='%s'", options.DiskQuota)) + } + if options.MaxWriteSpeed != "" { + opts = append(opts, fmt.Sprintf("MAX_WRITE_SPEED='%s'", options.MaxWriteSpeed)) + } + if options.SplitFile { + opts = append(opts, "SPLIT_FILE") + } + if options.RecordErrors > 0 { + opts = append(opts, fmt.Sprintf("RECORD_ERRORS=%d", options.RecordErrors)) + } + if options.Detached { + opts = append(opts, "DETACHED") + } + if options.CloudStorageURI != "" { + opts = append(opts, fmt.Sprintf("CLOUD_STORAGE_URI='%s'", options.CloudStorageURI)) + } + if options.GroupKey != "" { + opts = append(opts, fmt.Sprintf("GROUP_KEY='%s'", escapeString(options.GroupKey))) + } + if options.SkipRows > 0 { + opts = append(opts, fmt.Sprintf("SKIP_ROWS=%d", options.SkipRows)) + } + if options.CharacterSet != "" { + opts = append(opts, fmt.Sprintf("CHARACTER_SET='%s'", escapeString(options.CharacterSet))) + } + if options.ChecksumTable != "" { + opts = append(opts, fmt.Sprintf("CHECKSUM_TABLE='%s'", escapeString(options.ChecksumTable))) + } + if options.DisableTiKVImportMode { + opts = append(opts, "DISABLE_TIKV_IMPORT_MODE") + } + if options.DisablePrecheck { + opts = append(opts, "DISABLE_PRECHECK") + } + + if options.CSVConfig != nil && options.Format == "csv" { + csvOpts, err := g.buildCSVOptions(options.CSVConfig) + if err != nil { + return nil, err + } + opts = append(opts, csvOpts...) + } + return opts, nil +} + +func (g *sqlGenerator) buildCSVOptions(csvConfig *config.CSVConfig) ([]string, error) { + var opts []string + if csvConfig.FieldsTerminatedBy != "" { + opts = append(opts, fmt.Sprintf("FIELDS_TERMINATED_BY='%s'", escapeString(csvConfig.FieldsTerminatedBy))) + } + if csvConfig.FieldsEnclosedBy != "" { + opts = append(opts, fmt.Sprintf("FIELDS_ENCLOSED_BY='%s'", escapeString(csvConfig.FieldsEnclosedBy))) + } + if csvConfig.FieldsEscapedBy != "" { + opts = append(opts, fmt.Sprintf("FIELDS_ESCAPED_BY='%s'", escapeString(csvConfig.FieldsEscapedBy))) + } + if csvConfig.LinesTerminatedBy != "" { + opts = append(opts, fmt.Sprintf("LINES_TERMINATED_BY='%s'", escapeString(csvConfig.LinesTerminatedBy))) + } + if len(csvConfig.FieldNullDefinedBy) > 0 { + if len(csvConfig.FieldNullDefinedBy) > 1 { + return nil, ErrMultipleFieldsDefinedNullBy + } + opts = append(opts, fmt.Sprintf("FIELDS_DEFINED_NULL_BY='%s'", escapeString(csvConfig.FieldNullDefinedBy[0]))) + } + return opts, nil +} + +func escapeIdentifier(s string) string { + return "`" + strings.ReplaceAll(s, "`", "``") + "`" +} + +func escapeString(s string) string { + return strings.ReplaceAll(strings.ReplaceAll(s, "\\", "\\\\"), "'", "''") +} diff --git a/pkg/importsdk/sql_generator_test.go b/pkg/importsdk/sql_generator_test.go new file mode 100644 index 0000000000..5c4a193b55 --- /dev/null +++ b/pkg/importsdk/sql_generator_test.go @@ -0,0 +1,173 @@ +// Copyright 2025 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package importsdk + +import ( + "testing" + + "github.com/pingcap/tidb/pkg/lightning/config" + "github.com/stretchr/testify/require" +) + +func TestGenerateImportSQL(t *testing.T) { + gen := NewSQLGenerator() + + defaultTableMeta := &TableMeta{ + Database: "test_db", + Table: "test_table", + WildcardPath: "s3://bucket/path/*.csv", + } + + tests := []struct { + name string + tableMeta *TableMeta + options *ImportOptions + expected string + expectedError string + }{ + { + name: "Basic", + options: &ImportOptions{ + Format: "csv", + }, + expected: "IMPORT INTO `test_db`.`test_table` FROM 's3://bucket/path/*.csv' FORMAT 'csv'", + }, + { + name: "With S3 Credentials", + options: &ImportOptions{ + Format: "csv", + ResourceParameters: "access-key=ak&endpoint=http%3A%2F%2Fminio%3A9000&secret-access-key=sk", + }, + expected: "IMPORT INTO `test_db`.`test_table` FROM 's3://bucket/path/*.csv?access-key=ak&endpoint=http%3A%2F%2Fminio%3A9000&secret-access-key=sk' FORMAT 'csv'", + }, + { + name: "With Options", + options: &ImportOptions{ + Format: "csv", + Thread: 4, + Detached: true, + MaxWriteSpeed: "100MiB", + }, + expected: "IMPORT INTO `test_db`.`test_table` FROM 's3://bucket/path/*.csv' FORMAT 'csv' WITH THREAD=4, MAX_WRITE_SPEED='100MiB', DETACHED", + }, + { + name: "All Options", + options: &ImportOptions{ + Format: "csv", + Thread: 8, + DiskQuota: "100GiB", + MaxWriteSpeed: "200MiB", + SplitFile: true, + RecordErrors: 100, + Detached: true, + CloudStorageURI: "s3://bucket/storage", + GroupKey: "group1", + SkipRows: 1, + CharacterSet: "utf8mb4", + ChecksumTable: "test_db.checksum_table", + DisableTiKVImportMode: true, + DisablePrecheck: true, + }, + expected: "IMPORT INTO `test_db`.`test_table` FROM 's3://bucket/path/*.csv' FORMAT 'csv' WITH THREAD=8, DISK_QUOTA='100GiB', MAX_WRITE_SPEED='200MiB', SPLIT_FILE, RECORD_ERRORS=100, DETACHED, CLOUD_STORAGE_URI='s3://bucket/storage', GROUP_KEY='group1', SKIP_ROWS=1, CHARACTER_SET='utf8mb4', CHECKSUM_TABLE='test_db.checksum_table', DISABLE_TIKV_IMPORT_MODE, DISABLE_PRECHECK", + }, + { + name: "With CSV Config", + options: &ImportOptions{ + Format: "csv", + CSVConfig: &config.CSVConfig{ + FieldsTerminatedBy: ",", + FieldsEnclosedBy: "\"", + FieldsEscapedBy: "\\", + LinesTerminatedBy: "\n", + FieldNullDefinedBy: []string{"NULL"}, + }, + }, + expected: "IMPORT INTO `test_db`.`test_table` FROM 's3://bucket/path/*.csv' FORMAT 'csv' WITH FIELDS_TERMINATED_BY=',', FIELDS_ENCLOSED_BY='\"', FIELDS_ESCAPED_BY='\\\\', LINES_TERMINATED_BY='\n', FIELDS_DEFINED_NULL_BY='NULL'", + }, + { + name: "With Cloud Storage URI", + options: &ImportOptions{ + Format: "parquet", + CloudStorageURI: "s3://bucket/storage", + }, + expected: "IMPORT INTO `test_db`.`test_table` FROM 's3://bucket/path/*.csv' FORMAT 'parquet' WITH CLOUD_STORAGE_URI='s3://bucket/storage'", + }, + { + name: "Multiple Null Defined By", + options: &ImportOptions{ + Format: "csv", + CSVConfig: &config.CSVConfig{ + FieldNullDefinedBy: []string{"NULL", "\\N"}, + }, + }, + expectedError: "IMPORT INTO only supports one FIELDS_DEFINED_NULL_BY value", + }, + { + name: "Resource Parameters Append", + tableMeta: &TableMeta{ + Database: "test_db", + Table: "test_table", + WildcardPath: "s3://bucket/path/*.csv?foo=bar", + }, + options: &ImportOptions{ + Format: "csv", + ResourceParameters: "access-key=ak", + }, + expected: "IMPORT INTO `test_db`.`test_table` FROM 's3://bucket/path/*.csv?foo=bar&access-key=ak' FORMAT 'csv'", + }, + { + name: "Escaping", + options: &ImportOptions{ + Format: "csv", + GroupKey: "group'1", + CharacterSet: "utf8'mb4", + CSVConfig: &config.CSVConfig{ + FieldsTerminatedBy: "'", + FieldsEnclosedBy: "\\", + }, + }, + expected: "IMPORT INTO `test_db`.`test_table` FROM 's3://bucket/path/*.csv' FORMAT 'csv' WITH GROUP_KEY='group''1', CHARACTER_SET='utf8''mb4', FIELDS_TERMINATED_BY='''', FIELDS_ENCLOSED_BY='\\\\'", + }, + { + name: "Identifier Escaping", + tableMeta: &TableMeta{ + Database: "test`db", + Table: "test`table", + WildcardPath: "s3://bucket/path/*.csv", + }, + options: &ImportOptions{ + Format: "csv", + }, + expected: "IMPORT INTO `test``db`.`test``table` FROM 's3://bucket/path/*.csv' FORMAT 'csv'", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + tm := tt.tableMeta + if tm == nil { + tm = defaultTableMeta + } + sql, err := gen.GenerateImportSQL(tm, tt.options) + if tt.expectedError != "" { + require.Error(t, err) + require.Contains(t, err.Error(), tt.expectedError) + } else { + require.NoError(t, err) + require.Equal(t, tt.expected, sql) + } + }) + } +} diff --git a/pkg/testkit/BUILD.bazel b/pkg/testkit/BUILD.bazel index a0fd0f8f1f..5748b1123b 100644 --- a/pkg/testkit/BUILD.bazel +++ b/pkg/testkit/BUILD.bazel @@ -4,6 +4,7 @@ go_library( name = "testkit", srcs = [ "asynctestkit.go", + "db_driver.go", "dbtestkit.go", "mocksessionmanager.go", "mockstore.go", @@ -71,8 +72,12 @@ go_library( go_test( name = "testkit_test", timeout = "short", - srcs = ["testkit_test.go"], + srcs = [ + "db_driver_test.go", + "testkit_test.go", + ], embed = [":testkit"], flaky = True, + shard_count = 2, deps = ["@com_github_stretchr_testify//require"], ) diff --git a/pkg/testkit/db_driver.go b/pkg/testkit/db_driver.go new file mode 100644 index 0000000000..a8ac755e19 --- /dev/null +++ b/pkg/testkit/db_driver.go @@ -0,0 +1,308 @@ +// Copyright 2025 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package testkit + +import ( + "context" + "database/sql" + "database/sql/driver" + "fmt" + "io" + "strconv" + "sync" + "sync/atomic" + + "github.com/pingcap/errors" + "github.com/pingcap/tidb/pkg/expression" + "github.com/pingcap/tidb/pkg/parser/mysql" + "github.com/pingcap/tidb/pkg/session/sessionapi" + "github.com/pingcap/tidb/pkg/types" + "github.com/pingcap/tidb/pkg/util/chunk" + "github.com/pingcap/tidb/pkg/util/sqlexec" +) + +var ( + tkMapMu sync.RWMutex + tkMap = make(map[string]*TestKit) + tkIDSeq int64 +) + +func init() { + sql.Register("testkit", &testKitDriver{}) +} + +// CreateMockDB creates a *sql.DB that uses the TestKit's store to create sessions. +func CreateMockDB(tk *TestKit) *sql.DB { + id := strconv.FormatInt(atomic.AddInt64(&tkIDSeq, 1), 10) + tkMapMu.Lock() + tkMap[id] = tk + tkMapMu.Unlock() + + db, err := sql.Open("testkit", id) + if err != nil { + panic(err) + } + return db +} + +type testKitDriver struct{} + +func (d *testKitDriver) Open(name string) (driver.Conn, error) { + tkMapMu.RLock() + tk, ok := tkMap[name] + tkMapMu.RUnlock() + if !ok { + return nil, fmt.Errorf("testkit not found for %s", name) + } + se := NewSession(tk.t, tk.store) + return &testKitConn{se: se}, nil +} + +type testKitConn struct { + se sessionapi.Session +} + +func (c *testKitConn) Prepare(query string) (driver.Stmt, error) { + return &testKitStmt{c: c, query: query}, nil +} + +func (c *testKitConn) Close() error { + c.se.Close() + return nil +} + +func (c *testKitConn) Begin() (driver.Tx, error) { + _, err := c.Exec("BEGIN", nil) + if err != nil { + return nil, err + } + return &testKitTxn{c: c}, nil +} + +func (c *testKitConn) Exec(query string, args []driver.Value) (driver.Result, error) { + return nil, driver.ErrSkip +} + +func (c *testKitConn) ExecContext(ctx context.Context, query string, args []driver.NamedValue) (driver.Result, error) { + qArgs := make([]any, len(args)) + for i, a := range args { + qArgs[i] = a.Value + } + + rs, err := c.execute(ctx, query, qArgs) + if err != nil { + return nil, err + } + if rs != nil { + if err := rs.Close(); err != nil { + return nil, err + } + } + return &testKitResult{ + lastInsertID: int64(c.se.LastInsertID()), + rowsAffected: int64(c.se.AffectedRows()), + }, nil +} + +func (c *testKitConn) QueryContext(ctx context.Context, query string, args []driver.NamedValue) (driver.Rows, error) { + qArgs := make([]any, len(args)) + for i, a := range args { + qArgs[i] = a.Value + } + + rs, err := c.execute(ctx, query, qArgs) + if err != nil { + return nil, err + } + if rs == nil { + return nil, nil + } + return &testKitRows{rs: rs}, nil +} + +func (c *testKitConn) execute(ctx context.Context, sql string, args []any) (sqlexec.RecordSet, error) { + // Set the command value to ComQuery, so that the process info can be updated correctly + c.se.SetCommandValue(mysql.ComQuery) + defer c.se.SetCommandValue(mysql.ComSleep) + + if len(args) == 0 { + rss, err := c.se.Execute(ctx, sql) + if err != nil { + return nil, err + } + if len(rss) == 0 { + return nil, nil + } + return rss[0], nil + } + + stmtID, _, _, err := c.se.PrepareStmt(sql) + if err != nil { + return nil, errors.Trace(err) + } + params := expression.Args2Expressions4Test(args...) + rs, err := c.se.ExecutePreparedStmt(ctx, stmtID, params) + if err != nil { + return rs, errors.Trace(err) + } + err = c.se.DropPreparedStmt(stmtID) + if err != nil { + return rs, errors.Trace(err) + } + return rs, nil +} + +type testKitStmt struct { + c *testKitConn + query string +} + +func (s *testKitStmt) Close() error { + return nil +} + +func (s *testKitStmt) NumInput() int { + return -1 +} + +func (s *testKitStmt) Exec(args []driver.Value) (driver.Result, error) { + return nil, driver.ErrSkip +} + +func (s *testKitStmt) ExecContext(ctx context.Context, args []driver.NamedValue) (driver.Result, error) { + return s.c.ExecContext(ctx, s.query, args) +} + +func (s *testKitStmt) Query(args []driver.Value) (driver.Rows, error) { + return nil, driver.ErrSkip +} + +func (s *testKitStmt) QueryContext(ctx context.Context, args []driver.NamedValue) (driver.Rows, error) { + return s.c.QueryContext(ctx, s.query, args) +} + +type testKitResult struct { + lastInsertID int64 + rowsAffected int64 +} + +func (r *testKitResult) LastInsertId() (int64, error) { + return r.lastInsertID, nil +} + +func (r *testKitResult) RowsAffected() (int64, error) { + return r.rowsAffected, nil +} + +type testKitRows struct { + rs sqlexec.RecordSet + chunk *chunk.Chunk + it *chunk.Iterator4Chunk +} + +func (r *testKitRows) Columns() []string { + fields := r.rs.Fields() + cols := make([]string, len(fields)) + for i, f := range fields { + cols[i] = f.Column.Name.O + } + return cols +} + +func (r *testKitRows) Close() error { + return r.rs.Close() +} + +func (r *testKitRows) Next(dest []driver.Value) error { + if r.chunk == nil { + r.chunk = r.rs.NewChunk(nil) + } + + var row chunk.Row + if r.it == nil { + err := r.rs.Next(context.Background(), r.chunk) + if err != nil { + return err + } + if r.chunk.NumRows() == 0 { + return io.EOF + } + r.it = chunk.NewIterator4Chunk(r.chunk) + row = r.it.Begin() + } else { + row = r.it.Next() + if row.IsEmpty() { + err := r.rs.Next(context.Background(), r.chunk) + if err != nil { + return err + } + if r.chunk.NumRows() == 0 { + return io.EOF + } + r.it = chunk.NewIterator4Chunk(r.chunk) + row = r.it.Begin() + } + } + + for i := range row.Len() { + d := row.GetDatum(i, &r.rs.Fields()[i].Column.FieldType) + // Handle NULL + if d.IsNull() { + dest[i] = nil + } else { + // Convert to appropriate type if needed, or just return string/bytes/int/float + // driver.Value allows int64, float64, bool, []byte, string, time.Time + // Datum.GetValue() returns interface{} which might be compatible. + v := d.GetValue() + switch x := v.(type) { + case []byte: + dest[i] = x + case string: + dest[i] = x + case int64: + dest[i] = x + case uint64: + dest[i] = x + case float64: + dest[i] = x + case float32: + dest[i] = x + case types.Time: + dest[i] = x.String() + case types.Duration: + dest[i] = x.String() + case *types.MyDecimal: + dest[i] = x.String() + default: + dest[i] = x + } + } + } + return nil +} + +type testKitTxn struct { + c *testKitConn +} + +func (t *testKitTxn) Commit() error { + _, err := t.c.Exec("COMMIT", nil) + return err +} + +func (t *testKitTxn) Rollback() error { + _, err := t.c.Exec("ROLLBACK", nil) + return err +} diff --git a/pkg/testkit/db_driver_test.go b/pkg/testkit/db_driver_test.go new file mode 100644 index 0000000000..96fa00a4c7 --- /dev/null +++ b/pkg/testkit/db_driver_test.go @@ -0,0 +1,79 @@ +// Copyright 2025 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package testkit_test + +import ( + "testing" + + "github.com/pingcap/tidb/pkg/testkit" + "github.com/stretchr/testify/require" +) + +func TestMockDB(t *testing.T) { + store := testkit.CreateMockStore(t) + + tk := testkit.NewTestKit(t, store) + + db := testkit.CreateMockDB(tk) + defer db.Close() + + var err error + _, err = db.Exec("use test") + require.NoError(t, err) + _, err = db.Exec("create table t (id int, v varchar(255))") + require.NoError(t, err) + _, err = db.Exec("insert into t values (1, 'a'), (2, 'b')") + require.NoError(t, err) + + // Test Query + rows, err := db.Query("select * from t order by id") + require.NoError(t, err) + defer rows.Close() + + var id int + var v string + require.True(t, rows.Next()) + require.NoError(t, rows.Scan(&id, &v)) + require.Equal(t, 1, id) + require.Equal(t, "a", v) + + require.True(t, rows.Next()) + require.NoError(t, rows.Scan(&id, &v)) + require.Equal(t, 2, id) + require.Equal(t, "b", v) + + require.False(t, rows.Next()) + require.NoError(t, rows.Err()) + + // Test Exec + res, err := db.Exec("insert into t values (3, 'c')") + require.NoError(t, err) + affected, err := res.RowsAffected() + require.NoError(t, err) + require.Equal(t, int64(1), affected) + + // Verify with TestKit + tk.MustExec("use test") + tk.MustQuery("select * from t where id = 3").Check(testkit.Rows("3 c")) + + // Test Prepare + stmt, err := db.Prepare("select v from t where id = ?") + require.NoError(t, err) + defer stmt.Close() + + row := stmt.QueryRow(2) + require.NoError(t, row.Scan(&v)) + require.Equal(t, "b", v) +} diff --git a/tests/realtikvtest/importintotest/BUILD.bazel b/tests/realtikvtest/importintotest/BUILD.bazel index a9a6dc41c2..a0e5c58102 100644 --- a/tests/realtikvtest/importintotest/BUILD.bazel +++ b/tests/realtikvtest/importintotest/BUILD.bazel @@ -11,6 +11,7 @@ go_test( "one_parquet_test.go", "parquet_test.go", "precheck_test.go", + "sdk_test.go", "util_test.go", ], embedsrcs = [ @@ -35,6 +36,7 @@ go_test( "//pkg/domain", "//pkg/domain/infosync", "//pkg/executor/importer", + "//pkg/importsdk", "//pkg/infoschema", "//pkg/kv", "//pkg/lightning/backend/local", diff --git a/tests/realtikvtest/importintotest/sdk_test.go b/tests/realtikvtest/importintotest/sdk_test.go new file mode 100644 index 0000000000..021b3d2dec --- /dev/null +++ b/tests/realtikvtest/importintotest/sdk_test.go @@ -0,0 +1,180 @@ +// Copyright 2025 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package importintotest + +import ( + "context" + "os" + "path/filepath" + "time" + + "github.com/pingcap/tidb/pkg/importsdk" + "github.com/pingcap/tidb/pkg/testkit" + "github.com/pingcap/tidb/pkg/testkit/testfailpoint" +) + +func (s *mockGCSSuite) TestImportSDK() { + ctx := context.Background() + + // 1. Prepare source data (local files) + tmpDir := s.T().TempDir() + + // Create schema files for CreateSchemasAndTables + // DB schema + err := os.WriteFile(filepath.Join(tmpDir, "importsdk_test-schema-create.sql"), []byte("CREATE DATABASE importsdk_test;"), 0644) + s.NoError(err) + // Table schema + err = os.WriteFile(filepath.Join(tmpDir, "importsdk_test.t-schema.sql"), []byte("CREATE TABLE t (id int, v varchar(255));"), 0644) + s.NoError(err) + + // Data file: mydumper format: {schema}.{table}.{seq}.csv + fileName := "importsdk_test.t.001.csv" + content := []byte("1,test1\n2,test2") + err = os.WriteFile(filepath.Join(tmpDir, fileName), content, 0644) + s.NoError(err) + + // 2. Prepare DB connection + // Ensure clean state + s.tk.MustExec("DROP DATABASE IF EXISTS importsdk_test") + + // Create a *sql.DB using the testkit driver + db := testkit.CreateMockDB(s.tk) + defer db.Close() + + // 3. Initialize SDK + // Use local file path. + sdk, err := importsdk.NewImportSDK(ctx, "file://"+tmpDir, db) + s.NoError(err) + defer sdk.Close() + + // 4. Test FileScanner.CreateSchemasAndTables + err = sdk.CreateSchemasAndTables(ctx) + s.NoError(err) + // Verify DB and Table exist + s.tk.MustExec("USE importsdk_test") + s.tk.MustExec("SHOW CREATE TABLE t") + + // 4.1 Test FileScanner.CreateSchemaAndTableByName + s.tk.MustExec("DROP TABLE t") + err = sdk.CreateSchemaAndTableByName(ctx, "importsdk_test", "t") + s.NoError(err) + s.tk.MustExec("SHOW CREATE TABLE t") + + // 5. Test FileScanner.GetTableMetas + metas, err := sdk.GetTableMetas(ctx) + s.NoError(err) + s.Len(metas, 1) + s.Equal("importsdk_test", metas[0].Database) + s.Equal("t", metas[0].Table) + s.Equal(int64(len(content)), metas[0].TotalSize) + + // 6. Test FileScanner.GetTableMetaByName + meta, err := sdk.GetTableMetaByName(ctx, "importsdk_test", "t") + s.NoError(err) + s.Equal("importsdk_test", meta.Database) + s.Equal("t", meta.Table) + + // 7. Test FileScanner.GetTotalSize + totalSize := sdk.GetTotalSize(ctx) + s.Equal(int64(len(content)), totalSize) + + // 8. Test SQLGenerator.GenerateImportSQL + opts := &importsdk.ImportOptions{ + Thread: 4, + Detached: true, + } + sql, err := sdk.GenerateImportSQL(meta, opts) + s.NoError(err) + s.Contains(sql, "IMPORT INTO `importsdk_test`.`t` FROM") + s.Contains(sql, "THREAD=4") + s.Contains(sql, "DETACHED") + + // 9. Test JobManager.SubmitJob + jobID, err := sdk.SubmitJob(ctx, sql) + s.NoError(err) + s.Greater(jobID, int64(0)) + + // 10. Test JobManager.GetJobStatus + status, err := sdk.GetJobStatus(ctx, jobID) + s.NoError(err) + s.Equal(jobID, status.JobID) + s.Equal("`importsdk_test`.`t`", status.TargetTable) + + // Wait for job to finish + s.Eventually(func() bool { + status, err := sdk.GetJobStatus(ctx, jobID) + s.NoError(err) + return status.Status == "finished" + }, 30*time.Second, 500*time.Millisecond) + + // Verify data + s.tk.MustQuery("SELECT * FROM importsdk_test.t").Check(testkit.Rows("1 test1", "2 test2")) + + // 11. Test JobManager.CancelJob with failpoint + s.tk.MustExec("TRUNCATE TABLE t") + testfailpoint.Enable(s.T(), "github.com/pingcap/tidb/pkg/disttask/importinto/syncBeforeJobStarted", "pause") + + jobID2, err := sdk.SubmitJob(ctx, sql) + s.NoError(err) + + // Cancel the job + err = sdk.CancelJob(ctx, jobID2) + s.NoError(err) + + // Unpause + testfailpoint.Disable(s.T(), "github.com/pingcap/tidb/pkg/disttask/importinto/syncBeforeJobStarted") + + // Verify status is cancelled + s.Eventually(func() bool { + status, err := sdk.GetJobStatus(ctx, jobID2) + s.NoError(err) + return status.Status == "cancelled" + }, 30*time.Second, 500*time.Millisecond) + + // 12. Test JobManager.GetGroupSummary + // Submit a job with GroupKey + groupKey := "test_group_key" + optsWithGroup := &importsdk.ImportOptions{ + Thread: 4, + Detached: true, + GroupKey: groupKey, + } + sqlWithGroup, err := sdk.GenerateImportSQL(meta, optsWithGroup) + s.NoError(err) + jobID3, err := sdk.SubmitJob(ctx, sqlWithGroup) + s.NoError(err) + s.Greater(jobID3, int64(0)) + + // Wait for job to finish + s.Eventually(func() bool { + status, err := sdk.GetJobStatus(ctx, jobID3) + s.NoError(err) + return status.Status == "finished" + }, 30*time.Second, 500*time.Millisecond) + + // Get group summary + groupSummary, err := sdk.GetGroupSummary(ctx, groupKey) + s.NoError(err) + s.Equal(groupKey, groupSummary.GroupKey) + s.Equal(int64(1), groupSummary.TotalJobs) + s.Equal(int64(1), groupSummary.Completed) + + // 13. Test JobManager.GetJobsByGroup + jobs, err := sdk.GetJobsByGroup(ctx, groupKey) + s.NoError(err) + s.Len(jobs, 1) + s.Equal(jobID3, jobs[0].JobID) + s.Equal("finished", jobs[0].Status) +}