importinto: refactor and add interface for import into sdk (#65094)

ref pingcap/tidb#65092
This commit is contained in:
GMHDBJD
2025-12-23 15:10:08 +08:00
committed by GitHub
parent 9aa228717e
commit 0f706ffd90
24 changed files with 3297 additions and 659 deletions

View File

@ -546,6 +546,7 @@ mock_import: mockgen
tools/bin/mockgen -package mock github.com/pingcap/tidb/pkg/disttask/framework/planner LogicalPlan,PipelineSpec > pkg/disttask/framework/mock/plan_mock.go
tools/bin/mockgen -package mock github.com/pingcap/tidb/pkg/disttask/framework/storage Manager > pkg/disttask/framework/mock/storage_manager_mock.go
tools/bin/mockgen -package mock github.com/pingcap/tidb/pkg/ingestor/ingestcli Client,WriteClient > pkg/ingestor/ingestcli/mock/client_mock.go
tools/bin/mockgen -package mock github.com/pingcap/tidb/pkg/importsdk FileScanner,JobManager,SQLGenerator,SDK > pkg/importsdk/mock/sdk_mock.go
.PHONY: gen_mock
gen_mock: mockgen

View File

@ -2,26 +2,45 @@ load("@io_bazel_rules_go//go:def.bzl", "go_library", "go_test")
go_library(
name = "importsdk",
srcs = ["sdk.go"],
srcs = [
"config.go",
"error.go",
"file_scanner.go",
"job_manager.go",
"model.go",
"pattern.go",
"sdk.go",
"sql_generator.go",
],
importpath = "github.com/pingcap/tidb/pkg/importsdk",
visibility = ["//visibility:public"],
deps = [
"//br/pkg/storage",
"//pkg/lightning/common",
"//pkg/lightning/config",
"//pkg/lightning/log",
"//pkg/lightning/mydump",
"//pkg/parser/mysql",
"@com_github_pingcap_errors//:errors",
"@org_uber_go_zap//:zap",
],
)
go_test(
name = "importsdk_test",
timeout = "short",
srcs = ["sdk_test.go"],
srcs = [
"config_test.go",
"file_scanner_test.go",
"job_manager_test.go",
"model_test.go",
"pattern_test.go",
"sdk_test.go",
"sql_generator_test.go",
],
embed = [":importsdk"],
flaky = True,
shard_count = 9,
shard_count = 20,
deps = [
"//pkg/lightning/config",
"//pkg/lightning/log",

118
pkg/importsdk/config.go Normal file
View File

@ -0,0 +1,118 @@
// Copyright 2025 PingCAP, Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package importsdk
import (
"github.com/pingcap/tidb/pkg/lightning/config"
"github.com/pingcap/tidb/pkg/lightning/log"
"github.com/pingcap/tidb/pkg/parser/mysql"
)
// SDKOption customizes the SDK configuration
type SDKOption func(*SDKConfig)
// SDKConfig is the configuration for the SDK
type SDKConfig struct {
// Loader options
concurrency int
sqlMode mysql.SQLMode
fileRouteRules []*config.FileRouteRule
routes config.Routes
filter []string
charset string
maxScanFiles *int
skipInvalidFiles bool
// General options
logger log.Logger
}
func defaultSDKConfig() *SDKConfig {
return &SDKConfig{
concurrency: 4,
filter: config.GetDefaultFilter(),
logger: log.L(),
charset: "auto",
}
}
// WithConcurrency sets the number of concurrent DB/Table creation workers.
func WithConcurrency(n int) SDKOption {
return func(cfg *SDKConfig) {
if n > 0 {
cfg.concurrency = n
}
}
}
// WithLogger specifies a custom logger
func WithLogger(logger log.Logger) SDKOption {
return func(cfg *SDKConfig) {
cfg.logger = logger
}
}
// WithSQLMode specifies the SQL mode for schema parsing
func WithSQLMode(mode mysql.SQLMode) SDKOption {
return func(cfg *SDKConfig) {
cfg.sqlMode = mode
}
}
// WithFilter specifies a filter for the loader
func WithFilter(filter []string) SDKOption {
return func(cfg *SDKConfig) {
cfg.filter = filter
}
}
// WithFileRouters sets the file routing rules.
func WithFileRouters(rules []*config.FileRouteRule) SDKOption {
return func(c *SDKConfig) {
c.fileRouteRules = rules
}
}
// WithRoutes sets the table routing rules.
func WithRoutes(routes config.Routes) SDKOption {
return func(c *SDKConfig) {
c.routes = routes
}
}
// WithCharset specifies the character set for import (default "auto").
func WithCharset(cs string) SDKOption {
return func(cfg *SDKConfig) {
if cs != "" {
cfg.charset = cs
}
}
}
// WithMaxScanFiles specifies custom file scan limitation
func WithMaxScanFiles(limit int) SDKOption {
return func(cfg *SDKConfig) {
if limit > 0 {
cfg.maxScanFiles = &limit
}
}
}
// WithSkipInvalidFiles specifies whether sdk need raise error on found invalid files
func WithSkipInvalidFiles(skip bool) SDKOption {
return func(cfg *SDKConfig) {
cfg.skipInvalidFiles = skip
}
}

View File

@ -0,0 +1,89 @@
// Copyright 2025 PingCAP, Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package importsdk
import (
"testing"
"github.com/pingcap/tidb/pkg/lightning/config"
"github.com/pingcap/tidb/pkg/lightning/log"
"github.com/pingcap/tidb/pkg/parser/mysql"
"github.com/stretchr/testify/require"
)
func TestDefaultSDKConfig(t *testing.T) {
cfg := defaultSDKConfig()
require.Equal(t, 4, cfg.concurrency)
require.Equal(t, config.GetDefaultFilter(), cfg.filter)
require.Equal(t, log.L(), cfg.logger)
require.Equal(t, "auto", cfg.charset)
}
func TestSDKOptions(t *testing.T) {
cfg := defaultSDKConfig()
// Test WithConcurrency
WithConcurrency(10)(cfg)
require.Equal(t, 10, cfg.concurrency)
WithConcurrency(-1)(cfg) // Should ignore invalid value
require.Equal(t, 10, cfg.concurrency)
// Test WithLogger
logger := log.L()
WithLogger(logger)(cfg)
require.Equal(t, logger, cfg.logger)
// Test WithSQLMode
mode := mysql.ModeStrictTransTables
WithSQLMode(mode)(cfg)
require.Equal(t, mode, cfg.sqlMode)
// Test WithFilter
filter := []string{"*.*"}
WithFilter(filter)(cfg)
require.Equal(t, filter, cfg.filter)
// Test WithFileRouters
routers := []*config.FileRouteRule{{Schema: "test"}}
WithFileRouters(routers)(cfg)
require.Equal(t, routers, cfg.fileRouteRules)
// Test WithCharset
WithCharset("utf8mb4")(cfg)
require.Equal(t, "utf8mb4", cfg.charset)
WithCharset("")(cfg) // Should ignore empty value
require.Equal(t, "utf8mb4", cfg.charset)
// Test WithMaxScanFiles
WithMaxScanFiles(100)(cfg)
require.NotNil(t, cfg.maxScanFiles)
require.Equal(t, 100, *cfg.maxScanFiles)
// Test WithSkipInvalidFiles
WithSkipInvalidFiles(true)(cfg)
require.True(t, cfg.skipInvalidFiles)
// Test WithRoutes
routes := config.Routes{
{
SchemaPattern: "source_db",
TablePattern: "source_table",
TargetSchema: "target_db",
TargetTable: "target_table",
},
}
WithRoutes(routes)(cfg)
require.Equal(t, routes, cfg.routes)
}

46
pkg/importsdk/error.go Normal file
View File

@ -0,0 +1,46 @@
// Copyright 2025 PingCAP, Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package importsdk
import "github.com/pingcap/errors"
var (
// ErrNoDatabasesFound indicates that the dump source contains no recognizable databases.
ErrNoDatabasesFound = errors.New("no databases found in the source path")
// ErrSchemaNotFound indicates the target schema doesn't exist in the dump source.
ErrSchemaNotFound = errors.New("schema not found")
// ErrTableNotFound indicates the target table doesn't exist in the dump source.
ErrTableNotFound = errors.New("table not found")
// ErrNoTableDataFiles indicates a table has zero data files and thus cannot proceed.
ErrNoTableDataFiles = errors.New("no data files for table")
// ErrWildcardNotSpecific indicates a wildcard cannot uniquely match the table's files.
ErrWildcardNotSpecific = errors.New("cannot generate a unique wildcard pattern for the table's data files")
// ErrJobNotFound indicates the job is not found.
ErrJobNotFound = errors.New("job not found")
// ErrNoJobIDReturned indicates that the submit job query did not return a job ID.
ErrNoJobIDReturned = errors.New("no job id returned")
// ErrInvalidOptions indicates the options are invalid.
ErrInvalidOptions = errors.New("invalid options")
// ErrMultipleFieldsDefinedNullBy indicates that multiple FIELDS_DEFINED_NULL_BY values are defined, which is not supported.
ErrMultipleFieldsDefinedNullBy = errors.New("IMPORT INTO only supports one FIELDS_DEFINED_NULL_BY value")
// ErrParseStorageURL indicates that the storage backend URL is invalid.
ErrParseStorageURL = errors.New("failed to parse storage backend URL")
// ErrCreateExternalStorage indicates that the external storage cannot be created.
ErrCreateExternalStorage = errors.New("failed to create external storage")
// ErrCreateLoader indicates that the MyDump loader cannot be created.
ErrCreateLoader = errors.New("failed to create MyDump loader")
// ErrCreateSchema indicates that creating schemas and tables failed.
ErrCreateSchema = errors.New("failed to create schemas and tables")
)

View File

@ -0,0 +1,271 @@
// Copyright 2025 PingCAP, Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package importsdk
import (
"context"
"database/sql"
"strings"
"github.com/pingcap/errors"
"github.com/pingcap/tidb/br/pkg/storage"
"github.com/pingcap/tidb/pkg/lightning/common"
"github.com/pingcap/tidb/pkg/lightning/log"
"github.com/pingcap/tidb/pkg/lightning/mydump"
"go.uber.org/zap"
)
// FileScanner defines the interface for scanning files
type FileScanner interface {
CreateSchemasAndTables(ctx context.Context) error
CreateSchemaAndTableByName(ctx context.Context, schema, table string) error
GetTableMetas(ctx context.Context) ([]*TableMeta, error)
GetTableMetaByName(ctx context.Context, db, table string) (*TableMeta, error)
GetTotalSize(ctx context.Context) int64
Close() error
}
type fileScanner struct {
sourcePath string
db *sql.DB
store storage.ExternalStorage
loader *mydump.MDLoader
logger log.Logger
config *SDKConfig
}
// NewFileScanner creates a new FileScanner
func NewFileScanner(ctx context.Context, sourcePath string, db *sql.DB, cfg *SDKConfig) (FileScanner, error) {
u, err := storage.ParseBackend(sourcePath, nil)
if err != nil {
return nil, errors.Annotatef(ErrParseStorageURL, "source=%s, err=%v", sourcePath, err)
}
store, err := storage.New(ctx, u, &storage.ExternalStorageOptions{})
if err != nil {
return nil, errors.Annotatef(ErrCreateExternalStorage, "source=%s, err=%v", sourcePath, err)
}
ldrCfg := mydump.LoaderConfig{
SourceURL: sourcePath,
Filter: cfg.filter,
FileRouters: cfg.fileRouteRules,
DefaultFileRules: len(cfg.fileRouteRules) == 0,
CharacterSet: cfg.charset,
Routes: cfg.routes,
}
var loaderOptions []mydump.MDLoaderSetupOption
if cfg.maxScanFiles != nil && *cfg.maxScanFiles > 0 {
loaderOptions = append(loaderOptions, mydump.WithMaxScanFiles(*cfg.maxScanFiles))
}
if cfg.concurrency > 0 {
loaderOptions = append(loaderOptions, mydump.WithScanFileConcurrency(cfg.concurrency))
}
// TODO: we can skip some time-consuming operation in constructFileInfo (like get real size of compressed file).
loader, err := mydump.NewLoaderWithStore(ctx, ldrCfg, store, loaderOptions...)
if err != nil {
if loader == nil || !errors.ErrorEqual(err, common.ErrTooManySourceFiles) {
return nil, errors.Annotatef(ErrCreateLoader, "source=%s, charset=%s, err=%v", sourcePath, cfg.charset, err)
}
}
return &fileScanner{
sourcePath: sourcePath,
db: db,
store: store,
loader: loader,
logger: cfg.logger,
config: cfg,
}, nil
}
func (s *fileScanner) CreateSchemasAndTables(ctx context.Context) error {
dbMetas := s.loader.GetDatabases()
if len(dbMetas) == 0 {
return errors.Annotatef(ErrNoDatabasesFound, "source=%s", s.sourcePath)
}
// Create all schemas and tables
importer := mydump.NewSchemaImporter(
s.logger,
s.config.sqlMode,
s.db,
s.store,
s.config.concurrency,
)
err := importer.Run(ctx, dbMetas)
if err != nil {
return errors.Annotatef(ErrCreateSchema, "source=%s, db_count=%d, err=%v", s.sourcePath, len(dbMetas), err)
}
return nil
}
// CreateSchemaAndTableByName creates specific table and database schema from source
func (s *fileScanner) CreateSchemaAndTableByName(ctx context.Context, schema, table string) error {
dbMetas := s.loader.GetDatabases()
// Find the specific table
for _, dbMeta := range dbMetas {
if dbMeta.Name != schema {
continue
}
for _, tblMeta := range dbMeta.Tables {
if tblMeta.Name != table {
continue
}
importer := mydump.NewSchemaImporter(
s.logger,
s.config.sqlMode,
s.db,
s.store,
s.config.concurrency,
)
err := importer.Run(ctx, []*mydump.MDDatabaseMeta{{
Name: dbMeta.Name,
SchemaFile: dbMeta.SchemaFile,
Tables: []*mydump.MDTableMeta{tblMeta},
}})
if err != nil {
return errors.Annotatef(ErrCreateSchema, "source=%s, schema=%s, table=%s, err=%v", s.sourcePath, schema, table, err)
}
return nil
}
return errors.Annotatef(ErrTableNotFound, "schema=%s, table=%s", schema, table)
}
return errors.Annotatef(ErrSchemaNotFound, "schema=%s", schema)
}
func (s *fileScanner) GetTableMetas(context.Context) ([]*TableMeta, error) {
dbMetas := s.loader.GetDatabases()
allFiles := s.loader.GetAllFiles()
var results []*TableMeta
for _, dbMeta := range dbMetas {
for _, tblMeta := range dbMeta.Tables {
tableMeta, err := s.buildTableMeta(dbMeta, tblMeta, allFiles)
if err != nil {
if s.config.skipInvalidFiles {
s.logger.Warn("skipping table due to invalid files", zap.String("database", dbMeta.Name), zap.String("table", tblMeta.Name), zap.Error(err))
continue
}
return nil, err
}
results = append(results, tableMeta)
}
}
return results, nil
}
func (s *fileScanner) GetTotalSize(ctx context.Context) int64 {
var total int64
dbMetas := s.loader.GetDatabases()
for _, dbMeta := range dbMetas {
for _, tblMeta := range dbMeta.Tables {
total += tblMeta.TotalSize
}
}
return total
}
func (s *fileScanner) Close() error {
if s.store != nil {
s.store.Close()
}
return nil
}
func (s *fileScanner) buildTableMeta(
dbMeta *mydump.MDDatabaseMeta,
tblMeta *mydump.MDTableMeta,
allDataFiles map[string]mydump.FileInfo,
) (*TableMeta, error) {
tableMeta := &TableMeta{
Database: dbMeta.Name,
Table: tblMeta.Name,
DataFiles: make([]DataFileMeta, 0, len(tblMeta.DataFiles)),
SchemaFile: tblMeta.SchemaFile.FileMeta.Path,
}
// Process data files
dataFiles, totalSize := processDataFiles(tblMeta.DataFiles)
tableMeta.DataFiles = dataFiles
tableMeta.TotalSize = totalSize
if len(tblMeta.DataFiles) == 0 {
s.logger.Warn("table has no data files", zap.String("database", dbMeta.Name), zap.String("table", tblMeta.Name))
return tableMeta, nil
}
wildcard, err := generateWildcardPath(tblMeta.DataFiles, allDataFiles)
if err != nil {
return nil, errors.Trace(err)
}
uri := s.store.URI()
// import into only support absolute path
uri = strings.TrimPrefix(uri, "file://")
tableMeta.WildcardPath = strings.TrimSuffix(uri, "/") + "/" + wildcard
return tableMeta, nil
}
// processDataFiles converts mydump data files to DataFileMeta and calculates total size
func processDataFiles(files []mydump.FileInfo) ([]DataFileMeta, int64) {
dataFiles := make([]DataFileMeta, 0, len(files))
var totalSize int64
for _, dataFile := range files {
fileMeta := createDataFileMeta(dataFile)
dataFiles = append(dataFiles, fileMeta)
totalSize += dataFile.FileMeta.RealSize
}
return dataFiles, totalSize
}
// createDataFileMeta creates a DataFileMeta from a mydump.DataFile
func createDataFileMeta(file mydump.FileInfo) DataFileMeta {
return DataFileMeta{
Path: file.FileMeta.Path,
Size: file.FileMeta.RealSize,
Format: file.FileMeta.Type,
Compression: file.FileMeta.Compression,
}
}
func (s *fileScanner) GetTableMetaByName(ctx context.Context, db, table string) (*TableMeta, error) {
dbMetas := s.loader.GetDatabases()
allFiles := s.loader.GetAllFiles()
for _, dbMeta := range dbMetas {
if dbMeta.Name != db {
continue
}
for _, tblMeta := range dbMeta.Tables {
if tblMeta.Name != table {
continue
}
return s.buildTableMeta(dbMeta, tblMeta, allFiles)
}
}
return nil, errors.Annotatef(ErrTableNotFound, "table %s.%s not found", db, table)
}

View File

@ -0,0 +1,177 @@
// Copyright 2025 PingCAP, Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package importsdk
import (
"context"
"os"
"path/filepath"
"regexp"
"testing"
"github.com/DATA-DOG/go-sqlmock"
"github.com/pingcap/tidb/pkg/lightning/config"
"github.com/pingcap/tidb/pkg/lightning/mydump"
filter "github.com/pingcap/tidb/pkg/util/table-filter"
"github.com/stretchr/testify/require"
)
func TestCreateDataFileMeta(t *testing.T) {
fi := mydump.FileInfo{
TableName: filter.Table{
Schema: "db",
Name: "table",
},
FileMeta: mydump.SourceFileMeta{
Path: "s3://bucket/path/to/f",
FileSize: 123,
Type: mydump.SourceTypeCSV,
Compression: mydump.CompressionGZ,
RealSize: 456,
},
}
df := createDataFileMeta(fi)
require.Equal(t, "s3://bucket/path/to/f", df.Path)
require.Equal(t, int64(456), df.Size)
require.Equal(t, mydump.SourceTypeCSV, df.Format)
require.Equal(t, mydump.CompressionGZ, df.Compression)
}
func TestProcessDataFiles(t *testing.T) {
files := []mydump.FileInfo{
{FileMeta: mydump.SourceFileMeta{Path: "s3://bucket/a", RealSize: 10}},
{FileMeta: mydump.SourceFileMeta{Path: "s3://bucket/b", RealSize: 20}},
}
dfm, total := processDataFiles(files)
require.Len(t, dfm, 2)
require.Equal(t, int64(30), total)
require.Equal(t, "s3://bucket/a", dfm[0].Path)
require.Equal(t, "s3://bucket/b", dfm[1].Path)
}
func TestFileScanner(t *testing.T) {
tmpDir := t.TempDir()
ctx := context.Background()
require.NoError(t, os.WriteFile(filepath.Join(tmpDir, "db1-schema-create.sql"), []byte("CREATE DATABASE db1;"), 0644))
require.NoError(t, os.WriteFile(filepath.Join(tmpDir, "db1.t1-schema.sql"), []byte("CREATE TABLE t1 (id INT);"), 0644))
require.NoError(t, os.WriteFile(filepath.Join(tmpDir, "db1.t1.001.csv"), []byte("1\n2"), 0644))
db, mock, err := sqlmock.New()
require.NoError(t, err)
defer db.Close()
cfg := defaultSDKConfig()
scanner, err := NewFileScanner(ctx, "file://"+tmpDir, db, cfg)
require.NoError(t, err)
defer scanner.Close()
t.Run("GetTotalSize", func(t *testing.T) {
size := scanner.GetTotalSize(ctx)
require.Equal(t, int64(3), size)
})
t.Run("GetTableMetas", func(t *testing.T) {
metas, err := scanner.GetTableMetas(ctx)
require.NoError(t, err)
require.Len(t, metas, 1)
require.Equal(t, "db1", metas[0].Database)
require.Equal(t, "t1", metas[0].Table)
require.Equal(t, int64(3), metas[0].TotalSize)
require.Len(t, metas[0].DataFiles, 1)
})
t.Run("GetTableMetaByName", func(t *testing.T) {
meta, err := scanner.GetTableMetaByName(ctx, "db1", "t1")
require.NoError(t, err)
require.Equal(t, "db1", meta.Database)
require.Equal(t, "t1", meta.Table)
_, err = scanner.GetTableMetaByName(ctx, "db1", "nonexistent")
require.Error(t, err)
})
t.Run("CreateSchemasAndTables", func(t *testing.T) {
mock.ExpectQuery("SELECT SCHEMA_NAME FROM information_schema.SCHEMATA.*").WillReturnRows(sqlmock.NewRows([]string{"SCHEMA_NAME"}))
mock.ExpectExec(regexp.QuoteMeta("CREATE DATABASE IF NOT EXISTS `db1`")).WillReturnResult(sqlmock.NewResult(0, 0))
mock.ExpectExec(regexp.QuoteMeta("CREATE TABLE IF NOT EXISTS `db1`.`t1`")).WillReturnResult(sqlmock.NewResult(0, 0))
err := scanner.CreateSchemasAndTables(ctx)
require.NoError(t, err)
require.NoError(t, mock.ExpectationsWereMet())
})
t.Run("CreateSchemaAndTableByName", func(t *testing.T) {
mock.ExpectQuery("SELECT SCHEMA_NAME FROM information_schema.SCHEMATA.*").WillReturnRows(sqlmock.NewRows([]string{"SCHEMA_NAME"}))
mock.ExpectExec(regexp.QuoteMeta("CREATE DATABASE IF NOT EXISTS `db1`")).WillReturnResult(sqlmock.NewResult(0, 0))
mock.ExpectExec(regexp.QuoteMeta("CREATE TABLE IF NOT EXISTS `db1`.`t1`")).WillReturnResult(sqlmock.NewResult(0, 0))
err := scanner.CreateSchemaAndTableByName(ctx, "db1", "t1")
require.NoError(t, err)
require.NoError(t, mock.ExpectationsWereMet())
err = scanner.CreateSchemaAndTableByName(ctx, "db1", "nonexistent")
require.Error(t, err)
})
}
func TestFileScannerWithSkipInvalidFiles(t *testing.T) {
tmpDir := t.TempDir()
ctx := context.Background()
require.NoError(t, os.WriteFile(filepath.Join(tmpDir, "data1.csv"), []byte("1"), 0644))
require.NoError(t, os.WriteFile(filepath.Join(tmpDir, "data2.csv"), []byte("1"), 0644))
require.NoError(t, os.WriteFile(filepath.Join(tmpDir, "data3.csv"), []byte("1"), 0644))
db, _, err := sqlmock.New()
require.NoError(t, err)
defer db.Close()
rules := []*config.FileRouteRule{
{
Pattern: "data[1-2].csv",
Schema: "db1",
Table: "t1",
Type: "csv",
},
{
Pattern: "data3.csv",
Schema: "db1",
Table: "t2",
Type: "csv",
},
}
cfg := defaultSDKConfig()
WithFileRouters(rules)(cfg)
scanner, err := NewFileScanner(ctx, "file://"+tmpDir, db, cfg)
require.NoError(t, err)
defer scanner.Close()
metas, err := scanner.GetTableMetas(ctx)
require.Error(t, err)
require.Nil(t, metas)
cfg.skipInvalidFiles = true
scanner2, err := NewFileScanner(ctx, "file://"+tmpDir, db, cfg)
require.NoError(t, err)
defer scanner2.Close()
metas2, err := scanner2.GetTableMetas(ctx)
require.NoError(t, err)
require.Len(t, metas2, 1)
require.Equal(t, "t2", metas2[0].Table)
}

View File

@ -0,0 +1,258 @@
// Copyright 2025 PingCAP, Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package importsdk
import (
"context"
"database/sql"
"fmt"
"strings"
"time"
"github.com/pingcap/errors"
)
// JobManager defines the interface for managing import jobs
type JobManager interface {
SubmitJob(ctx context.Context, query string) (int64, error)
GetJobStatus(ctx context.Context, jobID int64) (*JobStatus, error)
CancelJob(ctx context.Context, jobID int64) error
GetGroupSummary(ctx context.Context, groupKey string) (*GroupStatus, error)
GetJobsByGroup(ctx context.Context, groupKey string) ([]*JobStatus, error)
}
const timeLayout = "2006-01-02 15:04:05"
type jobManager struct {
db *sql.DB
}
// NewJobManager creates a new JobManager
func NewJobManager(db *sql.DB) JobManager {
return &jobManager{
db: db,
}
}
// SubmitJob submits an import job and returns the job ID
func (m *jobManager) SubmitJob(ctx context.Context, query string) (int64, error) {
rows, err := m.db.QueryContext(ctx, query)
if err != nil {
return 0, errors.Trace(err)
}
defer rows.Close()
if rows.Next() {
status, err := scanJobStatus(rows)
if err != nil {
return 0, errors.Trace(err)
}
return status.JobID, nil
}
if err := rows.Err(); err != nil {
return 0, errors.Trace(err)
}
return 0, ErrNoJobIDReturned
}
// GetJobStatus gets the status of an import job
func (m *jobManager) GetJobStatus(ctx context.Context, jobID int64) (*JobStatus, error) {
query := fmt.Sprintf("SHOW IMPORT JOB %d", jobID)
rows, err := m.db.QueryContext(ctx, query)
if err != nil {
return nil, errors.Trace(err)
}
defer rows.Close()
if rows.Next() {
return scanJobStatus(rows)
}
if err := rows.Err(); err != nil {
return nil, errors.Trace(err)
}
return nil, ErrJobNotFound
}
// GetGroupSummary returns aggregated information for the specified group key.
func (m *jobManager) GetGroupSummary(ctx context.Context, groupKey string) (*GroupStatus, error) {
if groupKey == "" {
return nil, ErrInvalidOptions
}
query := fmt.Sprintf("SHOW IMPORT GROUP '%s'", strings.ReplaceAll(groupKey, "'", "''"))
rows, err := m.db.QueryContext(ctx, query)
if err != nil {
return nil, errors.Trace(err)
}
defer rows.Close()
if rows.Next() {
status, err := scanGroupStatus(rows)
if err != nil {
return nil, errors.Trace(err)
}
return status, nil
}
if err := rows.Err(); err != nil {
return nil, errors.Trace(err)
}
return nil, ErrJobNotFound
}
// GetJobsByGroup returns all jobs for the specified group key.
func (m *jobManager) GetJobsByGroup(ctx context.Context, groupKey string) ([]*JobStatus, error) {
if groupKey == "" {
return nil, ErrInvalidOptions
}
query := fmt.Sprintf("SHOW IMPORT JOBS WHERE GROUP_KEY = '%s'", strings.ReplaceAll(groupKey, "'", "''"))
rows, err := m.db.QueryContext(ctx, query)
if err != nil {
return nil, errors.Trace(err)
}
defer rows.Close()
var jobs []*JobStatus
for rows.Next() {
status, err := scanJobStatus(rows)
if err != nil {
return nil, errors.Trace(err)
}
jobs = append(jobs, status)
}
if err := rows.Err(); err != nil {
return nil, errors.Trace(err)
}
return jobs, nil
}
func scanJobStatus(rows *sql.Rows) (*JobStatus, error) {
var (
id int64
groupKey sql.NullString
dataSource string
targetTable string
tableID int64
phase string
status string
sourceFileSize string
importedRows sql.NullInt64
resultMessage sql.NullString
createTimeStr string
startTimeStr sql.NullString
endTimeStr sql.NullString
createdBy string
updateTimeStr sql.NullString
step sql.NullString
processedSize sql.NullString
totalSize sql.NullString
percent sql.NullString
speed sql.NullString
eta sql.NullString
)
err := rows.Scan(
&id, &groupKey, &dataSource, &targetTable, &tableID,
&phase, &status, &sourceFileSize, &importedRows, &resultMessage,
&createTimeStr, &startTimeStr, &endTimeStr, &createdBy, &updateTimeStr,
&step, &processedSize, &totalSize, &percent, &speed, &eta,
)
if err != nil {
return nil, errors.Trace(err)
}
// Parse times
createTime := parseTime(createTimeStr)
startTime := parseNullTime(startTimeStr)
endTime := parseNullTime(endTimeStr)
updateTime := parseNullTime(updateTimeStr)
return &JobStatus{
JobID: id,
GroupKey: groupKey.String,
DataSource: dataSource,
TargetTable: targetTable,
TableID: tableID,
Phase: phase,
Status: status,
SourceFileSize: sourceFileSize,
ImportedRows: importedRows.Int64,
ResultMessage: resultMessage.String,
CreateTime: createTime,
StartTime: startTime,
EndTime: endTime,
CreatedBy: createdBy,
UpdateTime: updateTime,
Step: step.String,
ProcessedSize: processedSize.String,
TotalSize: totalSize.String,
Percent: percent.String,
Speed: speed.String,
ETA: eta.String,
}, nil
}
func scanGroupStatus(rows *sql.Rows) (*GroupStatus, error) {
var (
groupKey string
totalJobs int64
pending int64
running int64
completed int64
failed int64
cancelled int64
firstCreateTime sql.NullString
lastUpdateTime sql.NullString
)
if err := rows.Scan(&groupKey, &totalJobs, &pending, &running, &completed, &failed, &cancelled, &firstCreateTime, &lastUpdateTime); err != nil {
return nil, errors.Trace(err)
}
return &GroupStatus{
GroupKey: groupKey,
TotalJobs: totalJobs,
Pending: pending,
Running: running,
Completed: completed,
Failed: failed,
Cancelled: cancelled,
FirstJobCreateTime: parseNullTime(firstCreateTime),
LastJobUpdateTime: parseNullTime(lastUpdateTime),
}, nil
}
// CancelJob cancels an import job
func (m *jobManager) CancelJob(ctx context.Context, jobID int64) error {
query := fmt.Sprintf("CANCEL IMPORT JOB %d", jobID)
_, err := m.db.ExecContext(ctx, query)
return errors.Trace(err)
}
func parseTime(s string) time.Time {
t, _ := time.Parse(timeLayout, s)
return t
}
func parseNullTime(ns sql.NullString) time.Time {
if !ns.Valid {
return time.Time{}
}
return parseTime(ns.String)
}

View File

@ -0,0 +1,242 @@
// Copyright 2025 PingCAP, Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package importsdk
import (
"context"
"database/sql"
"testing"
"github.com/DATA-DOG/go-sqlmock"
"github.com/stretchr/testify/require"
)
func TestSubmitJob(t *testing.T) {
db, mock, err := sqlmock.New()
require.NoError(t, err)
defer db.Close()
manager := NewJobManager(db)
ctx := context.Background()
sqlQuery := "IMPORT INTO ..."
// Columns expected by scanJobStatus
cols := []string{
"Job_ID", "Group_Key", "Data_Source", "Target_Table", "Table_ID",
"Phase", "Status", "Source_File_Size", "Imported_Rows", "Result_Message",
"Create_Time", "Start_Time", "End_Time", "Created_By", "Update_Time",
"Step", "Processed_Size", "Total_Size", "Percent", "Speed", "ETA",
}
// Case 1: Success
rows := sqlmock.NewRows(cols).AddRow(
int64(123), "", "s3://bucket/file.csv", "db.table", int64(1),
"import", "finished", "100MB", int64(1000), "success",
"2023-01-01 10:00:00", "2023-01-01 10:00:01", "2023-01-01 10:00:02", "user", "2023-01-01 10:00:02",
"", "100MB", "100MB", "100%", "10MB/s", "0s",
)
mock.ExpectQuery(sqlQuery).WillReturnRows(rows)
jobID, err := manager.SubmitJob(ctx, sqlQuery)
require.NoError(t, err)
require.Equal(t, int64(123), jobID)
// Case 2: No rows returned
mock.ExpectQuery(sqlQuery).WillReturnRows(sqlmock.NewRows(cols))
_, err = manager.SubmitJob(ctx, sqlQuery)
require.Error(t, err)
require.Contains(t, err.Error(), "no job id returned")
// Case 3: Error
mock.ExpectQuery(sqlQuery).WillReturnError(sql.ErrConnDone)
_, err = manager.SubmitJob(ctx, sqlQuery)
require.Error(t, err)
require.NoError(t, mock.ExpectationsWereMet())
}
func TestGetJobStatus(t *testing.T) {
db, mock, err := sqlmock.New()
require.NoError(t, err)
defer db.Close()
manager := NewJobManager(db)
ctx := context.Background()
jobID := int64(123)
// Columns expected by GetJobStatus (same as scanJobStatus)
cols := []string{
"Job_ID", "Group_Key", "Data_Source", "Target_Table", "Table_ID",
"Phase", "Status", "Source_File_Size", "Imported_Rows", "Result_Message",
"Create_Time", "Start_Time", "End_Time", "Created_By", "Update_Time",
"Step", "Processed_Size", "Total_Size", "Percent", "Speed", "ETA",
}
// Case 1: Success
rows := sqlmock.NewRows(cols).AddRow(
jobID, "", "s3://bucket/file.csv", "db.table", int64(1),
"import", "finished", "100MB", int64(1000), "success",
"2023-01-01 10:00:00", "2023-01-01 10:00:01", "2023-01-01 10:00:02", "user", "2023-01-01 10:00:02",
"", "100MB", "100MB", "100%", "10MB/s", "0s",
)
mock.ExpectQuery("SHOW IMPORT JOB 123").WillReturnRows(rows)
status, err := manager.GetJobStatus(ctx, jobID)
require.NoError(t, err)
require.Equal(t, jobID, status.JobID)
require.Equal(t, "finished", status.Status)
// Case 2: Job not found
mock.ExpectQuery("SHOW IMPORT JOB 123").WillReturnRows(sqlmock.NewRows(cols))
_, err = manager.GetJobStatus(ctx, jobID)
require.ErrorIs(t, err, ErrJobNotFound)
// Case 3: Error
mock.ExpectQuery("SHOW IMPORT JOB 123").WillReturnError(sql.ErrConnDone)
_, err = manager.GetJobStatus(ctx, jobID)
require.Error(t, err)
require.NoError(t, mock.ExpectationsWereMet())
}
func TestCancelJob(t *testing.T) {
db, mock, err := sqlmock.New()
require.NoError(t, err)
defer db.Close()
manager := NewJobManager(db)
ctx := context.Background()
jobID := int64(123)
// Case 1: Success
mock.ExpectExec("CANCEL IMPORT JOB 123").WillReturnResult(sqlmock.NewResult(0, 0))
err = manager.CancelJob(ctx, jobID)
require.NoError(t, err)
// Case 2: Error
mock.ExpectExec("CANCEL IMPORT JOB 123").WillReturnError(sql.ErrConnDone)
err = manager.CancelJob(ctx, jobID)
require.Error(t, err)
require.NoError(t, mock.ExpectationsWereMet())
}
func TestGetGroupSummary(t *testing.T) {
db, mock, err := sqlmock.New()
require.NoError(t, err)
defer db.Close()
manager := NewJobManager(db)
ctx := context.Background()
groupKey := "test_group"
// Columns expected by scanGroupStatus
cols := []string{
"Group_Key", "Total_Jobs", "Pending", "Running", "Completed", "Failed", "Cancelled",
"First_Job_Create_Time", "Last_Job_Update_Time",
}
// Case 1: Success
rows := sqlmock.NewRows(cols).AddRow(
groupKey, int64(10), int64(1), int64(2), int64(3), int64(2), int64(2),
"2023-01-01 10:00:00", "2023-01-01 12:00:00",
)
mock.ExpectQuery("SHOW IMPORT GROUP 'test_group'").WillReturnRows(rows)
summary, err := manager.GetGroupSummary(ctx, groupKey)
require.NoError(t, err)
require.Equal(t, groupKey, summary.GroupKey)
require.Equal(t, int64(10), summary.TotalJobs)
require.Equal(t, int64(1), summary.Pending)
require.Equal(t, int64(2), summary.Running)
require.Equal(t, int64(3), summary.Completed)
require.Equal(t, int64(2), summary.Failed)
require.Equal(t, int64(2), summary.Cancelled)
// Case 2: Empty Group Key
_, err = manager.GetGroupSummary(ctx, "")
require.ErrorIs(t, err, ErrInvalidOptions)
// Case 3: Group not found
mock.ExpectQuery("SHOW IMPORT GROUP 'test_group'").WillReturnRows(sqlmock.NewRows(cols))
_, err = manager.GetGroupSummary(ctx, groupKey)
require.ErrorIs(t, err, ErrJobNotFound)
// Case 4: Error
mock.ExpectQuery("SHOW IMPORT GROUP 'test_group'").WillReturnError(sql.ErrConnDone)
_, err = manager.GetGroupSummary(ctx, groupKey)
require.Error(t, err)
require.NoError(t, mock.ExpectationsWereMet())
}
func TestGetJobsByGroup(t *testing.T) {
db, mock, err := sqlmock.New()
require.NoError(t, err)
defer db.Close()
manager := NewJobManager(db)
ctx := context.Background()
groupKey := "test_group"
// Columns expected by scanJobStatus
cols := []string{
"Job_ID", "Group_Key", "Data_Source", "Target_Table", "Table_ID",
"Phase", "Status", "Source_File_Size", "Imported_Rows", "Result_Message",
"Create_Time", "Start_Time", "End_Time", "Created_By", "Update_Time",
"Step", "Processed_Size", "Total_Size", "Percent", "Speed", "ETA",
}
// Case 1: Success
rows := sqlmock.NewRows(cols).
AddRow(
int64(1), groupKey, "s3://bucket/file1.csv", "db.t1", int64(1),
"import", "finished", "100MB", int64(1000), "success",
"2023-01-01 10:00:00", "2023-01-01 10:00:01", "2023-01-01 10:00:02", "user", "2023-01-01 10:00:02",
"", "100MB", "100MB", "100%", "10MB/s", "0s",
).
AddRow(
int64(2), groupKey, "s3://bucket/file2.csv", "db.t2", int64(2),
"import", "running", "200MB", int64(500), "",
"2023-01-01 10:00:00", "2023-01-01 10:00:01", "", "user", "2023-01-01 10:00:02",
"", "100MB", "200MB", "50%", "10MB/s", "10s",
)
mock.ExpectQuery("SHOW IMPORT JOBS WHERE GROUP_KEY = 'test_group'").WillReturnRows(rows)
jobs, err := manager.GetJobsByGroup(ctx, groupKey)
require.NoError(t, err)
require.Len(t, jobs, 2)
require.Equal(t, int64(1), jobs[0].JobID)
require.Equal(t, "finished", jobs[0].Status)
require.Equal(t, int64(2), jobs[1].JobID)
require.Equal(t, "running", jobs[1].Status)
// Case 2: Empty Group Key
_, err = manager.GetJobsByGroup(ctx, "")
require.ErrorIs(t, err, ErrInvalidOptions)
// Case 3: No jobs found (empty list)
mock.ExpectQuery("SHOW IMPORT JOBS WHERE GROUP_KEY = 'test_group'").WillReturnRows(sqlmock.NewRows(cols))
jobs, err = manager.GetJobsByGroup(ctx, groupKey)
require.NoError(t, err)
require.Empty(t, jobs)
// Case 4: Error
mock.ExpectQuery("SHOW IMPORT JOBS WHERE GROUP_KEY = 'test_group'").WillReturnError(sql.ErrConnDone)
_, err = manager.GetJobsByGroup(ctx, groupKey)
require.Error(t, err)
require.NoError(t, mock.ExpectationsWereMet())
}

View File

@ -0,0 +1,12 @@
load("@io_bazel_rules_go//go:def.bzl", "go_library")
go_library(
name = "mock",
srcs = ["sdk_mock.go"],
importpath = "github.com/pingcap/tidb/pkg/importsdk/mock",
visibility = ["//visibility:public"],
deps = [
"//pkg/importsdk",
"@org_uber_go_mock//gomock",
],
)

View File

@ -0,0 +1,480 @@
// Code generated by MockGen. DO NOT EDIT.
// Source: github.com/pingcap/tidb/pkg/importsdk (interfaces: FileScanner,JobManager,SQLGenerator,SDK)
//
// Generated by this command:
//
// mockgen -package mock github.com/pingcap/tidb/pkg/importsdk FileScanner,JobManager,SQLGenerator,SDK
//
// Package mock is a generated GoMock package.
package mock
import (
context "context"
reflect "reflect"
importsdk "github.com/pingcap/tidb/pkg/importsdk"
gomock "go.uber.org/mock/gomock"
)
// MockFileScanner is a mock of FileScanner interface.
type MockFileScanner struct {
ctrl *gomock.Controller
recorder *MockFileScannerMockRecorder
}
// MockFileScannerMockRecorder is the mock recorder for MockFileScanner.
type MockFileScannerMockRecorder struct {
mock *MockFileScanner
}
// NewMockFileScanner creates a new mock instance.
func NewMockFileScanner(ctrl *gomock.Controller) *MockFileScanner {
mock := &MockFileScanner{ctrl: ctrl}
mock.recorder = &MockFileScannerMockRecorder{mock}
return mock
}
// EXPECT returns an object that allows the caller to indicate expected use.
func (m *MockFileScanner) EXPECT() *MockFileScannerMockRecorder {
return m.recorder
}
// ISGOMOCK indicates that this struct is a gomock mock.
func (m *MockFileScanner) ISGOMOCK() struct{} {
return struct{}{}
}
// Close mocks base method.
func (m *MockFileScanner) Close() error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "Close")
ret0, _ := ret[0].(error)
return ret0
}
// Close indicates an expected call of Close.
func (mr *MockFileScannerMockRecorder) Close() *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Close", reflect.TypeOf((*MockFileScanner)(nil).Close))
}
// CreateSchemaAndTableByName mocks base method.
func (m *MockFileScanner) CreateSchemaAndTableByName(arg0 context.Context, arg1, arg2 string) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "CreateSchemaAndTableByName", arg0, arg1, arg2)
ret0, _ := ret[0].(error)
return ret0
}
// CreateSchemaAndTableByName indicates an expected call of CreateSchemaAndTableByName.
func (mr *MockFileScannerMockRecorder) CreateSchemaAndTableByName(arg0, arg1, arg2 any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CreateSchemaAndTableByName", reflect.TypeOf((*MockFileScanner)(nil).CreateSchemaAndTableByName), arg0, arg1, arg2)
}
// CreateSchemasAndTables mocks base method.
func (m *MockFileScanner) CreateSchemasAndTables(arg0 context.Context) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "CreateSchemasAndTables", arg0)
ret0, _ := ret[0].(error)
return ret0
}
// CreateSchemasAndTables indicates an expected call of CreateSchemasAndTables.
func (mr *MockFileScannerMockRecorder) CreateSchemasAndTables(arg0 any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CreateSchemasAndTables", reflect.TypeOf((*MockFileScanner)(nil).CreateSchemasAndTables), arg0)
}
// GetTableMetaByName mocks base method.
func (m *MockFileScanner) GetTableMetaByName(arg0 context.Context, arg1, arg2 string) (*importsdk.TableMeta, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "GetTableMetaByName", arg0, arg1, arg2)
ret0, _ := ret[0].(*importsdk.TableMeta)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// GetTableMetaByName indicates an expected call of GetTableMetaByName.
func (mr *MockFileScannerMockRecorder) GetTableMetaByName(arg0, arg1, arg2 any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetTableMetaByName", reflect.TypeOf((*MockFileScanner)(nil).GetTableMetaByName), arg0, arg1, arg2)
}
// GetTableMetas mocks base method.
func (m *MockFileScanner) GetTableMetas(arg0 context.Context) ([]*importsdk.TableMeta, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "GetTableMetas", arg0)
ret0, _ := ret[0].([]*importsdk.TableMeta)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// GetTableMetas indicates an expected call of GetTableMetas.
func (mr *MockFileScannerMockRecorder) GetTableMetas(arg0 any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetTableMetas", reflect.TypeOf((*MockFileScanner)(nil).GetTableMetas), arg0)
}
// GetTotalSize mocks base method.
func (m *MockFileScanner) GetTotalSize(arg0 context.Context) int64 {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "GetTotalSize", arg0)
ret0, _ := ret[0].(int64)
return ret0
}
// GetTotalSize indicates an expected call of GetTotalSize.
func (mr *MockFileScannerMockRecorder) GetTotalSize(arg0 any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetTotalSize", reflect.TypeOf((*MockFileScanner)(nil).GetTotalSize), arg0)
}
// MockJobManager is a mock of JobManager interface.
type MockJobManager struct {
ctrl *gomock.Controller
recorder *MockJobManagerMockRecorder
}
// MockJobManagerMockRecorder is the mock recorder for MockJobManager.
type MockJobManagerMockRecorder struct {
mock *MockJobManager
}
// NewMockJobManager creates a new mock instance.
func NewMockJobManager(ctrl *gomock.Controller) *MockJobManager {
mock := &MockJobManager{ctrl: ctrl}
mock.recorder = &MockJobManagerMockRecorder{mock}
return mock
}
// EXPECT returns an object that allows the caller to indicate expected use.
func (m *MockJobManager) EXPECT() *MockJobManagerMockRecorder {
return m.recorder
}
// ISGOMOCK indicates that this struct is a gomock mock.
func (m *MockJobManager) ISGOMOCK() struct{} {
return struct{}{}
}
// CancelJob mocks base method.
func (m *MockJobManager) CancelJob(arg0 context.Context, arg1 int64) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "CancelJob", arg0, arg1)
ret0, _ := ret[0].(error)
return ret0
}
// CancelJob indicates an expected call of CancelJob.
func (mr *MockJobManagerMockRecorder) CancelJob(arg0, arg1 any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CancelJob", reflect.TypeOf((*MockJobManager)(nil).CancelJob), arg0, arg1)
}
// GetGroupSummary mocks base method.
func (m *MockJobManager) GetGroupSummary(arg0 context.Context, arg1 string) (*importsdk.GroupStatus, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "GetGroupSummary", arg0, arg1)
ret0, _ := ret[0].(*importsdk.GroupStatus)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// GetGroupSummary indicates an expected call of GetGroupSummary.
func (mr *MockJobManagerMockRecorder) GetGroupSummary(arg0, arg1 any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetGroupSummary", reflect.TypeOf((*MockJobManager)(nil).GetGroupSummary), arg0, arg1)
}
// GetJobStatus mocks base method.
func (m *MockJobManager) GetJobStatus(arg0 context.Context, arg1 int64) (*importsdk.JobStatus, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "GetJobStatus", arg0, arg1)
ret0, _ := ret[0].(*importsdk.JobStatus)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// GetJobStatus indicates an expected call of GetJobStatus.
func (mr *MockJobManagerMockRecorder) GetJobStatus(arg0, arg1 any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetJobStatus", reflect.TypeOf((*MockJobManager)(nil).GetJobStatus), arg0, arg1)
}
// GetJobsByGroup mocks base method.
func (m *MockJobManager) GetJobsByGroup(arg0 context.Context, arg1 string) ([]*importsdk.JobStatus, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "GetJobsByGroup", arg0, arg1)
ret0, _ := ret[0].([]*importsdk.JobStatus)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// GetJobsByGroup indicates an expected call of GetJobsByGroup.
func (mr *MockJobManagerMockRecorder) GetJobsByGroup(arg0, arg1 any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetJobsByGroup", reflect.TypeOf((*MockJobManager)(nil).GetJobsByGroup), arg0, arg1)
}
// SubmitJob mocks base method.
func (m *MockJobManager) SubmitJob(arg0 context.Context, arg1 string) (int64, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "SubmitJob", arg0, arg1)
ret0, _ := ret[0].(int64)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// SubmitJob indicates an expected call of SubmitJob.
func (mr *MockJobManagerMockRecorder) SubmitJob(arg0, arg1 any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SubmitJob", reflect.TypeOf((*MockJobManager)(nil).SubmitJob), arg0, arg1)
}
// MockSQLGenerator is a mock of SQLGenerator interface.
type MockSQLGenerator struct {
ctrl *gomock.Controller
recorder *MockSQLGeneratorMockRecorder
}
// MockSQLGeneratorMockRecorder is the mock recorder for MockSQLGenerator.
type MockSQLGeneratorMockRecorder struct {
mock *MockSQLGenerator
}
// NewMockSQLGenerator creates a new mock instance.
func NewMockSQLGenerator(ctrl *gomock.Controller) *MockSQLGenerator {
mock := &MockSQLGenerator{ctrl: ctrl}
mock.recorder = &MockSQLGeneratorMockRecorder{mock}
return mock
}
// EXPECT returns an object that allows the caller to indicate expected use.
func (m *MockSQLGenerator) EXPECT() *MockSQLGeneratorMockRecorder {
return m.recorder
}
// ISGOMOCK indicates that this struct is a gomock mock.
func (m *MockSQLGenerator) ISGOMOCK() struct{} {
return struct{}{}
}
// GenerateImportSQL mocks base method.
func (m *MockSQLGenerator) GenerateImportSQL(arg0 *importsdk.TableMeta, arg1 *importsdk.ImportOptions) (string, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "GenerateImportSQL", arg0, arg1)
ret0, _ := ret[0].(string)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// GenerateImportSQL indicates an expected call of GenerateImportSQL.
func (mr *MockSQLGeneratorMockRecorder) GenerateImportSQL(arg0, arg1 any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GenerateImportSQL", reflect.TypeOf((*MockSQLGenerator)(nil).GenerateImportSQL), arg0, arg1)
}
// MockSDK is a mock of SDK interface.
type MockSDK struct {
ctrl *gomock.Controller
recorder *MockSDKMockRecorder
}
// MockSDKMockRecorder is the mock recorder for MockSDK.
type MockSDKMockRecorder struct {
mock *MockSDK
}
// NewMockSDK creates a new mock instance.
func NewMockSDK(ctrl *gomock.Controller) *MockSDK {
mock := &MockSDK{ctrl: ctrl}
mock.recorder = &MockSDKMockRecorder{mock}
return mock
}
// EXPECT returns an object that allows the caller to indicate expected use.
func (m *MockSDK) EXPECT() *MockSDKMockRecorder {
return m.recorder
}
// ISGOMOCK indicates that this struct is a gomock mock.
func (m *MockSDK) ISGOMOCK() struct{} {
return struct{}{}
}
// CancelJob mocks base method.
func (m *MockSDK) CancelJob(arg0 context.Context, arg1 int64) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "CancelJob", arg0, arg1)
ret0, _ := ret[0].(error)
return ret0
}
// CancelJob indicates an expected call of CancelJob.
func (mr *MockSDKMockRecorder) CancelJob(arg0, arg1 any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CancelJob", reflect.TypeOf((*MockSDK)(nil).CancelJob), arg0, arg1)
}
// Close mocks base method.
func (m *MockSDK) Close() error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "Close")
ret0, _ := ret[0].(error)
return ret0
}
// Close indicates an expected call of Close.
func (mr *MockSDKMockRecorder) Close() *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Close", reflect.TypeOf((*MockSDK)(nil).Close))
}
// CreateSchemaAndTableByName mocks base method.
func (m *MockSDK) CreateSchemaAndTableByName(arg0 context.Context, arg1, arg2 string) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "CreateSchemaAndTableByName", arg0, arg1, arg2)
ret0, _ := ret[0].(error)
return ret0
}
// CreateSchemaAndTableByName indicates an expected call of CreateSchemaAndTableByName.
func (mr *MockSDKMockRecorder) CreateSchemaAndTableByName(arg0, arg1, arg2 any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CreateSchemaAndTableByName", reflect.TypeOf((*MockSDK)(nil).CreateSchemaAndTableByName), arg0, arg1, arg2)
}
// CreateSchemasAndTables mocks base method.
func (m *MockSDK) CreateSchemasAndTables(arg0 context.Context) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "CreateSchemasAndTables", arg0)
ret0, _ := ret[0].(error)
return ret0
}
// CreateSchemasAndTables indicates an expected call of CreateSchemasAndTables.
func (mr *MockSDKMockRecorder) CreateSchemasAndTables(arg0 any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CreateSchemasAndTables", reflect.TypeOf((*MockSDK)(nil).CreateSchemasAndTables), arg0)
}
// GenerateImportSQL mocks base method.
func (m *MockSDK) GenerateImportSQL(arg0 *importsdk.TableMeta, arg1 *importsdk.ImportOptions) (string, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "GenerateImportSQL", arg0, arg1)
ret0, _ := ret[0].(string)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// GenerateImportSQL indicates an expected call of GenerateImportSQL.
func (mr *MockSDKMockRecorder) GenerateImportSQL(arg0, arg1 any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GenerateImportSQL", reflect.TypeOf((*MockSDK)(nil).GenerateImportSQL), arg0, arg1)
}
// GetGroupSummary mocks base method.
func (m *MockSDK) GetGroupSummary(arg0 context.Context, arg1 string) (*importsdk.GroupStatus, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "GetGroupSummary", arg0, arg1)
ret0, _ := ret[0].(*importsdk.GroupStatus)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// GetGroupSummary indicates an expected call of GetGroupSummary.
func (mr *MockSDKMockRecorder) GetGroupSummary(arg0, arg1 any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetGroupSummary", reflect.TypeOf((*MockSDK)(nil).GetGroupSummary), arg0, arg1)
}
// GetJobStatus mocks base method.
func (m *MockSDK) GetJobStatus(arg0 context.Context, arg1 int64) (*importsdk.JobStatus, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "GetJobStatus", arg0, arg1)
ret0, _ := ret[0].(*importsdk.JobStatus)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// GetJobStatus indicates an expected call of GetJobStatus.
func (mr *MockSDKMockRecorder) GetJobStatus(arg0, arg1 any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetJobStatus", reflect.TypeOf((*MockSDK)(nil).GetJobStatus), arg0, arg1)
}
// GetJobsByGroup mocks base method.
func (m *MockSDK) GetJobsByGroup(arg0 context.Context, arg1 string) ([]*importsdk.JobStatus, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "GetJobsByGroup", arg0, arg1)
ret0, _ := ret[0].([]*importsdk.JobStatus)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// GetJobsByGroup indicates an expected call of GetJobsByGroup.
func (mr *MockSDKMockRecorder) GetJobsByGroup(arg0, arg1 any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetJobsByGroup", reflect.TypeOf((*MockSDK)(nil).GetJobsByGroup), arg0, arg1)
}
// GetTableMetaByName mocks base method.
func (m *MockSDK) GetTableMetaByName(arg0 context.Context, arg1, arg2 string) (*importsdk.TableMeta, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "GetTableMetaByName", arg0, arg1, arg2)
ret0, _ := ret[0].(*importsdk.TableMeta)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// GetTableMetaByName indicates an expected call of GetTableMetaByName.
func (mr *MockSDKMockRecorder) GetTableMetaByName(arg0, arg1, arg2 any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetTableMetaByName", reflect.TypeOf((*MockSDK)(nil).GetTableMetaByName), arg0, arg1, arg2)
}
// GetTableMetas mocks base method.
func (m *MockSDK) GetTableMetas(arg0 context.Context) ([]*importsdk.TableMeta, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "GetTableMetas", arg0)
ret0, _ := ret[0].([]*importsdk.TableMeta)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// GetTableMetas indicates an expected call of GetTableMetas.
func (mr *MockSDKMockRecorder) GetTableMetas(arg0 any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetTableMetas", reflect.TypeOf((*MockSDK)(nil).GetTableMetas), arg0)
}
// GetTotalSize mocks base method.
func (m *MockSDK) GetTotalSize(arg0 context.Context) int64 {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "GetTotalSize", arg0)
ret0, _ := ret[0].(int64)
return ret0
}
// GetTotalSize indicates an expected call of GetTotalSize.
func (mr *MockSDKMockRecorder) GetTotalSize(arg0 any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetTotalSize", reflect.TypeOf((*MockSDK)(nil).GetTotalSize), arg0)
}
// SubmitJob mocks base method.
func (m *MockSDK) SubmitJob(arg0 context.Context, arg1 string) (int64, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "SubmitJob", arg0, arg1)
ret0, _ := ret[0].(int64)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// SubmitJob indicates an expected call of SubmitJob.
func (mr *MockSDKMockRecorder) SubmitJob(arg0, arg1 any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SubmitJob", reflect.TypeOf((*MockSDK)(nil).SubmitJob), arg0, arg1)
}

119
pkg/importsdk/model.go Normal file
View File

@ -0,0 +1,119 @@
// Copyright 2025 PingCAP, Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package importsdk
import (
"time"
"github.com/pingcap/tidb/pkg/lightning/config"
"github.com/pingcap/tidb/pkg/lightning/mydump"
)
// TableMeta contains metadata for a table to be imported
type TableMeta struct {
Database string
Table string
DataFiles []DataFileMeta
TotalSize int64 // In bytes
WildcardPath string // Wildcard pattern that matches only this table's data files
SchemaFile string // Path to the table schema file, if available
}
// DataFileMeta contains metadata for a data file
type DataFileMeta struct {
Path string
Size int64
Format mydump.SourceType
Compression mydump.Compression
}
// ImportOptions wraps the options for IMPORT INTO statement.
// It reuses structures from executor/importer where possible.
type ImportOptions struct {
Format string
CSVConfig *config.CSVConfig
Thread int
DiskQuota string
MaxWriteSpeed string
SplitFile bool
RecordErrors int64
Detached bool
CloudStorageURI string
GroupKey string
SkipRows int
CharacterSet string
ChecksumTable string
DisableTiKVImportMode bool
DisablePrecheck bool
ResourceParameters string
}
// GroupStatus represents the aggregated status for a group of import jobs.
type GroupStatus struct {
GroupKey string
TotalJobs int64
Pending int64
Running int64
Completed int64
Failed int64
Cancelled int64
FirstJobCreateTime time.Time
LastJobUpdateTime time.Time
}
// JobStatus represents the status of an import job.
type JobStatus struct {
JobID int64
GroupKey string
DataSource string
TargetTable string
TableID int64
Phase string
Status string
SourceFileSize string
ImportedRows int64
ResultMessage string
CreateTime time.Time
StartTime time.Time
EndTime time.Time
CreatedBy string
UpdateTime time.Time
Step string
ProcessedSize string
TotalSize string
Percent string
Speed string
ETA string
}
// IsFinished returns true if the job is finished successfully.
func (s *JobStatus) IsFinished() bool {
return s.Status == "finished"
}
// IsFailed returns true if the job failed.
func (s *JobStatus) IsFailed() bool {
return s.Status == "failed"
}
// IsCancelled returns true if the job was cancelled.
func (s *JobStatus) IsCancelled() bool {
return s.Status == "cancelled"
}
// IsCompleted returns true if the job is in a terminal state.
func (s *JobStatus) IsCompleted() bool {
return s.IsFinished() || s.IsFailed() || s.IsCancelled()
}

View File

@ -0,0 +1,82 @@
// Copyright 2025 PingCAP, Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package importsdk
import (
"testing"
"github.com/stretchr/testify/require"
)
func TestJobStatus(t *testing.T) {
tests := []struct {
status string
isFinished bool
isFailed bool
isCancelled bool
isCompleted bool
}{
{
status: "finished",
isFinished: true,
isFailed: false,
isCancelled: false,
isCompleted: true,
},
{
status: "failed",
isFinished: false,
isFailed: true,
isCancelled: false,
isCompleted: true,
},
{
status: "cancelled",
isFinished: false,
isFailed: false,
isCancelled: true,
isCompleted: true,
},
{
status: "running",
isFinished: false,
isFailed: false,
isCancelled: false,
isCompleted: false,
},
{
status: "pending",
isFinished: false,
isFailed: false,
isCancelled: false,
isCompleted: false,
},
{
status: "unknown",
isFinished: false,
isFailed: false,
isCancelled: false,
isCompleted: false,
},
}
for _, tt := range tests {
job := &JobStatus{Status: tt.status}
require.Equal(t, tt.isFinished, job.IsFinished(), "status: %s", tt.status)
require.Equal(t, tt.isFailed, job.IsFailed(), "status: %s", tt.status)
require.Equal(t, tt.isCancelled, job.IsCancelled(), "status: %s", tt.status)
require.Equal(t, tt.isCompleted, job.IsCompleted(), "status: %s", tt.status)
}
}

175
pkg/importsdk/pattern.go Normal file
View File

@ -0,0 +1,175 @@
// Copyright 2025 PingCAP, Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package importsdk
import (
"path/filepath"
"strings"
"github.com/pingcap/errors"
"github.com/pingcap/tidb/pkg/lightning/mydump"
)
// generateWildcardPath creates a wildcard pattern path that matches only this table's files
func generateWildcardPath(
files []mydump.FileInfo,
allFiles map[string]mydump.FileInfo,
) (string, error) {
tableFiles := make(map[string]struct{}, len(files))
for _, df := range files {
tableFiles[df.FileMeta.Path] = struct{}{}
}
if len(files) == 0 {
return "", errors.Annotate(ErrNoTableDataFiles, "cannot generate wildcard pattern because the table has no data files")
}
// If there's only one file, we can just return its path
if len(files) == 1 {
return files[0].FileMeta.Path, nil
}
// Try Mydumper-specific pattern first
p := generateMydumperPattern(files[0])
if p != "" && isValidPattern(p, tableFiles, allFiles) {
return p, nil
}
// Fallback to generic prefix/suffix pattern
paths := make([]string, 0, len(files))
for _, file := range files {
paths = append(paths, file.FileMeta.Path)
}
p = generatePrefixSuffixPattern(paths)
if p != "" && isValidPattern(p, tableFiles, allFiles) {
return p, nil
}
return "", errors.Annotatef(ErrWildcardNotSpecific, "failed to find a wildcard that matches all and only the table's files.")
}
// isValidPattern checks if a wildcard pattern matches only the table's files
func isValidPattern(pattern string, tableFiles map[string]struct{}, allFiles map[string]mydump.FileInfo) bool {
if pattern == "" {
return false
}
for path := range allFiles {
isMatch, err := filepath.Match(pattern, path)
if err != nil {
return false // Invalid pattern
}
_, isTableFile := tableFiles[path]
// If pattern matches a file that's not from our table, it's invalid
if isMatch && !isTableFile {
return false
}
// If pattern doesn't match our table's file, it's also invalid
if !isMatch && isTableFile {
return false
}
}
return true
}
// generateMydumperPattern generates a wildcard pattern for Mydumper-formatted data files
// belonging to a specific table, based on their naming convention.
// It returns a pattern string that matches all data files for the table, or an empty string if not applicable.
func generateMydumperPattern(file mydump.FileInfo) string {
dbName, tableName := file.TableName.Schema, file.TableName.Name
if dbName == "" || tableName == "" {
return ""
}
// compute dirPrefix and basename
full := file.FileMeta.Path
dirPrefix, name := "", full
if idx := strings.LastIndex(full, "/"); idx >= 0 {
dirPrefix = full[:idx+1]
name = full[idx+1:]
}
// compression ext from filename when compression exists (last suffix like .gz/.zst)
compExt := ""
if file.FileMeta.Compression != mydump.CompressionNone {
compExt = filepath.Ext(name)
}
// data ext after stripping compression ext
base := strings.TrimSuffix(name, compExt)
dataExt := filepath.Ext(base)
return dirPrefix + dbName + "." + tableName + ".*" + dataExt + compExt
}
// longestCommonPrefix finds the longest string that is a prefix of all strings in the slice
func longestCommonPrefix(strs []string) string {
if len(strs) == 0 {
return ""
}
prefix := strs[0]
for _, s := range strs[1:] {
i := 0
for i < len(prefix) && i < len(s) && prefix[i] == s[i] {
i++
}
prefix = prefix[:i]
if prefix == "" {
break
}
}
return prefix
}
// longestCommonSuffix finds the longest string that is a suffix of all strings in the slice, starting after the given prefix length
func longestCommonSuffix(strs []string, prefixLen int) string {
if len(strs) == 0 {
return ""
}
suffix := strs[0][prefixLen:]
for _, s := range strs[1:] {
remaining := s[prefixLen:]
i := 0
for i < len(suffix) && i < len(remaining) && suffix[len(suffix)-i-1] == remaining[len(remaining)-i-1] {
i++
}
suffix = suffix[len(suffix)-i:]
if suffix == "" {
break
}
}
return suffix
}
// generatePrefixSuffixPattern returns a wildcard pattern that matches all and only the given paths
// by finding the longest common prefix and suffix among them, and placing a '*' wildcard in between.
func generatePrefixSuffixPattern(paths []string) string {
if len(paths) == 0 {
return ""
}
if len(paths) == 1 {
return paths[0]
}
prefix := longestCommonPrefix(paths)
suffix := longestCommonSuffix(paths, len(prefix))
return prefix + "*" + suffix
}

View File

@ -0,0 +1,208 @@
// Copyright 2025 PingCAP, Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package importsdk
import (
"testing"
"github.com/pingcap/tidb/pkg/lightning/log"
"github.com/pingcap/tidb/pkg/lightning/mydump"
filter "github.com/pingcap/tidb/pkg/util/table-filter"
"github.com/stretchr/testify/require"
)
func TestLongestCommonPrefix(t *testing.T) {
strs := []string{"s3://bucket/foo/bar/baz1", "s3://bucket/foo/bar/baz2", "s3://bucket/foo/bar/baz"}
p := longestCommonPrefix(strs)
require.Equal(t, "s3://bucket/foo/bar/baz", p)
// no common prefix
require.Equal(t, "", longestCommonPrefix([]string{"a", "b"}))
// empty inputs
require.Equal(t, "", longestCommonPrefix(nil))
require.Equal(t, "", longestCommonPrefix([]string{}))
}
func TestLongestCommonSuffix(t *testing.T) {
strs := []string{"abcXYZ", "defXYZ", "XYZ"}
s := longestCommonSuffix(strs, 0)
require.Equal(t, "XYZ", s)
// no common suffix
require.Equal(t, "", longestCommonSuffix([]string{"a", "b"}, 0))
// empty inputs
require.Equal(t, "", longestCommonSuffix(nil, 0))
require.Equal(t, "", longestCommonSuffix([]string{}, 0))
// same prefix
require.Equal(t, "", longestCommonSuffix([]string{"abc", "abc"}, 3))
require.Equal(t, "f", longestCommonSuffix([]string{"abcdf", "abcef"}, 3))
}
func TestGeneratePrefixSuffixPattern(t *testing.T) {
paths := []string{"pre_middle_suf", "pre_most_suf"}
pattern := generatePrefixSuffixPattern(paths)
// common prefix "pre_m", suffix "_suf"
require.Equal(t, "pre_m*_suf", pattern)
// empty inputs
require.Equal(t, "", generatePrefixSuffixPattern(nil))
require.Equal(t, "", generatePrefixSuffixPattern([]string{}))
// only one file
require.Equal(t, "pre_middle_suf", generatePrefixSuffixPattern([]string{"pre_middle_suf"}))
// no common prefix/suffix
paths2 := []string{"foo", "bar"}
require.Equal(t, "*", generatePrefixSuffixPattern(paths2))
// overlapping prefix/suffix
paths3 := []string{"aaabaaa", "aaa"}
require.Equal(t, "aaa*", generatePrefixSuffixPattern(paths3))
}
func generateFileMetas(t *testing.T, paths []string) []mydump.FileInfo {
t.Helper()
files := make([]mydump.FileInfo, 0, len(paths))
fileRouter, err := mydump.NewDefaultFileRouter(log.L())
require.NoError(t, err)
for _, p := range paths {
res, err := fileRouter.Route(p)
require.NoError(t, err)
files = append(files, mydump.FileInfo{
TableName: res.Table,
FileMeta: mydump.SourceFileMeta{
Path: p,
Type: res.Type,
Compression: res.Compression,
SortKey: res.Key,
},
})
}
return files
}
func TestGenerateMydumperPattern(t *testing.T) {
paths := []string{"db.tb.0001.sql", "db.tb.0002.sql"}
p := generateMydumperPattern(generateFileMetas(t, paths)[0])
require.Equal(t, "db.tb.*.sql", p)
paths2 := []string{"s3://bucket/dir/db.tb.0001.sql", "s3://bucket/dir/db.tb.0002.sql"}
p2 := generateMydumperPattern(generateFileMetas(t, paths2)[0])
require.Equal(t, "s3://bucket/dir/db.tb.*.sql", p2)
// not mydumper pattern
require.Equal(t, "", generateMydumperPattern(mydump.FileInfo{
TableName: filter.Table{},
}))
}
func TestValidatePattern(t *testing.T) {
tableFiles := map[string]struct{}{
"a.txt": {}, "b.txt": {},
}
// only table files in allFiles
smallAll := map[string]mydump.FileInfo{
"a.txt": {}, "b.txt": {},
}
require.True(t, isValidPattern("*.txt", tableFiles, smallAll))
// allFiles includes an extra file => invalid
fullAll := map[string]mydump.FileInfo{
"a.txt": {}, "b.txt": {}, "c.txt": {},
}
require.False(t, isValidPattern("*.txt", tableFiles, fullAll))
// If pattern doesn't match our table's file, it's also invalid
require.False(t, isValidPattern("*.csv", tableFiles, smallAll))
// empty pattern => invalid
require.False(t, isValidPattern("", tableFiles, smallAll))
}
func TestGenerateWildcardPath(t *testing.T) {
// Helper to create allFiles map
createAllFiles := func(paths []string) map[string]mydump.FileInfo {
allFiles := make(map[string]mydump.FileInfo)
for _, p := range paths {
allFiles[p] = mydump.FileInfo{
FileMeta: mydump.SourceFileMeta{Path: p},
}
}
return allFiles
}
// No files
files1 := []mydump.FileInfo{}
allFiles1 := createAllFiles([]string{})
_, err := generateWildcardPath(files1, allFiles1)
require.Error(t, err)
require.Contains(t, err.Error(), "no data files for table")
// Single file
files2 := generateFileMetas(t, []string{"db.tb.0001.sql"})
allFiles2 := createAllFiles([]string{"db.tb.0001.sql"})
path2, err := generateWildcardPath(files2, allFiles2)
require.NoError(t, err)
require.Equal(t, "db.tb.0001.sql", path2)
// Mydumper pattern succeeds
files3 := generateFileMetas(t, []string{"db.tb.0001.sql.gz", "db.tb.0002.sql.gz"})
allFiles3 := createAllFiles([]string{"db.tb.0001.sql.gz", "db.tb.0002.sql.gz"})
path3, err := generateWildcardPath(files3, allFiles3)
require.NoError(t, err)
require.Equal(t, "db.tb.*.sql.gz", path3)
// Mydumper pattern fails, fallback to prefix/suffix succeeds
files4 := []mydump.FileInfo{
{
TableName: filter.Table{Schema: "db", Name: "tb"},
FileMeta: mydump.SourceFileMeta{
Path: "a.sql",
Type: mydump.SourceTypeSQL,
Compression: mydump.CompressionNone,
},
},
{
TableName: filter.Table{Schema: "db", Name: "tb"},
FileMeta: mydump.SourceFileMeta{
Path: "b.sql",
Type: mydump.SourceTypeSQL,
Compression: mydump.CompressionNone,
},
},
}
allFiles4 := map[string]mydump.FileInfo{
files4[0].FileMeta.Path: files4[0],
files4[1].FileMeta.Path: files4[1],
}
path4, err := generateWildcardPath(files4, allFiles4)
require.NoError(t, err)
require.Equal(t, "*.sql", path4)
allFiles4["db-schema.sql"] = mydump.FileInfo{
FileMeta: mydump.SourceFileMeta{
Path: "db-schema.sql",
Type: mydump.SourceTypeSQL,
Compression: mydump.CompressionNone,
},
}
_, err = generateWildcardPath(files4, allFiles4)
require.Error(t, err)
require.Contains(t, err.Error(), "cannot generate a unique wildcard pattern")
}

View File

@ -17,61 +17,21 @@ package importsdk
import (
"context"
"database/sql"
"path/filepath"
"strings"
"github.com/pingcap/errors"
"github.com/pingcap/tidb/br/pkg/storage"
"github.com/pingcap/tidb/pkg/lightning/config"
"github.com/pingcap/tidb/pkg/lightning/log"
"github.com/pingcap/tidb/pkg/lightning/mydump"
"github.com/pingcap/tidb/pkg/parser/mysql"
)
// SDK defines the interface for cloud import services
type SDK interface {
// CreateSchemasAndTables creates all database schemas and tables from source path
CreateSchemasAndTables(ctx context.Context) error
// GetTableMetas returns metadata for all tables in the source path
GetTableMetas(ctx context.Context) ([]*TableMeta, error)
// GetTableMetaByName returns metadata for a specific table
GetTableMetaByName(ctx context.Context, schema, table string) (*TableMeta, error)
// GetTotalSize returns the cumulative size (in bytes) of all data files under the source path
GetTotalSize(ctx context.Context) int64
// Close releases resources used by the SDK
FileScanner
JobManager
SQLGenerator
Close() error
}
// TableMeta contains metadata for a table to be imported
type TableMeta struct {
Database string
Table string
DataFiles []DataFileMeta
TotalSize int64 // In bytes
WildcardPath string // Wildcard pattern that matches only this table's data files
SchemaFile string // Path to the table schema file, if available
}
// DataFileMeta contains metadata for a data file
type DataFileMeta struct {
Path string
Size int64
Format mydump.SourceType
Compression mydump.Compression
}
// ImportSDK implements SDK interface
type ImportSDK struct {
sourcePath string
db *sql.DB
store storage.ExternalStorage
loader *mydump.MDLoader
logger log.Logger
config *sdkConfig
// importSDK implements SDK interface
type importSDK struct {
FileScanner
JobManager
SQLGenerator
}
// NewImportSDK creates a new CloudImportSDK instance
@ -81,410 +41,22 @@ func NewImportSDK(ctx context.Context, sourcePath string, db *sql.DB, options ..
opt(cfg)
}
u, err := storage.ParseBackend(sourcePath, nil)
scanner, err := NewFileScanner(ctx, sourcePath, db, cfg)
if err != nil {
return nil, errors.Annotatef(err, "failed to parse storage backend URL (source=%s). Please verify the URL format and credentials", sourcePath)
}
store, err := storage.New(ctx, u, &storage.ExternalStorageOptions{})
if err != nil {
return nil, errors.Annotatef(err, "failed to create external storage (source=%s). Check network/connectivity and permissions", sourcePath)
return nil, err
}
ldrCfg := mydump.LoaderConfig{
SourceURL: sourcePath,
Filter: cfg.filter,
FileRouters: cfg.fileRouteRules,
DefaultFileRules: len(cfg.fileRouteRules) == 0,
CharacterSet: cfg.charset,
}
jobManager := NewJobManager(db)
sqlGenerator := NewSQLGenerator()
loader, err := mydump.NewLoaderWithStore(ctx, ldrCfg, store)
if err != nil {
return nil, errors.Annotatef(err, "failed to create MyDump loader (source=%s, charset=%s, filter=%v). Please check dump layout and router rules", sourcePath, cfg.charset, cfg.filter)
}
return &ImportSDK{
sourcePath: sourcePath,
db: db,
store: store,
loader: loader,
logger: cfg.logger,
config: cfg,
return &importSDK{
FileScanner: scanner,
JobManager: jobManager,
SQLGenerator: sqlGenerator,
}, nil
}
// SDKOption customizes the SDK configuration
type SDKOption func(*sdkConfig)
type sdkConfig struct {
// Loader options
concurrency int
sqlMode mysql.SQLMode
fileRouteRules []*config.FileRouteRule
filter []string
charset string
// General options
logger log.Logger
// Close releases resources used by the SDK
func (sdk *importSDK) Close() error {
return sdk.FileScanner.Close()
}
func defaultSDKConfig() *sdkConfig {
return &sdkConfig{
concurrency: 4,
filter: config.GetDefaultFilter(),
logger: log.L(),
charset: "auto",
}
}
// WithConcurrency sets the number of concurrent DB/Table creation workers.
func WithConcurrency(n int) SDKOption {
return func(cfg *sdkConfig) {
if n > 0 {
cfg.concurrency = n
}
}
}
// WithLogger specifies a custom logger
func WithLogger(logger log.Logger) SDKOption {
return func(cfg *sdkConfig) {
cfg.logger = logger
}
}
// WithSQLMode specifies the SQL mode for schema parsing
func WithSQLMode(mode mysql.SQLMode) SDKOption {
return func(cfg *sdkConfig) {
cfg.sqlMode = mode
}
}
// WithFilter specifies a filter for the loader
func WithFilter(filter []string) SDKOption {
return func(cfg *sdkConfig) {
cfg.filter = filter
}
}
// WithFileRouters specifies custom file routing rules
func WithFileRouters(routers []*config.FileRouteRule) SDKOption {
return func(cfg *sdkConfig) {
cfg.fileRouteRules = routers
}
}
// WithCharset specifies the character set for import (default "auto").
func WithCharset(cs string) SDKOption {
return func(cfg *sdkConfig) {
if cs != "" {
cfg.charset = cs
}
}
}
// CreateSchemasAndTables implements the CloudImportSDK interface
func (sdk *ImportSDK) CreateSchemasAndTables(ctx context.Context) error {
dbMetas := sdk.loader.GetDatabases()
if len(dbMetas) == 0 {
return errors.Annotatef(ErrNoDatabasesFound, "source=%s. Ensure the path contains valid dump files (*.sql, *.csv, *.parquet, etc.) and filter rules are correct", sdk.sourcePath)
}
// Create all schemas and tables
importer := mydump.NewSchemaImporter(
sdk.logger,
sdk.config.sqlMode,
sdk.db,
sdk.store,
sdk.config.concurrency,
)
err := importer.Run(ctx, dbMetas)
if err != nil {
return errors.Annotatef(err, "creating schemas and tables failed (source=%s, db_count=%d, concurrency=%d)", sdk.sourcePath, len(dbMetas), sdk.config.concurrency)
}
return nil
}
// GetTableMetas implements the CloudImportSDK interface
func (sdk *ImportSDK) GetTableMetas(context.Context) ([]*TableMeta, error) {
dbMetas := sdk.loader.GetDatabases()
allFiles := sdk.loader.GetAllFiles()
var results []*TableMeta
for _, dbMeta := range dbMetas {
for _, tblMeta := range dbMeta.Tables {
tableMeta, err := sdk.buildTableMeta(dbMeta, tblMeta, allFiles)
if err != nil {
return nil, errors.Wrapf(err, "failed to build metadata for table %s.%s",
dbMeta.Name, tblMeta.Name)
}
results = append(results, tableMeta)
}
}
return results, nil
}
// GetTableMetaByName implements CloudImportSDK interface
func (sdk *ImportSDK) GetTableMetaByName(_ context.Context, schema, table string) (*TableMeta, error) {
dbMetas := sdk.loader.GetDatabases()
allFiles := sdk.loader.GetAllFiles()
// Find the specific table
for _, dbMeta := range dbMetas {
if dbMeta.Name != schema {
continue
}
for _, tblMeta := range dbMeta.Tables {
if tblMeta.Name != table {
continue
}
return sdk.buildTableMeta(dbMeta, tblMeta, allFiles)
}
return nil, errors.Annotatef(ErrTableNotFound, "schema=%s, table=%s", schema, table)
}
return nil, errors.Annotatef(ErrSchemaNotFound, "schema=%s", schema)
}
// GetTotalSize implements CloudImportSDK interface
func (sdk *ImportSDK) GetTotalSize(ctx context.Context) int64 {
var total int64
dbMetas := sdk.loader.GetDatabases()
for _, dbMeta := range dbMetas {
for _, tblMeta := range dbMeta.Tables {
total += tblMeta.TotalSize
}
}
return total
}
// buildTableMeta creates a TableMeta from database and table metadata
func (sdk *ImportSDK) buildTableMeta(
dbMeta *mydump.MDDatabaseMeta,
tblMeta *mydump.MDTableMeta,
allDataFiles map[string]mydump.FileInfo,
) (*TableMeta, error) {
tableMeta := &TableMeta{
Database: dbMeta.Name,
Table: tblMeta.Name,
DataFiles: make([]DataFileMeta, 0, len(tblMeta.DataFiles)),
SchemaFile: tblMeta.SchemaFile.FileMeta.Path,
}
// Process data files
dataFiles, totalSize := processDataFiles(tblMeta.DataFiles)
tableMeta.DataFiles = dataFiles
tableMeta.TotalSize = totalSize
wildcard, err := generateWildcardPath(tblMeta.DataFiles, allDataFiles)
if err != nil {
return nil, errors.Annotatef(err, "failed to build wildcard for table=%s.%s", dbMeta.Name, tblMeta.Name)
}
tableMeta.WildcardPath = strings.TrimSuffix(sdk.store.URI(), "/") + "/" + wildcard
return tableMeta, nil
}
// Close implements CloudImportSDK interface
func (sdk *ImportSDK) Close() error {
// close external storage
if sdk.store != nil {
sdk.store.Close()
}
return nil
}
// processDataFiles converts mydump data files to DataFileMeta and calculates total size
func processDataFiles(files []mydump.FileInfo) ([]DataFileMeta, int64) {
dataFiles := make([]DataFileMeta, 0, len(files))
var totalSize int64
for _, dataFile := range files {
fileMeta := createDataFileMeta(dataFile)
dataFiles = append(dataFiles, fileMeta)
totalSize += dataFile.FileMeta.RealSize
}
return dataFiles, totalSize
}
// createDataFileMeta creates a DataFileMeta from a mydump.DataFile
func createDataFileMeta(file mydump.FileInfo) DataFileMeta {
return DataFileMeta{
Path: file.FileMeta.Path,
Size: file.FileMeta.RealSize,
Format: file.FileMeta.Type,
Compression: file.FileMeta.Compression,
}
}
// generateWildcardPath creates a wildcard pattern path that matches only this table's files
func generateWildcardPath(
files []mydump.FileInfo,
allFiles map[string]mydump.FileInfo,
) (string, error) {
tableFiles := make(map[string]struct{}, len(files))
for _, df := range files {
tableFiles[df.FileMeta.Path] = struct{}{}
}
if len(files) == 0 {
return "", errors.Annotate(ErrNoTableDataFiles, "cannot generate wildcard pattern because the table has no data files")
}
// If there's only one file, we can just return its path
if len(files) == 1 {
return files[0].FileMeta.Path, nil
}
// Try Mydumper-specific pattern first
p := generateMydumperPattern(files[0])
if p != "" && isValidPattern(p, tableFiles, allFiles) {
return p, nil
}
// Fallback to generic prefix/suffix pattern
paths := make([]string, 0, len(files))
for _, file := range files {
paths = append(paths, file.FileMeta.Path)
}
p = generatePrefixSuffixPattern(paths)
if p != "" && isValidPattern(p, tableFiles, allFiles) {
return p, nil
}
return "", errors.Annotatef(ErrWildcardNotSpecific, "failed to find a wildcard that matches all and only the table's files.")
}
// isValidPattern checks if a wildcard pattern matches only the table's files
func isValidPattern(pattern string, tableFiles map[string]struct{}, allFiles map[string]mydump.FileInfo) bool {
if pattern == "" {
return false
}
for path := range allFiles {
isMatch, err := filepath.Match(pattern, path)
if err != nil {
return false // Invalid pattern
}
_, isTableFile := tableFiles[path]
// If pattern matches a file that's not from our table, it's invalid
if isMatch && !isTableFile {
return false
}
// If pattern doesn't match our table's file, it's also invalid
if !isMatch && isTableFile {
return false
}
}
return true
}
// generateMydumperPattern generates a wildcard pattern for Mydumper-formatted data files
// belonging to a specific table, based on their naming convention.
// It returns a pattern string that matches all data files for the table, or an empty string if not applicable.
func generateMydumperPattern(file mydump.FileInfo) string {
dbName, tableName := file.TableName.Schema, file.TableName.Name
if dbName == "" || tableName == "" {
return ""
}
// compute dirPrefix and basename
full := file.FileMeta.Path
dirPrefix, name := "", full
if idx := strings.LastIndex(full, "/"); idx >= 0 {
dirPrefix = full[:idx+1]
name = full[idx+1:]
}
// compression ext from filename when compression exists (last suffix like .gz/.zst)
compExt := ""
if file.FileMeta.Compression != mydump.CompressionNone {
compExt = filepath.Ext(name)
}
// data ext after stripping compression ext
base := strings.TrimSuffix(name, compExt)
dataExt := filepath.Ext(base)
return dirPrefix + dbName + "." + tableName + ".*" + dataExt + compExt
}
// longestCommonPrefix finds the longest string that is a prefix of all strings in the slice
func longestCommonPrefix(strs []string) string {
if len(strs) == 0 {
return ""
}
prefix := strs[0]
for _, s := range strs[1:] {
i := 0
for i < len(prefix) && i < len(s) && prefix[i] == s[i] {
i++
}
prefix = prefix[:i]
if prefix == "" {
break
}
}
return prefix
}
// longestCommonSuffix finds the longest string that is a suffix of all strings in the slice, starting after the given prefix length
func longestCommonSuffix(strs []string, prefixLen int) string {
if len(strs) == 0 {
return ""
}
suffix := strs[0][prefixLen:]
for _, s := range strs[1:] {
remaining := s[prefixLen:]
i := 0
for i < len(suffix) && i < len(remaining) && suffix[len(suffix)-i-1] == remaining[len(remaining)-i-1] {
i++
}
suffix = suffix[len(suffix)-i:]
if suffix == "" {
break
}
}
return suffix
}
// generatePrefixSuffixPattern returns a wildcard pattern that matches all and only the given paths
// by finding the longest common prefix and suffix among them, and placing a '*' wildcard in between.
func generatePrefixSuffixPattern(paths []string) string {
if len(paths) == 0 {
return ""
}
if len(paths) == 1 {
return paths[0]
}
prefix := longestCommonPrefix(paths)
suffix := longestCommonSuffix(paths, len(prefix))
return prefix + "*" + suffix
}
// TODO: add error code and doc for cloud sdk
// Sentinel errors to categorize common failure scenarios for clearer user messages.
var (
// ErrNoDatabasesFound indicates that the dump source contains no recognizable databases.
ErrNoDatabasesFound = errors.New("no databases found in the source path")
// ErrSchemaNotFound indicates the target schema doesn't exist in the dump source.
ErrSchemaNotFound = errors.New("schema not found")
// ErrTableNotFound indicates the target table doesn't exist in the dump source.
ErrTableNotFound = errors.New("table not found")
// ErrNoTableDataFiles indicates a table has zero data files and thus cannot proceed.
ErrNoTableDataFiles = errors.New("no data files for table")
// ErrWildcardNotSpecific indicates a wildcard cannot uniquely match the table's files.
ErrWildcardNotSpecific = errors.New("cannot generate a unique wildcard pattern for the table's data files")
)

View File

@ -27,8 +27,6 @@ import (
"github.com/pingcap/tidb/pkg/lightning/log"
"github.com/pingcap/tidb/pkg/lightning/mydump"
"github.com/pingcap/tidb/pkg/parser/mysql"
filter "github.com/pingcap/tidb/pkg/util/table-filter"
"github.com/stretchr/testify/require"
"github.com/stretchr/testify/suite"
)
@ -355,219 +353,82 @@ func (s *mockGCSSuite) TestOnlyDataFiles() {
s.NoError(err)
}
func TestLongestCommonPrefix(t *testing.T) {
strs := []string{"s3://bucket/foo/bar/baz1", "s3://bucket/foo/bar/baz2", "s3://bucket/foo/bar/baz"}
p := longestCommonPrefix(strs)
require.Equal(t, "s3://bucket/foo/bar/baz", p)
// no common prefix
require.Equal(t, "", longestCommonPrefix([]string{"a", "b"}))
// empty inputs
require.Equal(t, "", longestCommonPrefix(nil))
require.Equal(t, "", longestCommonPrefix([]string{}))
func (s *mockGCSSuite) TestScanLimitation() {
s.server.CreateBucketWithOpts(fakestorage.CreateBucketOpts{Name: "limitation"})
s.server.CreateObject(fakestorage.Object{
ObjectAttrs: fakestorage.ObjectAttrs{BucketName: "limitation", Name: "db2.tb2.001.csv"},
Content: []byte("5,e\n6,f\n"),
})
s.server.CreateObject(fakestorage.Object{
ObjectAttrs: fakestorage.ObjectAttrs{BucketName: "limitation", Name: "db2.tb2.002.csv"},
Content: []byte("7,g\n8,h\n"),
})
db, _, err := sqlmock.New()
s.NoError(err)
defer db.Close()
importSDK, err := NewImportSDK(
context.Background(),
fmt.Sprintf("gs://limitation?endpoint=%s&access-key=aaaaaa&secret-access-key=bbbbbb", gcsEndpoint),
db,
WithCharset("utf8"),
WithConcurrency(8),
WithFilter([]string{"*.*"}),
WithSQLMode(mysql.ModeANSIQuotes),
WithLogger(log.L()),
WithSkipInvalidFiles(true),
WithMaxScanFiles(1),
)
s.NoError(err)
defer importSDK.Close()
metas, err := importSDK.GetTableMetas(context.Background())
s.NoError(err)
s.Len(metas, 1)
s.Len(metas[0].DataFiles, 1)
}
func TestLongestCommonSuffix(t *testing.T) {
strs := []string{"abcXYZ", "defXYZ", "XYZ"}
s := longestCommonSuffix(strs, 0)
require.Equal(t, "XYZ", s)
// no common suffix
require.Equal(t, "", longestCommonSuffix([]string{"a", "b"}, 0))
// empty inputs
require.Equal(t, "", longestCommonSuffix(nil, 0))
require.Equal(t, "", longestCommonSuffix([]string{}, 0))
// same prefix
require.Equal(t, "", longestCommonSuffix([]string{"abc", "abc"}, 3))
require.Equal(t, "f", longestCommonSuffix([]string{"abcdf", "abcef"}, 3))
}
func TestGeneratePrefixSuffixPattern(t *testing.T) {
paths := []string{"pre_middle_suf", "pre_most_suf"}
pattern := generatePrefixSuffixPattern(paths)
// common prefix "pre_m", suffix "_suf"
require.Equal(t, "pre_m*_suf", pattern)
// empty inputs
require.Equal(t, "", generatePrefixSuffixPattern(nil))
require.Equal(t, "", generatePrefixSuffixPattern([]string{}))
// only one file
require.Equal(t, "pre_middle_suf", generatePrefixSuffixPattern([]string{"pre_middle_suf"}))
// no common prefix/suffix
paths2 := []string{"foo", "bar"}
require.Equal(t, "*", generatePrefixSuffixPattern(paths2))
// overlapping prefix/suffix
paths3 := []string{"aaabaaa", "aaa"}
require.Equal(t, "aaa*", generatePrefixSuffixPattern(paths3))
}
func generateFileMetas(t *testing.T, paths []string) []mydump.FileInfo {
t.Helper()
files := make([]mydump.FileInfo, 0, len(paths))
fileRouter, err := mydump.NewDefaultFileRouter(log.L())
require.NoError(t, err)
for _, p := range paths {
res, err := fileRouter.Route(p)
require.NoError(t, err)
files = append(files, mydump.FileInfo{
TableName: res.Table,
FileMeta: mydump.SourceFileMeta{
Path: p,
Type: res.Type,
Compression: res.Compression,
SortKey: res.Key,
},
func (s *mockGCSSuite) TestCreateTableMetaByName() {
for i := 0; i < 2; i++ {
s.server.CreateObject(fakestorage.Object{
ObjectAttrs: fakestorage.ObjectAttrs{BucketName: "specific-table-test", Name: fmt.Sprintf("db%d-schema-create.sql", i)},
Content: []byte(fmt.Sprintf("CREATE DATABASE IF NOT EXISTS db%d;\n", i)),
})
}
return files
}
func TestGenerateMydumperPattern(t *testing.T) {
paths := []string{"db.tb.0001.sql", "db.tb.0002.sql"}
p := generateMydumperPattern(generateFileMetas(t, paths)[0])
require.Equal(t, "db.tb.*.sql", p)
paths2 := []string{"s3://bucket/dir/db.tb.0001.sql", "s3://bucket/dir/db.tb.0002.sql"}
p2 := generateMydumperPattern(generateFileMetas(t, paths2)[0])
require.Equal(t, "s3://bucket/dir/db.tb.*.sql", p2)
// not mydumper pattern
require.Equal(t, "", generateMydumperPattern(mydump.FileInfo{
TableName: filter.Table{},
}))
}
func TestCreateDataFileMeta(t *testing.T) {
fi := mydump.FileInfo{
TableName: filter.Table{
Schema: "db",
Name: "table",
},
FileMeta: mydump.SourceFileMeta{
Path: "s3://bucket/path/to/f",
FileSize: 123,
Type: mydump.SourceTypeCSV,
Compression: mydump.CompressionGZ,
RealSize: 456,
},
}
df := createDataFileMeta(fi)
require.Equal(t, "s3://bucket/path/to/f", df.Path)
require.Equal(t, int64(456), df.Size)
require.Equal(t, mydump.SourceTypeCSV, df.Format)
require.Equal(t, mydump.CompressionGZ, df.Compression)
}
func TestProcessDataFiles(t *testing.T) {
files := []mydump.FileInfo{
{FileMeta: mydump.SourceFileMeta{Path: "s3://bucket/a", RealSize: 10}},
{FileMeta: mydump.SourceFileMeta{Path: "s3://bucket/b", RealSize: 20}},
}
dfm, total := processDataFiles(files)
require.Len(t, dfm, 2)
require.Equal(t, int64(30), total)
require.Equal(t, "s3://bucket/a", dfm[0].Path)
require.Equal(t, "s3://bucket/b", dfm[1].Path)
}
func TestValidatePattern(t *testing.T) {
tableFiles := map[string]struct{}{
"a.txt": {}, "b.txt": {},
}
// only table files in allFiles
smallAll := map[string]mydump.FileInfo{
"a.txt": {}, "b.txt": {},
}
require.True(t, isValidPattern("*.txt", tableFiles, smallAll))
// allFiles includes an extra file => invalid
fullAll := map[string]mydump.FileInfo{
"a.txt": {}, "b.txt": {}, "c.txt": {},
}
require.False(t, isValidPattern("*.txt", tableFiles, fullAll))
// If pattern doesn't match our table's file, it's also invalid
require.False(t, isValidPattern("*.csv", tableFiles, smallAll))
// empty pattern => invalid
require.False(t, isValidPattern("", tableFiles, smallAll))
}
func TestGenerateWildcardPath(t *testing.T) {
// Helper to create allFiles map
createAllFiles := func(paths []string) map[string]mydump.FileInfo {
allFiles := make(map[string]mydump.FileInfo)
for _, p := range paths {
allFiles[p] = mydump.FileInfo{
FileMeta: mydump.SourceFileMeta{Path: p},
}
for j := 0; j != 2; j++ {
tableName := fmt.Sprintf("db%d.tb%d", i, j)
s.server.CreateObject(fakestorage.Object{
ObjectAttrs: fakestorage.ObjectAttrs{BucketName: "specific-table-test", Name: fmt.Sprintf("%s-schema.sql", tableName)},
Content: []byte(fmt.Sprintf("CREATE TABLE IF NOT EXISTS db%d.tb%d (a INT, b VARCHAR(10));\n", i, j)),
})
s.server.CreateObject(fakestorage.Object{
ObjectAttrs: fakestorage.ObjectAttrs{BucketName: "specific-table-test", Name: fmt.Sprintf("%s.001.sql", tableName)},
Content: []byte(fmt.Sprintf("INSERT INTO db%d.tb%d VALUES (1,'a'),(2,'b');\n", i, j)),
})
s.server.CreateObject(fakestorage.Object{
ObjectAttrs: fakestorage.ObjectAttrs{BucketName: "specific-table-test", Name: fmt.Sprintf("%s.002.sql", tableName)},
Content: []byte(fmt.Sprintf("INSERT INTO db%d.tb%d VALUES (3,'c'),(4,'d');\n", i, j)),
})
}
return allFiles
}
// No files
files1 := []mydump.FileInfo{}
allFiles1 := createAllFiles([]string{})
_, err := generateWildcardPath(files1, allFiles1)
require.Error(t, err)
require.Contains(t, err.Error(), "no data files for table")
db, mock, err := sqlmock.New()
s.NoError(err)
defer db.Close()
// Single file
files2 := generateFileMetas(t, []string{"db.tb.0001.sql"})
allFiles2 := createAllFiles([]string{"db.tb.0001.sql"})
path2, err := generateWildcardPath(files2, allFiles2)
require.NoError(t, err)
require.Equal(t, "db.tb.0001.sql", path2)
mock.ExpectQuery(`SELECT SCHEMA_NAME FROM information_schema.SCHEMATA`).
WillReturnRows(sqlmock.NewRows([]string{"SCHEMA_NAME"}))
mock.ExpectExec("CREATE DATABASE IF NOT EXISTS `db1`;").
WillReturnResult(sqlmock.NewResult(0, 1))
mock.ExpectExec(regexp.QuoteMeta("CREATE TABLE IF NOT EXISTS `db1`.`tb1` (`a` INT,`b` VARCHAR(10));")).
WillReturnResult(sqlmock.NewResult(0, 1))
// Mydumper pattern succeeds
files3 := generateFileMetas(t, []string{"db.tb.0001.sql.gz", "db.tb.0002.sql.gz"})
allFiles3 := createAllFiles([]string{"db.tb.0001.sql.gz", "db.tb.0002.sql.gz"})
path3, err := generateWildcardPath(files3, allFiles3)
require.NoError(t, err)
require.Equal(t, "db.tb.*.sql.gz", path3)
importSDK, err := NewImportSDK(
context.Background(),
fmt.Sprintf("gs://specific-table-test?endpoint=%s&access-key=aaaaaa&secret-access-key=bbbbbb", gcsEndpoint),
db,
WithConcurrency(1),
)
s.NoError(err)
defer importSDK.Close()
// Mydumper pattern fails, fallback to prefix/suffix succeeds
files4 := []mydump.FileInfo{
{
TableName: filter.Table{Schema: "db", Name: "tb"},
FileMeta: mydump.SourceFileMeta{
Path: "a.sql",
Type: mydump.SourceTypeSQL,
Compression: mydump.CompressionNone,
},
},
{
TableName: filter.Table{Schema: "db", Name: "tb"},
FileMeta: mydump.SourceFileMeta{
Path: "b.sql",
Type: mydump.SourceTypeSQL,
Compression: mydump.CompressionNone,
},
},
}
allFiles4 := map[string]mydump.FileInfo{
files4[0].FileMeta.Path: files4[0],
files4[1].FileMeta.Path: files4[1],
}
path4, err := generateWildcardPath(files4, allFiles4)
require.NoError(t, err)
require.Equal(t, "*.sql", path4)
allFiles4["db-schema.sql"] = mydump.FileInfo{
FileMeta: mydump.SourceFileMeta{
Path: "db-schema.sql",
Type: mydump.SourceTypeSQL,
Compression: mydump.CompressionNone,
},
}
_, err = generateWildcardPath(files4, allFiles4)
require.Error(t, err)
require.Contains(t, err.Error(), "cannot generate a unique wildcard pattern")
err = importSDK.CreateSchemaAndTableByName(context.Background(), "db1", "tb1")
s.NoError(err)
}

View File

@ -0,0 +1,161 @@
// Copyright 2025 PingCAP, Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package importsdk
import (
"fmt"
"net/url"
"strings"
"github.com/pingcap/tidb/pkg/lightning/config"
)
// SQLGenerator defines the interface for generating IMPORT INTO SQL
type SQLGenerator interface {
GenerateImportSQL(tableMeta *TableMeta, options *ImportOptions) (string, error)
}
type sqlGenerator struct{}
// NewSQLGenerator creates a new SQLGenerator
func NewSQLGenerator() SQLGenerator {
return &sqlGenerator{}
}
// GenerateImportSQL generates the IMPORT INTO SQL statement
func (g *sqlGenerator) GenerateImportSQL(tableMeta *TableMeta, options *ImportOptions) (string, error) {
var sb strings.Builder
sb.WriteString("IMPORT INTO ")
sb.WriteString(escapeIdentifier(tableMeta.Database))
sb.WriteString(".")
sb.WriteString(escapeIdentifier(tableMeta.Table))
path := tableMeta.WildcardPath
if options.ResourceParameters != "" {
u, err := url.Parse(path)
if err == nil {
if u.RawQuery != "" {
u.RawQuery += "&" + options.ResourceParameters
} else {
u.RawQuery = options.ResourceParameters
}
path = u.String()
}
}
sb.WriteString(" FROM '")
sb.WriteString(path)
sb.WriteString("'")
if options.Format != "" {
sb.WriteString(" FORMAT '")
sb.WriteString(options.Format)
sb.WriteString("'")
}
opts, err := g.buildOptions(options)
if err != nil {
return "", err
}
if len(opts) > 0 {
sb.WriteString(" WITH ")
sb.WriteString(strings.Join(opts, ", "))
}
return sb.String(), nil
}
func (g *sqlGenerator) buildOptions(options *ImportOptions) ([]string, error) {
var opts []string
if options.Thread > 0 {
opts = append(opts, fmt.Sprintf("THREAD=%d", options.Thread))
}
if options.DiskQuota != "" {
opts = append(opts, fmt.Sprintf("DISK_QUOTA='%s'", options.DiskQuota))
}
if options.MaxWriteSpeed != "" {
opts = append(opts, fmt.Sprintf("MAX_WRITE_SPEED='%s'", options.MaxWriteSpeed))
}
if options.SplitFile {
opts = append(opts, "SPLIT_FILE")
}
if options.RecordErrors > 0 {
opts = append(opts, fmt.Sprintf("RECORD_ERRORS=%d", options.RecordErrors))
}
if options.Detached {
opts = append(opts, "DETACHED")
}
if options.CloudStorageURI != "" {
opts = append(opts, fmt.Sprintf("CLOUD_STORAGE_URI='%s'", options.CloudStorageURI))
}
if options.GroupKey != "" {
opts = append(opts, fmt.Sprintf("GROUP_KEY='%s'", escapeString(options.GroupKey)))
}
if options.SkipRows > 0 {
opts = append(opts, fmt.Sprintf("SKIP_ROWS=%d", options.SkipRows))
}
if options.CharacterSet != "" {
opts = append(opts, fmt.Sprintf("CHARACTER_SET='%s'", escapeString(options.CharacterSet)))
}
if options.ChecksumTable != "" {
opts = append(opts, fmt.Sprintf("CHECKSUM_TABLE='%s'", escapeString(options.ChecksumTable)))
}
if options.DisableTiKVImportMode {
opts = append(opts, "DISABLE_TIKV_IMPORT_MODE")
}
if options.DisablePrecheck {
opts = append(opts, "DISABLE_PRECHECK")
}
if options.CSVConfig != nil && options.Format == "csv" {
csvOpts, err := g.buildCSVOptions(options.CSVConfig)
if err != nil {
return nil, err
}
opts = append(opts, csvOpts...)
}
return opts, nil
}
func (g *sqlGenerator) buildCSVOptions(csvConfig *config.CSVConfig) ([]string, error) {
var opts []string
if csvConfig.FieldsTerminatedBy != "" {
opts = append(opts, fmt.Sprintf("FIELDS_TERMINATED_BY='%s'", escapeString(csvConfig.FieldsTerminatedBy)))
}
if csvConfig.FieldsEnclosedBy != "" {
opts = append(opts, fmt.Sprintf("FIELDS_ENCLOSED_BY='%s'", escapeString(csvConfig.FieldsEnclosedBy)))
}
if csvConfig.FieldsEscapedBy != "" {
opts = append(opts, fmt.Sprintf("FIELDS_ESCAPED_BY='%s'", escapeString(csvConfig.FieldsEscapedBy)))
}
if csvConfig.LinesTerminatedBy != "" {
opts = append(opts, fmt.Sprintf("LINES_TERMINATED_BY='%s'", escapeString(csvConfig.LinesTerminatedBy)))
}
if len(csvConfig.FieldNullDefinedBy) > 0 {
if len(csvConfig.FieldNullDefinedBy) > 1 {
return nil, ErrMultipleFieldsDefinedNullBy
}
opts = append(opts, fmt.Sprintf("FIELDS_DEFINED_NULL_BY='%s'", escapeString(csvConfig.FieldNullDefinedBy[0])))
}
return opts, nil
}
func escapeIdentifier(s string) string {
return "`" + strings.ReplaceAll(s, "`", "``") + "`"
}
func escapeString(s string) string {
return strings.ReplaceAll(strings.ReplaceAll(s, "\\", "\\\\"), "'", "''")
}

View File

@ -0,0 +1,173 @@
// Copyright 2025 PingCAP, Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package importsdk
import (
"testing"
"github.com/pingcap/tidb/pkg/lightning/config"
"github.com/stretchr/testify/require"
)
func TestGenerateImportSQL(t *testing.T) {
gen := NewSQLGenerator()
defaultTableMeta := &TableMeta{
Database: "test_db",
Table: "test_table",
WildcardPath: "s3://bucket/path/*.csv",
}
tests := []struct {
name string
tableMeta *TableMeta
options *ImportOptions
expected string
expectedError string
}{
{
name: "Basic",
options: &ImportOptions{
Format: "csv",
},
expected: "IMPORT INTO `test_db`.`test_table` FROM 's3://bucket/path/*.csv' FORMAT 'csv'",
},
{
name: "With S3 Credentials",
options: &ImportOptions{
Format: "csv",
ResourceParameters: "access-key=ak&endpoint=http%3A%2F%2Fminio%3A9000&secret-access-key=sk",
},
expected: "IMPORT INTO `test_db`.`test_table` FROM 's3://bucket/path/*.csv?access-key=ak&endpoint=http%3A%2F%2Fminio%3A9000&secret-access-key=sk' FORMAT 'csv'",
},
{
name: "With Options",
options: &ImportOptions{
Format: "csv",
Thread: 4,
Detached: true,
MaxWriteSpeed: "100MiB",
},
expected: "IMPORT INTO `test_db`.`test_table` FROM 's3://bucket/path/*.csv' FORMAT 'csv' WITH THREAD=4, MAX_WRITE_SPEED='100MiB', DETACHED",
},
{
name: "All Options",
options: &ImportOptions{
Format: "csv",
Thread: 8,
DiskQuota: "100GiB",
MaxWriteSpeed: "200MiB",
SplitFile: true,
RecordErrors: 100,
Detached: true,
CloudStorageURI: "s3://bucket/storage",
GroupKey: "group1",
SkipRows: 1,
CharacterSet: "utf8mb4",
ChecksumTable: "test_db.checksum_table",
DisableTiKVImportMode: true,
DisablePrecheck: true,
},
expected: "IMPORT INTO `test_db`.`test_table` FROM 's3://bucket/path/*.csv' FORMAT 'csv' WITH THREAD=8, DISK_QUOTA='100GiB', MAX_WRITE_SPEED='200MiB', SPLIT_FILE, RECORD_ERRORS=100, DETACHED, CLOUD_STORAGE_URI='s3://bucket/storage', GROUP_KEY='group1', SKIP_ROWS=1, CHARACTER_SET='utf8mb4', CHECKSUM_TABLE='test_db.checksum_table', DISABLE_TIKV_IMPORT_MODE, DISABLE_PRECHECK",
},
{
name: "With CSV Config",
options: &ImportOptions{
Format: "csv",
CSVConfig: &config.CSVConfig{
FieldsTerminatedBy: ",",
FieldsEnclosedBy: "\"",
FieldsEscapedBy: "\\",
LinesTerminatedBy: "\n",
FieldNullDefinedBy: []string{"NULL"},
},
},
expected: "IMPORT INTO `test_db`.`test_table` FROM 's3://bucket/path/*.csv' FORMAT 'csv' WITH FIELDS_TERMINATED_BY=',', FIELDS_ENCLOSED_BY='\"', FIELDS_ESCAPED_BY='\\\\', LINES_TERMINATED_BY='\n', FIELDS_DEFINED_NULL_BY='NULL'",
},
{
name: "With Cloud Storage URI",
options: &ImportOptions{
Format: "parquet",
CloudStorageURI: "s3://bucket/storage",
},
expected: "IMPORT INTO `test_db`.`test_table` FROM 's3://bucket/path/*.csv' FORMAT 'parquet' WITH CLOUD_STORAGE_URI='s3://bucket/storage'",
},
{
name: "Multiple Null Defined By",
options: &ImportOptions{
Format: "csv",
CSVConfig: &config.CSVConfig{
FieldNullDefinedBy: []string{"NULL", "\\N"},
},
},
expectedError: "IMPORT INTO only supports one FIELDS_DEFINED_NULL_BY value",
},
{
name: "Resource Parameters Append",
tableMeta: &TableMeta{
Database: "test_db",
Table: "test_table",
WildcardPath: "s3://bucket/path/*.csv?foo=bar",
},
options: &ImportOptions{
Format: "csv",
ResourceParameters: "access-key=ak",
},
expected: "IMPORT INTO `test_db`.`test_table` FROM 's3://bucket/path/*.csv?foo=bar&access-key=ak' FORMAT 'csv'",
},
{
name: "Escaping",
options: &ImportOptions{
Format: "csv",
GroupKey: "group'1",
CharacterSet: "utf8'mb4",
CSVConfig: &config.CSVConfig{
FieldsTerminatedBy: "'",
FieldsEnclosedBy: "\\",
},
},
expected: "IMPORT INTO `test_db`.`test_table` FROM 's3://bucket/path/*.csv' FORMAT 'csv' WITH GROUP_KEY='group''1', CHARACTER_SET='utf8''mb4', FIELDS_TERMINATED_BY='''', FIELDS_ENCLOSED_BY='\\\\'",
},
{
name: "Identifier Escaping",
tableMeta: &TableMeta{
Database: "test`db",
Table: "test`table",
WildcardPath: "s3://bucket/path/*.csv",
},
options: &ImportOptions{
Format: "csv",
},
expected: "IMPORT INTO `test``db`.`test``table` FROM 's3://bucket/path/*.csv' FORMAT 'csv'",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
tm := tt.tableMeta
if tm == nil {
tm = defaultTableMeta
}
sql, err := gen.GenerateImportSQL(tm, tt.options)
if tt.expectedError != "" {
require.Error(t, err)
require.Contains(t, err.Error(), tt.expectedError)
} else {
require.NoError(t, err)
require.Equal(t, tt.expected, sql)
}
})
}
}

View File

@ -4,6 +4,7 @@ go_library(
name = "testkit",
srcs = [
"asynctestkit.go",
"db_driver.go",
"dbtestkit.go",
"mocksessionmanager.go",
"mockstore.go",
@ -71,8 +72,12 @@ go_library(
go_test(
name = "testkit_test",
timeout = "short",
srcs = ["testkit_test.go"],
srcs = [
"db_driver_test.go",
"testkit_test.go",
],
embed = [":testkit"],
flaky = True,
shard_count = 2,
deps = ["@com_github_stretchr_testify//require"],
)

308
pkg/testkit/db_driver.go Normal file
View File

@ -0,0 +1,308 @@
// Copyright 2025 PingCAP, Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package testkit
import (
"context"
"database/sql"
"database/sql/driver"
"fmt"
"io"
"strconv"
"sync"
"sync/atomic"
"github.com/pingcap/errors"
"github.com/pingcap/tidb/pkg/expression"
"github.com/pingcap/tidb/pkg/parser/mysql"
"github.com/pingcap/tidb/pkg/session/sessionapi"
"github.com/pingcap/tidb/pkg/types"
"github.com/pingcap/tidb/pkg/util/chunk"
"github.com/pingcap/tidb/pkg/util/sqlexec"
)
var (
tkMapMu sync.RWMutex
tkMap = make(map[string]*TestKit)
tkIDSeq int64
)
func init() {
sql.Register("testkit", &testKitDriver{})
}
// CreateMockDB creates a *sql.DB that uses the TestKit's store to create sessions.
func CreateMockDB(tk *TestKit) *sql.DB {
id := strconv.FormatInt(atomic.AddInt64(&tkIDSeq, 1), 10)
tkMapMu.Lock()
tkMap[id] = tk
tkMapMu.Unlock()
db, err := sql.Open("testkit", id)
if err != nil {
panic(err)
}
return db
}
type testKitDriver struct{}
func (d *testKitDriver) Open(name string) (driver.Conn, error) {
tkMapMu.RLock()
tk, ok := tkMap[name]
tkMapMu.RUnlock()
if !ok {
return nil, fmt.Errorf("testkit not found for %s", name)
}
se := NewSession(tk.t, tk.store)
return &testKitConn{se: se}, nil
}
type testKitConn struct {
se sessionapi.Session
}
func (c *testKitConn) Prepare(query string) (driver.Stmt, error) {
return &testKitStmt{c: c, query: query}, nil
}
func (c *testKitConn) Close() error {
c.se.Close()
return nil
}
func (c *testKitConn) Begin() (driver.Tx, error) {
_, err := c.Exec("BEGIN", nil)
if err != nil {
return nil, err
}
return &testKitTxn{c: c}, nil
}
func (c *testKitConn) Exec(query string, args []driver.Value) (driver.Result, error) {
return nil, driver.ErrSkip
}
func (c *testKitConn) ExecContext(ctx context.Context, query string, args []driver.NamedValue) (driver.Result, error) {
qArgs := make([]any, len(args))
for i, a := range args {
qArgs[i] = a.Value
}
rs, err := c.execute(ctx, query, qArgs)
if err != nil {
return nil, err
}
if rs != nil {
if err := rs.Close(); err != nil {
return nil, err
}
}
return &testKitResult{
lastInsertID: int64(c.se.LastInsertID()),
rowsAffected: int64(c.se.AffectedRows()),
}, nil
}
func (c *testKitConn) QueryContext(ctx context.Context, query string, args []driver.NamedValue) (driver.Rows, error) {
qArgs := make([]any, len(args))
for i, a := range args {
qArgs[i] = a.Value
}
rs, err := c.execute(ctx, query, qArgs)
if err != nil {
return nil, err
}
if rs == nil {
return nil, nil
}
return &testKitRows{rs: rs}, nil
}
func (c *testKitConn) execute(ctx context.Context, sql string, args []any) (sqlexec.RecordSet, error) {
// Set the command value to ComQuery, so that the process info can be updated correctly
c.se.SetCommandValue(mysql.ComQuery)
defer c.se.SetCommandValue(mysql.ComSleep)
if len(args) == 0 {
rss, err := c.se.Execute(ctx, sql)
if err != nil {
return nil, err
}
if len(rss) == 0 {
return nil, nil
}
return rss[0], nil
}
stmtID, _, _, err := c.se.PrepareStmt(sql)
if err != nil {
return nil, errors.Trace(err)
}
params := expression.Args2Expressions4Test(args...)
rs, err := c.se.ExecutePreparedStmt(ctx, stmtID, params)
if err != nil {
return rs, errors.Trace(err)
}
err = c.se.DropPreparedStmt(stmtID)
if err != nil {
return rs, errors.Trace(err)
}
return rs, nil
}
type testKitStmt struct {
c *testKitConn
query string
}
func (s *testKitStmt) Close() error {
return nil
}
func (s *testKitStmt) NumInput() int {
return -1
}
func (s *testKitStmt) Exec(args []driver.Value) (driver.Result, error) {
return nil, driver.ErrSkip
}
func (s *testKitStmt) ExecContext(ctx context.Context, args []driver.NamedValue) (driver.Result, error) {
return s.c.ExecContext(ctx, s.query, args)
}
func (s *testKitStmt) Query(args []driver.Value) (driver.Rows, error) {
return nil, driver.ErrSkip
}
func (s *testKitStmt) QueryContext(ctx context.Context, args []driver.NamedValue) (driver.Rows, error) {
return s.c.QueryContext(ctx, s.query, args)
}
type testKitResult struct {
lastInsertID int64
rowsAffected int64
}
func (r *testKitResult) LastInsertId() (int64, error) {
return r.lastInsertID, nil
}
func (r *testKitResult) RowsAffected() (int64, error) {
return r.rowsAffected, nil
}
type testKitRows struct {
rs sqlexec.RecordSet
chunk *chunk.Chunk
it *chunk.Iterator4Chunk
}
func (r *testKitRows) Columns() []string {
fields := r.rs.Fields()
cols := make([]string, len(fields))
for i, f := range fields {
cols[i] = f.Column.Name.O
}
return cols
}
func (r *testKitRows) Close() error {
return r.rs.Close()
}
func (r *testKitRows) Next(dest []driver.Value) error {
if r.chunk == nil {
r.chunk = r.rs.NewChunk(nil)
}
var row chunk.Row
if r.it == nil {
err := r.rs.Next(context.Background(), r.chunk)
if err != nil {
return err
}
if r.chunk.NumRows() == 0 {
return io.EOF
}
r.it = chunk.NewIterator4Chunk(r.chunk)
row = r.it.Begin()
} else {
row = r.it.Next()
if row.IsEmpty() {
err := r.rs.Next(context.Background(), r.chunk)
if err != nil {
return err
}
if r.chunk.NumRows() == 0 {
return io.EOF
}
r.it = chunk.NewIterator4Chunk(r.chunk)
row = r.it.Begin()
}
}
for i := range row.Len() {
d := row.GetDatum(i, &r.rs.Fields()[i].Column.FieldType)
// Handle NULL
if d.IsNull() {
dest[i] = nil
} else {
// Convert to appropriate type if needed, or just return string/bytes/int/float
// driver.Value allows int64, float64, bool, []byte, string, time.Time
// Datum.GetValue() returns interface{} which might be compatible.
v := d.GetValue()
switch x := v.(type) {
case []byte:
dest[i] = x
case string:
dest[i] = x
case int64:
dest[i] = x
case uint64:
dest[i] = x
case float64:
dest[i] = x
case float32:
dest[i] = x
case types.Time:
dest[i] = x.String()
case types.Duration:
dest[i] = x.String()
case *types.MyDecimal:
dest[i] = x.String()
default:
dest[i] = x
}
}
}
return nil
}
type testKitTxn struct {
c *testKitConn
}
func (t *testKitTxn) Commit() error {
_, err := t.c.Exec("COMMIT", nil)
return err
}
func (t *testKitTxn) Rollback() error {
_, err := t.c.Exec("ROLLBACK", nil)
return err
}

View File

@ -0,0 +1,79 @@
// Copyright 2025 PingCAP, Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package testkit_test
import (
"testing"
"github.com/pingcap/tidb/pkg/testkit"
"github.com/stretchr/testify/require"
)
func TestMockDB(t *testing.T) {
store := testkit.CreateMockStore(t)
tk := testkit.NewTestKit(t, store)
db := testkit.CreateMockDB(tk)
defer db.Close()
var err error
_, err = db.Exec("use test")
require.NoError(t, err)
_, err = db.Exec("create table t (id int, v varchar(255))")
require.NoError(t, err)
_, err = db.Exec("insert into t values (1, 'a'), (2, 'b')")
require.NoError(t, err)
// Test Query
rows, err := db.Query("select * from t order by id")
require.NoError(t, err)
defer rows.Close()
var id int
var v string
require.True(t, rows.Next())
require.NoError(t, rows.Scan(&id, &v))
require.Equal(t, 1, id)
require.Equal(t, "a", v)
require.True(t, rows.Next())
require.NoError(t, rows.Scan(&id, &v))
require.Equal(t, 2, id)
require.Equal(t, "b", v)
require.False(t, rows.Next())
require.NoError(t, rows.Err())
// Test Exec
res, err := db.Exec("insert into t values (3, 'c')")
require.NoError(t, err)
affected, err := res.RowsAffected()
require.NoError(t, err)
require.Equal(t, int64(1), affected)
// Verify with TestKit
tk.MustExec("use test")
tk.MustQuery("select * from t where id = 3").Check(testkit.Rows("3 c"))
// Test Prepare
stmt, err := db.Prepare("select v from t where id = ?")
require.NoError(t, err)
defer stmt.Close()
row := stmt.QueryRow(2)
require.NoError(t, row.Scan(&v))
require.Equal(t, "b", v)
}

View File

@ -11,6 +11,7 @@ go_test(
"one_parquet_test.go",
"parquet_test.go",
"precheck_test.go",
"sdk_test.go",
"util_test.go",
],
embedsrcs = [
@ -35,6 +36,7 @@ go_test(
"//pkg/domain",
"//pkg/domain/infosync",
"//pkg/executor/importer",
"//pkg/importsdk",
"//pkg/infoschema",
"//pkg/kv",
"//pkg/lightning/backend/local",

View File

@ -0,0 +1,180 @@
// Copyright 2025 PingCAP, Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package importintotest
import (
"context"
"os"
"path/filepath"
"time"
"github.com/pingcap/tidb/pkg/importsdk"
"github.com/pingcap/tidb/pkg/testkit"
"github.com/pingcap/tidb/pkg/testkit/testfailpoint"
)
func (s *mockGCSSuite) TestImportSDK() {
ctx := context.Background()
// 1. Prepare source data (local files)
tmpDir := s.T().TempDir()
// Create schema files for CreateSchemasAndTables
// DB schema
err := os.WriteFile(filepath.Join(tmpDir, "importsdk_test-schema-create.sql"), []byte("CREATE DATABASE importsdk_test;"), 0644)
s.NoError(err)
// Table schema
err = os.WriteFile(filepath.Join(tmpDir, "importsdk_test.t-schema.sql"), []byte("CREATE TABLE t (id int, v varchar(255));"), 0644)
s.NoError(err)
// Data file: mydumper format: {schema}.{table}.{seq}.csv
fileName := "importsdk_test.t.001.csv"
content := []byte("1,test1\n2,test2")
err = os.WriteFile(filepath.Join(tmpDir, fileName), content, 0644)
s.NoError(err)
// 2. Prepare DB connection
// Ensure clean state
s.tk.MustExec("DROP DATABASE IF EXISTS importsdk_test")
// Create a *sql.DB using the testkit driver
db := testkit.CreateMockDB(s.tk)
defer db.Close()
// 3. Initialize SDK
// Use local file path.
sdk, err := importsdk.NewImportSDK(ctx, "file://"+tmpDir, db)
s.NoError(err)
defer sdk.Close()
// 4. Test FileScanner.CreateSchemasAndTables
err = sdk.CreateSchemasAndTables(ctx)
s.NoError(err)
// Verify DB and Table exist
s.tk.MustExec("USE importsdk_test")
s.tk.MustExec("SHOW CREATE TABLE t")
// 4.1 Test FileScanner.CreateSchemaAndTableByName
s.tk.MustExec("DROP TABLE t")
err = sdk.CreateSchemaAndTableByName(ctx, "importsdk_test", "t")
s.NoError(err)
s.tk.MustExec("SHOW CREATE TABLE t")
// 5. Test FileScanner.GetTableMetas
metas, err := sdk.GetTableMetas(ctx)
s.NoError(err)
s.Len(metas, 1)
s.Equal("importsdk_test", metas[0].Database)
s.Equal("t", metas[0].Table)
s.Equal(int64(len(content)), metas[0].TotalSize)
// 6. Test FileScanner.GetTableMetaByName
meta, err := sdk.GetTableMetaByName(ctx, "importsdk_test", "t")
s.NoError(err)
s.Equal("importsdk_test", meta.Database)
s.Equal("t", meta.Table)
// 7. Test FileScanner.GetTotalSize
totalSize := sdk.GetTotalSize(ctx)
s.Equal(int64(len(content)), totalSize)
// 8. Test SQLGenerator.GenerateImportSQL
opts := &importsdk.ImportOptions{
Thread: 4,
Detached: true,
}
sql, err := sdk.GenerateImportSQL(meta, opts)
s.NoError(err)
s.Contains(sql, "IMPORT INTO `importsdk_test`.`t` FROM")
s.Contains(sql, "THREAD=4")
s.Contains(sql, "DETACHED")
// 9. Test JobManager.SubmitJob
jobID, err := sdk.SubmitJob(ctx, sql)
s.NoError(err)
s.Greater(jobID, int64(0))
// 10. Test JobManager.GetJobStatus
status, err := sdk.GetJobStatus(ctx, jobID)
s.NoError(err)
s.Equal(jobID, status.JobID)
s.Equal("`importsdk_test`.`t`", status.TargetTable)
// Wait for job to finish
s.Eventually(func() bool {
status, err := sdk.GetJobStatus(ctx, jobID)
s.NoError(err)
return status.Status == "finished"
}, 30*time.Second, 500*time.Millisecond)
// Verify data
s.tk.MustQuery("SELECT * FROM importsdk_test.t").Check(testkit.Rows("1 test1", "2 test2"))
// 11. Test JobManager.CancelJob with failpoint
s.tk.MustExec("TRUNCATE TABLE t")
testfailpoint.Enable(s.T(), "github.com/pingcap/tidb/pkg/disttask/importinto/syncBeforeJobStarted", "pause")
jobID2, err := sdk.SubmitJob(ctx, sql)
s.NoError(err)
// Cancel the job
err = sdk.CancelJob(ctx, jobID2)
s.NoError(err)
// Unpause
testfailpoint.Disable(s.T(), "github.com/pingcap/tidb/pkg/disttask/importinto/syncBeforeJobStarted")
// Verify status is cancelled
s.Eventually(func() bool {
status, err := sdk.GetJobStatus(ctx, jobID2)
s.NoError(err)
return status.Status == "cancelled"
}, 30*time.Second, 500*time.Millisecond)
// 12. Test JobManager.GetGroupSummary
// Submit a job with GroupKey
groupKey := "test_group_key"
optsWithGroup := &importsdk.ImportOptions{
Thread: 4,
Detached: true,
GroupKey: groupKey,
}
sqlWithGroup, err := sdk.GenerateImportSQL(meta, optsWithGroup)
s.NoError(err)
jobID3, err := sdk.SubmitJob(ctx, sqlWithGroup)
s.NoError(err)
s.Greater(jobID3, int64(0))
// Wait for job to finish
s.Eventually(func() bool {
status, err := sdk.GetJobStatus(ctx, jobID3)
s.NoError(err)
return status.Status == "finished"
}, 30*time.Second, 500*time.Millisecond)
// Get group summary
groupSummary, err := sdk.GetGroupSummary(ctx, groupKey)
s.NoError(err)
s.Equal(groupKey, groupSummary.GroupKey)
s.Equal(int64(1), groupSummary.TotalJobs)
s.Equal(int64(1), groupSummary.Completed)
// 13. Test JobManager.GetJobsByGroup
jobs, err := sdk.GetJobsByGroup(ctx, groupKey)
s.NoError(err)
s.Len(jobs, 1)
s.Equal(jobID3, jobs[0].JobID)
s.Equal("finished", jobs[0].Status)
}