importinto: refactor and add interface for import into sdk (#65094)
ref pingcap/tidb#65092
This commit is contained in:
1
Makefile
1
Makefile
@ -546,6 +546,7 @@ mock_import: mockgen
|
||||
tools/bin/mockgen -package mock github.com/pingcap/tidb/pkg/disttask/framework/planner LogicalPlan,PipelineSpec > pkg/disttask/framework/mock/plan_mock.go
|
||||
tools/bin/mockgen -package mock github.com/pingcap/tidb/pkg/disttask/framework/storage Manager > pkg/disttask/framework/mock/storage_manager_mock.go
|
||||
tools/bin/mockgen -package mock github.com/pingcap/tidb/pkg/ingestor/ingestcli Client,WriteClient > pkg/ingestor/ingestcli/mock/client_mock.go
|
||||
tools/bin/mockgen -package mock github.com/pingcap/tidb/pkg/importsdk FileScanner,JobManager,SQLGenerator,SDK > pkg/importsdk/mock/sdk_mock.go
|
||||
|
||||
.PHONY: gen_mock
|
||||
gen_mock: mockgen
|
||||
|
||||
@ -2,26 +2,45 @@ load("@io_bazel_rules_go//go:def.bzl", "go_library", "go_test")
|
||||
|
||||
go_library(
|
||||
name = "importsdk",
|
||||
srcs = ["sdk.go"],
|
||||
srcs = [
|
||||
"config.go",
|
||||
"error.go",
|
||||
"file_scanner.go",
|
||||
"job_manager.go",
|
||||
"model.go",
|
||||
"pattern.go",
|
||||
"sdk.go",
|
||||
"sql_generator.go",
|
||||
],
|
||||
importpath = "github.com/pingcap/tidb/pkg/importsdk",
|
||||
visibility = ["//visibility:public"],
|
||||
deps = [
|
||||
"//br/pkg/storage",
|
||||
"//pkg/lightning/common",
|
||||
"//pkg/lightning/config",
|
||||
"//pkg/lightning/log",
|
||||
"//pkg/lightning/mydump",
|
||||
"//pkg/parser/mysql",
|
||||
"@com_github_pingcap_errors//:errors",
|
||||
"@org_uber_go_zap//:zap",
|
||||
],
|
||||
)
|
||||
|
||||
go_test(
|
||||
name = "importsdk_test",
|
||||
timeout = "short",
|
||||
srcs = ["sdk_test.go"],
|
||||
srcs = [
|
||||
"config_test.go",
|
||||
"file_scanner_test.go",
|
||||
"job_manager_test.go",
|
||||
"model_test.go",
|
||||
"pattern_test.go",
|
||||
"sdk_test.go",
|
||||
"sql_generator_test.go",
|
||||
],
|
||||
embed = [":importsdk"],
|
||||
flaky = True,
|
||||
shard_count = 9,
|
||||
shard_count = 20,
|
||||
deps = [
|
||||
"//pkg/lightning/config",
|
||||
"//pkg/lightning/log",
|
||||
|
||||
118
pkg/importsdk/config.go
Normal file
118
pkg/importsdk/config.go
Normal file
@ -0,0 +1,118 @@
|
||||
// Copyright 2025 PingCAP, Inc.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package importsdk
|
||||
|
||||
import (
|
||||
"github.com/pingcap/tidb/pkg/lightning/config"
|
||||
"github.com/pingcap/tidb/pkg/lightning/log"
|
||||
"github.com/pingcap/tidb/pkg/parser/mysql"
|
||||
)
|
||||
|
||||
// SDKOption customizes the SDK configuration
|
||||
type SDKOption func(*SDKConfig)
|
||||
|
||||
// SDKConfig is the configuration for the SDK
|
||||
type SDKConfig struct {
|
||||
// Loader options
|
||||
concurrency int
|
||||
sqlMode mysql.SQLMode
|
||||
fileRouteRules []*config.FileRouteRule
|
||||
routes config.Routes
|
||||
filter []string
|
||||
charset string
|
||||
maxScanFiles *int
|
||||
skipInvalidFiles bool
|
||||
|
||||
// General options
|
||||
logger log.Logger
|
||||
}
|
||||
|
||||
func defaultSDKConfig() *SDKConfig {
|
||||
return &SDKConfig{
|
||||
concurrency: 4,
|
||||
filter: config.GetDefaultFilter(),
|
||||
logger: log.L(),
|
||||
charset: "auto",
|
||||
}
|
||||
}
|
||||
|
||||
// WithConcurrency sets the number of concurrent DB/Table creation workers.
|
||||
func WithConcurrency(n int) SDKOption {
|
||||
return func(cfg *SDKConfig) {
|
||||
if n > 0 {
|
||||
cfg.concurrency = n
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// WithLogger specifies a custom logger
|
||||
func WithLogger(logger log.Logger) SDKOption {
|
||||
return func(cfg *SDKConfig) {
|
||||
cfg.logger = logger
|
||||
}
|
||||
}
|
||||
|
||||
// WithSQLMode specifies the SQL mode for schema parsing
|
||||
func WithSQLMode(mode mysql.SQLMode) SDKOption {
|
||||
return func(cfg *SDKConfig) {
|
||||
cfg.sqlMode = mode
|
||||
}
|
||||
}
|
||||
|
||||
// WithFilter specifies a filter for the loader
|
||||
func WithFilter(filter []string) SDKOption {
|
||||
return func(cfg *SDKConfig) {
|
||||
cfg.filter = filter
|
||||
}
|
||||
}
|
||||
|
||||
// WithFileRouters sets the file routing rules.
|
||||
func WithFileRouters(rules []*config.FileRouteRule) SDKOption {
|
||||
return func(c *SDKConfig) {
|
||||
c.fileRouteRules = rules
|
||||
}
|
||||
}
|
||||
|
||||
// WithRoutes sets the table routing rules.
|
||||
func WithRoutes(routes config.Routes) SDKOption {
|
||||
return func(c *SDKConfig) {
|
||||
c.routes = routes
|
||||
}
|
||||
}
|
||||
|
||||
// WithCharset specifies the character set for import (default "auto").
|
||||
func WithCharset(cs string) SDKOption {
|
||||
return func(cfg *SDKConfig) {
|
||||
if cs != "" {
|
||||
cfg.charset = cs
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// WithMaxScanFiles specifies custom file scan limitation
|
||||
func WithMaxScanFiles(limit int) SDKOption {
|
||||
return func(cfg *SDKConfig) {
|
||||
if limit > 0 {
|
||||
cfg.maxScanFiles = &limit
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// WithSkipInvalidFiles specifies whether sdk need raise error on found invalid files
|
||||
func WithSkipInvalidFiles(skip bool) SDKOption {
|
||||
return func(cfg *SDKConfig) {
|
||||
cfg.skipInvalidFiles = skip
|
||||
}
|
||||
}
|
||||
89
pkg/importsdk/config_test.go
Normal file
89
pkg/importsdk/config_test.go
Normal file
@ -0,0 +1,89 @@
|
||||
// Copyright 2025 PingCAP, Inc.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package importsdk
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/pingcap/tidb/pkg/lightning/config"
|
||||
"github.com/pingcap/tidb/pkg/lightning/log"
|
||||
"github.com/pingcap/tidb/pkg/parser/mysql"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestDefaultSDKConfig(t *testing.T) {
|
||||
cfg := defaultSDKConfig()
|
||||
require.Equal(t, 4, cfg.concurrency)
|
||||
require.Equal(t, config.GetDefaultFilter(), cfg.filter)
|
||||
require.Equal(t, log.L(), cfg.logger)
|
||||
require.Equal(t, "auto", cfg.charset)
|
||||
}
|
||||
|
||||
func TestSDKOptions(t *testing.T) {
|
||||
cfg := defaultSDKConfig()
|
||||
|
||||
// Test WithConcurrency
|
||||
WithConcurrency(10)(cfg)
|
||||
require.Equal(t, 10, cfg.concurrency)
|
||||
WithConcurrency(-1)(cfg) // Should ignore invalid value
|
||||
require.Equal(t, 10, cfg.concurrency)
|
||||
|
||||
// Test WithLogger
|
||||
logger := log.L()
|
||||
WithLogger(logger)(cfg)
|
||||
require.Equal(t, logger, cfg.logger)
|
||||
|
||||
// Test WithSQLMode
|
||||
mode := mysql.ModeStrictTransTables
|
||||
WithSQLMode(mode)(cfg)
|
||||
require.Equal(t, mode, cfg.sqlMode)
|
||||
|
||||
// Test WithFilter
|
||||
filter := []string{"*.*"}
|
||||
WithFilter(filter)(cfg)
|
||||
require.Equal(t, filter, cfg.filter)
|
||||
|
||||
// Test WithFileRouters
|
||||
routers := []*config.FileRouteRule{{Schema: "test"}}
|
||||
WithFileRouters(routers)(cfg)
|
||||
require.Equal(t, routers, cfg.fileRouteRules)
|
||||
|
||||
// Test WithCharset
|
||||
WithCharset("utf8mb4")(cfg)
|
||||
require.Equal(t, "utf8mb4", cfg.charset)
|
||||
WithCharset("")(cfg) // Should ignore empty value
|
||||
require.Equal(t, "utf8mb4", cfg.charset)
|
||||
|
||||
// Test WithMaxScanFiles
|
||||
WithMaxScanFiles(100)(cfg)
|
||||
require.NotNil(t, cfg.maxScanFiles)
|
||||
require.Equal(t, 100, *cfg.maxScanFiles)
|
||||
|
||||
// Test WithSkipInvalidFiles
|
||||
WithSkipInvalidFiles(true)(cfg)
|
||||
require.True(t, cfg.skipInvalidFiles)
|
||||
|
||||
// Test WithRoutes
|
||||
routes := config.Routes{
|
||||
{
|
||||
SchemaPattern: "source_db",
|
||||
TablePattern: "source_table",
|
||||
TargetSchema: "target_db",
|
||||
TargetTable: "target_table",
|
||||
},
|
||||
}
|
||||
WithRoutes(routes)(cfg)
|
||||
require.Equal(t, routes, cfg.routes)
|
||||
}
|
||||
46
pkg/importsdk/error.go
Normal file
46
pkg/importsdk/error.go
Normal file
@ -0,0 +1,46 @@
|
||||
// Copyright 2025 PingCAP, Inc.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package importsdk
|
||||
|
||||
import "github.com/pingcap/errors"
|
||||
|
||||
var (
|
||||
// ErrNoDatabasesFound indicates that the dump source contains no recognizable databases.
|
||||
ErrNoDatabasesFound = errors.New("no databases found in the source path")
|
||||
// ErrSchemaNotFound indicates the target schema doesn't exist in the dump source.
|
||||
ErrSchemaNotFound = errors.New("schema not found")
|
||||
// ErrTableNotFound indicates the target table doesn't exist in the dump source.
|
||||
ErrTableNotFound = errors.New("table not found")
|
||||
// ErrNoTableDataFiles indicates a table has zero data files and thus cannot proceed.
|
||||
ErrNoTableDataFiles = errors.New("no data files for table")
|
||||
// ErrWildcardNotSpecific indicates a wildcard cannot uniquely match the table's files.
|
||||
ErrWildcardNotSpecific = errors.New("cannot generate a unique wildcard pattern for the table's data files")
|
||||
// ErrJobNotFound indicates the job is not found.
|
||||
ErrJobNotFound = errors.New("job not found")
|
||||
// ErrNoJobIDReturned indicates that the submit job query did not return a job ID.
|
||||
ErrNoJobIDReturned = errors.New("no job id returned")
|
||||
// ErrInvalidOptions indicates the options are invalid.
|
||||
ErrInvalidOptions = errors.New("invalid options")
|
||||
// ErrMultipleFieldsDefinedNullBy indicates that multiple FIELDS_DEFINED_NULL_BY values are defined, which is not supported.
|
||||
ErrMultipleFieldsDefinedNullBy = errors.New("IMPORT INTO only supports one FIELDS_DEFINED_NULL_BY value")
|
||||
// ErrParseStorageURL indicates that the storage backend URL is invalid.
|
||||
ErrParseStorageURL = errors.New("failed to parse storage backend URL")
|
||||
// ErrCreateExternalStorage indicates that the external storage cannot be created.
|
||||
ErrCreateExternalStorage = errors.New("failed to create external storage")
|
||||
// ErrCreateLoader indicates that the MyDump loader cannot be created.
|
||||
ErrCreateLoader = errors.New("failed to create MyDump loader")
|
||||
// ErrCreateSchema indicates that creating schemas and tables failed.
|
||||
ErrCreateSchema = errors.New("failed to create schemas and tables")
|
||||
)
|
||||
271
pkg/importsdk/file_scanner.go
Normal file
271
pkg/importsdk/file_scanner.go
Normal file
@ -0,0 +1,271 @@
|
||||
// Copyright 2025 PingCAP, Inc.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package importsdk
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"strings"
|
||||
|
||||
"github.com/pingcap/errors"
|
||||
"github.com/pingcap/tidb/br/pkg/storage"
|
||||
"github.com/pingcap/tidb/pkg/lightning/common"
|
||||
"github.com/pingcap/tidb/pkg/lightning/log"
|
||||
"github.com/pingcap/tidb/pkg/lightning/mydump"
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
// FileScanner defines the interface for scanning files
|
||||
type FileScanner interface {
|
||||
CreateSchemasAndTables(ctx context.Context) error
|
||||
CreateSchemaAndTableByName(ctx context.Context, schema, table string) error
|
||||
GetTableMetas(ctx context.Context) ([]*TableMeta, error)
|
||||
GetTableMetaByName(ctx context.Context, db, table string) (*TableMeta, error)
|
||||
GetTotalSize(ctx context.Context) int64
|
||||
Close() error
|
||||
}
|
||||
|
||||
type fileScanner struct {
|
||||
sourcePath string
|
||||
db *sql.DB
|
||||
store storage.ExternalStorage
|
||||
loader *mydump.MDLoader
|
||||
logger log.Logger
|
||||
config *SDKConfig
|
||||
}
|
||||
|
||||
// NewFileScanner creates a new FileScanner
|
||||
func NewFileScanner(ctx context.Context, sourcePath string, db *sql.DB, cfg *SDKConfig) (FileScanner, error) {
|
||||
u, err := storage.ParseBackend(sourcePath, nil)
|
||||
if err != nil {
|
||||
return nil, errors.Annotatef(ErrParseStorageURL, "source=%s, err=%v", sourcePath, err)
|
||||
}
|
||||
store, err := storage.New(ctx, u, &storage.ExternalStorageOptions{})
|
||||
if err != nil {
|
||||
return nil, errors.Annotatef(ErrCreateExternalStorage, "source=%s, err=%v", sourcePath, err)
|
||||
}
|
||||
|
||||
ldrCfg := mydump.LoaderConfig{
|
||||
SourceURL: sourcePath,
|
||||
Filter: cfg.filter,
|
||||
FileRouters: cfg.fileRouteRules,
|
||||
DefaultFileRules: len(cfg.fileRouteRules) == 0,
|
||||
CharacterSet: cfg.charset,
|
||||
Routes: cfg.routes,
|
||||
}
|
||||
|
||||
var loaderOptions []mydump.MDLoaderSetupOption
|
||||
if cfg.maxScanFiles != nil && *cfg.maxScanFiles > 0 {
|
||||
loaderOptions = append(loaderOptions, mydump.WithMaxScanFiles(*cfg.maxScanFiles))
|
||||
}
|
||||
if cfg.concurrency > 0 {
|
||||
loaderOptions = append(loaderOptions, mydump.WithScanFileConcurrency(cfg.concurrency))
|
||||
}
|
||||
|
||||
// TODO: we can skip some time-consuming operation in constructFileInfo (like get real size of compressed file).
|
||||
loader, err := mydump.NewLoaderWithStore(ctx, ldrCfg, store, loaderOptions...)
|
||||
if err != nil {
|
||||
if loader == nil || !errors.ErrorEqual(err, common.ErrTooManySourceFiles) {
|
||||
return nil, errors.Annotatef(ErrCreateLoader, "source=%s, charset=%s, err=%v", sourcePath, cfg.charset, err)
|
||||
}
|
||||
}
|
||||
|
||||
return &fileScanner{
|
||||
sourcePath: sourcePath,
|
||||
db: db,
|
||||
store: store,
|
||||
loader: loader,
|
||||
logger: cfg.logger,
|
||||
config: cfg,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (s *fileScanner) CreateSchemasAndTables(ctx context.Context) error {
|
||||
dbMetas := s.loader.GetDatabases()
|
||||
if len(dbMetas) == 0 {
|
||||
return errors.Annotatef(ErrNoDatabasesFound, "source=%s", s.sourcePath)
|
||||
}
|
||||
|
||||
// Create all schemas and tables
|
||||
importer := mydump.NewSchemaImporter(
|
||||
s.logger,
|
||||
s.config.sqlMode,
|
||||
s.db,
|
||||
s.store,
|
||||
s.config.concurrency,
|
||||
)
|
||||
|
||||
err := importer.Run(ctx, dbMetas)
|
||||
if err != nil {
|
||||
return errors.Annotatef(ErrCreateSchema, "source=%s, db_count=%d, err=%v", s.sourcePath, len(dbMetas), err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// CreateSchemaAndTableByName creates specific table and database schema from source
|
||||
func (s *fileScanner) CreateSchemaAndTableByName(ctx context.Context, schema, table string) error {
|
||||
dbMetas := s.loader.GetDatabases()
|
||||
// Find the specific table
|
||||
for _, dbMeta := range dbMetas {
|
||||
if dbMeta.Name != schema {
|
||||
continue
|
||||
}
|
||||
|
||||
for _, tblMeta := range dbMeta.Tables {
|
||||
if tblMeta.Name != table {
|
||||
continue
|
||||
}
|
||||
|
||||
importer := mydump.NewSchemaImporter(
|
||||
s.logger,
|
||||
s.config.sqlMode,
|
||||
s.db,
|
||||
s.store,
|
||||
s.config.concurrency,
|
||||
)
|
||||
|
||||
err := importer.Run(ctx, []*mydump.MDDatabaseMeta{{
|
||||
Name: dbMeta.Name,
|
||||
SchemaFile: dbMeta.SchemaFile,
|
||||
Tables: []*mydump.MDTableMeta{tblMeta},
|
||||
}})
|
||||
if err != nil {
|
||||
return errors.Annotatef(ErrCreateSchema, "source=%s, schema=%s, table=%s, err=%v", s.sourcePath, schema, table, err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
return errors.Annotatef(ErrTableNotFound, "schema=%s, table=%s", schema, table)
|
||||
}
|
||||
|
||||
return errors.Annotatef(ErrSchemaNotFound, "schema=%s", schema)
|
||||
}
|
||||
|
||||
func (s *fileScanner) GetTableMetas(context.Context) ([]*TableMeta, error) {
|
||||
dbMetas := s.loader.GetDatabases()
|
||||
allFiles := s.loader.GetAllFiles()
|
||||
var results []*TableMeta
|
||||
for _, dbMeta := range dbMetas {
|
||||
for _, tblMeta := range dbMeta.Tables {
|
||||
tableMeta, err := s.buildTableMeta(dbMeta, tblMeta, allFiles)
|
||||
if err != nil {
|
||||
if s.config.skipInvalidFiles {
|
||||
s.logger.Warn("skipping table due to invalid files", zap.String("database", dbMeta.Name), zap.String("table", tblMeta.Name), zap.Error(err))
|
||||
continue
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
results = append(results, tableMeta)
|
||||
}
|
||||
}
|
||||
|
||||
return results, nil
|
||||
}
|
||||
|
||||
func (s *fileScanner) GetTotalSize(ctx context.Context) int64 {
|
||||
var total int64
|
||||
dbMetas := s.loader.GetDatabases()
|
||||
for _, dbMeta := range dbMetas {
|
||||
for _, tblMeta := range dbMeta.Tables {
|
||||
total += tblMeta.TotalSize
|
||||
}
|
||||
}
|
||||
return total
|
||||
}
|
||||
|
||||
func (s *fileScanner) Close() error {
|
||||
if s.store != nil {
|
||||
s.store.Close()
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *fileScanner) buildTableMeta(
|
||||
dbMeta *mydump.MDDatabaseMeta,
|
||||
tblMeta *mydump.MDTableMeta,
|
||||
allDataFiles map[string]mydump.FileInfo,
|
||||
) (*TableMeta, error) {
|
||||
tableMeta := &TableMeta{
|
||||
Database: dbMeta.Name,
|
||||
Table: tblMeta.Name,
|
||||
DataFiles: make([]DataFileMeta, 0, len(tblMeta.DataFiles)),
|
||||
SchemaFile: tblMeta.SchemaFile.FileMeta.Path,
|
||||
}
|
||||
|
||||
// Process data files
|
||||
dataFiles, totalSize := processDataFiles(tblMeta.DataFiles)
|
||||
tableMeta.DataFiles = dataFiles
|
||||
tableMeta.TotalSize = totalSize
|
||||
|
||||
if len(tblMeta.DataFiles) == 0 {
|
||||
s.logger.Warn("table has no data files", zap.String("database", dbMeta.Name), zap.String("table", tblMeta.Name))
|
||||
return tableMeta, nil
|
||||
}
|
||||
|
||||
wildcard, err := generateWildcardPath(tblMeta.DataFiles, allDataFiles)
|
||||
if err != nil {
|
||||
return nil, errors.Trace(err)
|
||||
}
|
||||
uri := s.store.URI()
|
||||
// import into only support absolute path
|
||||
uri = strings.TrimPrefix(uri, "file://")
|
||||
tableMeta.WildcardPath = strings.TrimSuffix(uri, "/") + "/" + wildcard
|
||||
|
||||
return tableMeta, nil
|
||||
}
|
||||
|
||||
// processDataFiles converts mydump data files to DataFileMeta and calculates total size
|
||||
func processDataFiles(files []mydump.FileInfo) ([]DataFileMeta, int64) {
|
||||
dataFiles := make([]DataFileMeta, 0, len(files))
|
||||
var totalSize int64
|
||||
|
||||
for _, dataFile := range files {
|
||||
fileMeta := createDataFileMeta(dataFile)
|
||||
dataFiles = append(dataFiles, fileMeta)
|
||||
totalSize += dataFile.FileMeta.RealSize
|
||||
}
|
||||
|
||||
return dataFiles, totalSize
|
||||
}
|
||||
|
||||
// createDataFileMeta creates a DataFileMeta from a mydump.DataFile
|
||||
func createDataFileMeta(file mydump.FileInfo) DataFileMeta {
|
||||
return DataFileMeta{
|
||||
Path: file.FileMeta.Path,
|
||||
Size: file.FileMeta.RealSize,
|
||||
Format: file.FileMeta.Type,
|
||||
Compression: file.FileMeta.Compression,
|
||||
}
|
||||
}
|
||||
|
||||
func (s *fileScanner) GetTableMetaByName(ctx context.Context, db, table string) (*TableMeta, error) {
|
||||
dbMetas := s.loader.GetDatabases()
|
||||
allFiles := s.loader.GetAllFiles()
|
||||
|
||||
for _, dbMeta := range dbMetas {
|
||||
if dbMeta.Name != db {
|
||||
continue
|
||||
}
|
||||
for _, tblMeta := range dbMeta.Tables {
|
||||
if tblMeta.Name != table {
|
||||
continue
|
||||
}
|
||||
return s.buildTableMeta(dbMeta, tblMeta, allFiles)
|
||||
}
|
||||
}
|
||||
return nil, errors.Annotatef(ErrTableNotFound, "table %s.%s not found", db, table)
|
||||
}
|
||||
177
pkg/importsdk/file_scanner_test.go
Normal file
177
pkg/importsdk/file_scanner_test.go
Normal file
@ -0,0 +1,177 @@
|
||||
// Copyright 2025 PingCAP, Inc.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package importsdk
|
||||
|
||||
import (
|
||||
"context"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"regexp"
|
||||
"testing"
|
||||
|
||||
"github.com/DATA-DOG/go-sqlmock"
|
||||
"github.com/pingcap/tidb/pkg/lightning/config"
|
||||
"github.com/pingcap/tidb/pkg/lightning/mydump"
|
||||
filter "github.com/pingcap/tidb/pkg/util/table-filter"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestCreateDataFileMeta(t *testing.T) {
|
||||
fi := mydump.FileInfo{
|
||||
TableName: filter.Table{
|
||||
Schema: "db",
|
||||
Name: "table",
|
||||
},
|
||||
FileMeta: mydump.SourceFileMeta{
|
||||
Path: "s3://bucket/path/to/f",
|
||||
FileSize: 123,
|
||||
Type: mydump.SourceTypeCSV,
|
||||
Compression: mydump.CompressionGZ,
|
||||
RealSize: 456,
|
||||
},
|
||||
}
|
||||
df := createDataFileMeta(fi)
|
||||
require.Equal(t, "s3://bucket/path/to/f", df.Path)
|
||||
require.Equal(t, int64(456), df.Size)
|
||||
require.Equal(t, mydump.SourceTypeCSV, df.Format)
|
||||
require.Equal(t, mydump.CompressionGZ, df.Compression)
|
||||
}
|
||||
|
||||
func TestProcessDataFiles(t *testing.T) {
|
||||
files := []mydump.FileInfo{
|
||||
{FileMeta: mydump.SourceFileMeta{Path: "s3://bucket/a", RealSize: 10}},
|
||||
{FileMeta: mydump.SourceFileMeta{Path: "s3://bucket/b", RealSize: 20}},
|
||||
}
|
||||
dfm, total := processDataFiles(files)
|
||||
require.Len(t, dfm, 2)
|
||||
require.Equal(t, int64(30), total)
|
||||
require.Equal(t, "s3://bucket/a", dfm[0].Path)
|
||||
require.Equal(t, "s3://bucket/b", dfm[1].Path)
|
||||
}
|
||||
|
||||
func TestFileScanner(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
ctx := context.Background()
|
||||
|
||||
require.NoError(t, os.WriteFile(filepath.Join(tmpDir, "db1-schema-create.sql"), []byte("CREATE DATABASE db1;"), 0644))
|
||||
require.NoError(t, os.WriteFile(filepath.Join(tmpDir, "db1.t1-schema.sql"), []byte("CREATE TABLE t1 (id INT);"), 0644))
|
||||
require.NoError(t, os.WriteFile(filepath.Join(tmpDir, "db1.t1.001.csv"), []byte("1\n2"), 0644))
|
||||
|
||||
db, mock, err := sqlmock.New()
|
||||
require.NoError(t, err)
|
||||
defer db.Close()
|
||||
|
||||
cfg := defaultSDKConfig()
|
||||
scanner, err := NewFileScanner(ctx, "file://"+tmpDir, db, cfg)
|
||||
require.NoError(t, err)
|
||||
defer scanner.Close()
|
||||
|
||||
t.Run("GetTotalSize", func(t *testing.T) {
|
||||
size := scanner.GetTotalSize(ctx)
|
||||
require.Equal(t, int64(3), size)
|
||||
})
|
||||
|
||||
t.Run("GetTableMetas", func(t *testing.T) {
|
||||
metas, err := scanner.GetTableMetas(ctx)
|
||||
require.NoError(t, err)
|
||||
require.Len(t, metas, 1)
|
||||
require.Equal(t, "db1", metas[0].Database)
|
||||
require.Equal(t, "t1", metas[0].Table)
|
||||
require.Equal(t, int64(3), metas[0].TotalSize)
|
||||
require.Len(t, metas[0].DataFiles, 1)
|
||||
})
|
||||
|
||||
t.Run("GetTableMetaByName", func(t *testing.T) {
|
||||
meta, err := scanner.GetTableMetaByName(ctx, "db1", "t1")
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, "db1", meta.Database)
|
||||
require.Equal(t, "t1", meta.Table)
|
||||
|
||||
_, err = scanner.GetTableMetaByName(ctx, "db1", "nonexistent")
|
||||
require.Error(t, err)
|
||||
})
|
||||
|
||||
t.Run("CreateSchemasAndTables", func(t *testing.T) {
|
||||
mock.ExpectQuery("SELECT SCHEMA_NAME FROM information_schema.SCHEMATA.*").WillReturnRows(sqlmock.NewRows([]string{"SCHEMA_NAME"}))
|
||||
mock.ExpectExec(regexp.QuoteMeta("CREATE DATABASE IF NOT EXISTS `db1`")).WillReturnResult(sqlmock.NewResult(0, 0))
|
||||
mock.ExpectExec(regexp.QuoteMeta("CREATE TABLE IF NOT EXISTS `db1`.`t1`")).WillReturnResult(sqlmock.NewResult(0, 0))
|
||||
|
||||
err := scanner.CreateSchemasAndTables(ctx)
|
||||
require.NoError(t, err)
|
||||
require.NoError(t, mock.ExpectationsWereMet())
|
||||
})
|
||||
|
||||
t.Run("CreateSchemaAndTableByName", func(t *testing.T) {
|
||||
mock.ExpectQuery("SELECT SCHEMA_NAME FROM information_schema.SCHEMATA.*").WillReturnRows(sqlmock.NewRows([]string{"SCHEMA_NAME"}))
|
||||
mock.ExpectExec(regexp.QuoteMeta("CREATE DATABASE IF NOT EXISTS `db1`")).WillReturnResult(sqlmock.NewResult(0, 0))
|
||||
mock.ExpectExec(regexp.QuoteMeta("CREATE TABLE IF NOT EXISTS `db1`.`t1`")).WillReturnResult(sqlmock.NewResult(0, 0))
|
||||
|
||||
err := scanner.CreateSchemaAndTableByName(ctx, "db1", "t1")
|
||||
require.NoError(t, err)
|
||||
require.NoError(t, mock.ExpectationsWereMet())
|
||||
|
||||
err = scanner.CreateSchemaAndTableByName(ctx, "db1", "nonexistent")
|
||||
require.Error(t, err)
|
||||
})
|
||||
}
|
||||
|
||||
func TestFileScannerWithSkipInvalidFiles(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
ctx := context.Background()
|
||||
|
||||
require.NoError(t, os.WriteFile(filepath.Join(tmpDir, "data1.csv"), []byte("1"), 0644))
|
||||
require.NoError(t, os.WriteFile(filepath.Join(tmpDir, "data2.csv"), []byte("1"), 0644))
|
||||
require.NoError(t, os.WriteFile(filepath.Join(tmpDir, "data3.csv"), []byte("1"), 0644))
|
||||
|
||||
db, _, err := sqlmock.New()
|
||||
require.NoError(t, err)
|
||||
defer db.Close()
|
||||
|
||||
rules := []*config.FileRouteRule{
|
||||
{
|
||||
Pattern: "data[1-2].csv",
|
||||
Schema: "db1",
|
||||
Table: "t1",
|
||||
Type: "csv",
|
||||
},
|
||||
{
|
||||
Pattern: "data3.csv",
|
||||
Schema: "db1",
|
||||
Table: "t2",
|
||||
Type: "csv",
|
||||
},
|
||||
}
|
||||
|
||||
cfg := defaultSDKConfig()
|
||||
WithFileRouters(rules)(cfg)
|
||||
|
||||
scanner, err := NewFileScanner(ctx, "file://"+tmpDir, db, cfg)
|
||||
require.NoError(t, err)
|
||||
defer scanner.Close()
|
||||
|
||||
metas, err := scanner.GetTableMetas(ctx)
|
||||
require.Error(t, err)
|
||||
require.Nil(t, metas)
|
||||
|
||||
cfg.skipInvalidFiles = true
|
||||
scanner2, err := NewFileScanner(ctx, "file://"+tmpDir, db, cfg)
|
||||
require.NoError(t, err)
|
||||
defer scanner2.Close()
|
||||
|
||||
metas2, err := scanner2.GetTableMetas(ctx)
|
||||
require.NoError(t, err)
|
||||
require.Len(t, metas2, 1)
|
||||
require.Equal(t, "t2", metas2[0].Table)
|
||||
}
|
||||
258
pkg/importsdk/job_manager.go
Normal file
258
pkg/importsdk/job_manager.go
Normal file
@ -0,0 +1,258 @@
|
||||
// Copyright 2025 PingCAP, Inc.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package importsdk
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/pingcap/errors"
|
||||
)
|
||||
|
||||
// JobManager defines the interface for managing import jobs
|
||||
type JobManager interface {
|
||||
SubmitJob(ctx context.Context, query string) (int64, error)
|
||||
GetJobStatus(ctx context.Context, jobID int64) (*JobStatus, error)
|
||||
CancelJob(ctx context.Context, jobID int64) error
|
||||
GetGroupSummary(ctx context.Context, groupKey string) (*GroupStatus, error)
|
||||
GetJobsByGroup(ctx context.Context, groupKey string) ([]*JobStatus, error)
|
||||
}
|
||||
|
||||
const timeLayout = "2006-01-02 15:04:05"
|
||||
|
||||
type jobManager struct {
|
||||
db *sql.DB
|
||||
}
|
||||
|
||||
// NewJobManager creates a new JobManager
|
||||
func NewJobManager(db *sql.DB) JobManager {
|
||||
return &jobManager{
|
||||
db: db,
|
||||
}
|
||||
}
|
||||
|
||||
// SubmitJob submits an import job and returns the job ID
|
||||
func (m *jobManager) SubmitJob(ctx context.Context, query string) (int64, error) {
|
||||
rows, err := m.db.QueryContext(ctx, query)
|
||||
if err != nil {
|
||||
return 0, errors.Trace(err)
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
if rows.Next() {
|
||||
status, err := scanJobStatus(rows)
|
||||
if err != nil {
|
||||
return 0, errors.Trace(err)
|
||||
}
|
||||
return status.JobID, nil
|
||||
}
|
||||
|
||||
if err := rows.Err(); err != nil {
|
||||
return 0, errors.Trace(err)
|
||||
}
|
||||
|
||||
return 0, ErrNoJobIDReturned
|
||||
}
|
||||
|
||||
// GetJobStatus gets the status of an import job
|
||||
func (m *jobManager) GetJobStatus(ctx context.Context, jobID int64) (*JobStatus, error) {
|
||||
query := fmt.Sprintf("SHOW IMPORT JOB %d", jobID)
|
||||
rows, err := m.db.QueryContext(ctx, query)
|
||||
if err != nil {
|
||||
return nil, errors.Trace(err)
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
if rows.Next() {
|
||||
return scanJobStatus(rows)
|
||||
}
|
||||
|
||||
if err := rows.Err(); err != nil {
|
||||
return nil, errors.Trace(err)
|
||||
}
|
||||
|
||||
return nil, ErrJobNotFound
|
||||
}
|
||||
|
||||
// GetGroupSummary returns aggregated information for the specified group key.
|
||||
func (m *jobManager) GetGroupSummary(ctx context.Context, groupKey string) (*GroupStatus, error) {
|
||||
if groupKey == "" {
|
||||
return nil, ErrInvalidOptions
|
||||
}
|
||||
query := fmt.Sprintf("SHOW IMPORT GROUP '%s'", strings.ReplaceAll(groupKey, "'", "''"))
|
||||
rows, err := m.db.QueryContext(ctx, query)
|
||||
if err != nil {
|
||||
return nil, errors.Trace(err)
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
if rows.Next() {
|
||||
status, err := scanGroupStatus(rows)
|
||||
if err != nil {
|
||||
return nil, errors.Trace(err)
|
||||
}
|
||||
return status, nil
|
||||
}
|
||||
|
||||
if err := rows.Err(); err != nil {
|
||||
return nil, errors.Trace(err)
|
||||
}
|
||||
return nil, ErrJobNotFound
|
||||
}
|
||||
|
||||
// GetJobsByGroup returns all jobs for the specified group key.
|
||||
func (m *jobManager) GetJobsByGroup(ctx context.Context, groupKey string) ([]*JobStatus, error) {
|
||||
if groupKey == "" {
|
||||
return nil, ErrInvalidOptions
|
||||
}
|
||||
query := fmt.Sprintf("SHOW IMPORT JOBS WHERE GROUP_KEY = '%s'", strings.ReplaceAll(groupKey, "'", "''"))
|
||||
rows, err := m.db.QueryContext(ctx, query)
|
||||
if err != nil {
|
||||
return nil, errors.Trace(err)
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
var jobs []*JobStatus
|
||||
for rows.Next() {
|
||||
status, err := scanJobStatus(rows)
|
||||
if err != nil {
|
||||
return nil, errors.Trace(err)
|
||||
}
|
||||
jobs = append(jobs, status)
|
||||
}
|
||||
|
||||
if err := rows.Err(); err != nil {
|
||||
return nil, errors.Trace(err)
|
||||
}
|
||||
return jobs, nil
|
||||
}
|
||||
|
||||
func scanJobStatus(rows *sql.Rows) (*JobStatus, error) {
|
||||
var (
|
||||
id int64
|
||||
groupKey sql.NullString
|
||||
dataSource string
|
||||
targetTable string
|
||||
tableID int64
|
||||
phase string
|
||||
status string
|
||||
sourceFileSize string
|
||||
importedRows sql.NullInt64
|
||||
resultMessage sql.NullString
|
||||
createTimeStr string
|
||||
startTimeStr sql.NullString
|
||||
endTimeStr sql.NullString
|
||||
createdBy string
|
||||
updateTimeStr sql.NullString
|
||||
step sql.NullString
|
||||
processedSize sql.NullString
|
||||
totalSize sql.NullString
|
||||
percent sql.NullString
|
||||
speed sql.NullString
|
||||
eta sql.NullString
|
||||
)
|
||||
|
||||
err := rows.Scan(
|
||||
&id, &groupKey, &dataSource, &targetTable, &tableID,
|
||||
&phase, &status, &sourceFileSize, &importedRows, &resultMessage,
|
||||
&createTimeStr, &startTimeStr, &endTimeStr, &createdBy, &updateTimeStr,
|
||||
&step, &processedSize, &totalSize, &percent, &speed, &eta,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, errors.Trace(err)
|
||||
}
|
||||
|
||||
// Parse times
|
||||
createTime := parseTime(createTimeStr)
|
||||
startTime := parseNullTime(startTimeStr)
|
||||
endTime := parseNullTime(endTimeStr)
|
||||
updateTime := parseNullTime(updateTimeStr)
|
||||
|
||||
return &JobStatus{
|
||||
JobID: id,
|
||||
GroupKey: groupKey.String,
|
||||
DataSource: dataSource,
|
||||
TargetTable: targetTable,
|
||||
TableID: tableID,
|
||||
Phase: phase,
|
||||
Status: status,
|
||||
SourceFileSize: sourceFileSize,
|
||||
ImportedRows: importedRows.Int64,
|
||||
ResultMessage: resultMessage.String,
|
||||
CreateTime: createTime,
|
||||
StartTime: startTime,
|
||||
EndTime: endTime,
|
||||
CreatedBy: createdBy,
|
||||
UpdateTime: updateTime,
|
||||
Step: step.String,
|
||||
ProcessedSize: processedSize.String,
|
||||
TotalSize: totalSize.String,
|
||||
Percent: percent.String,
|
||||
Speed: speed.String,
|
||||
ETA: eta.String,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func scanGroupStatus(rows *sql.Rows) (*GroupStatus, error) {
|
||||
var (
|
||||
groupKey string
|
||||
totalJobs int64
|
||||
pending int64
|
||||
running int64
|
||||
completed int64
|
||||
failed int64
|
||||
cancelled int64
|
||||
firstCreateTime sql.NullString
|
||||
lastUpdateTime sql.NullString
|
||||
)
|
||||
|
||||
if err := rows.Scan(&groupKey, &totalJobs, &pending, &running, &completed, &failed, &cancelled, &firstCreateTime, &lastUpdateTime); err != nil {
|
||||
return nil, errors.Trace(err)
|
||||
}
|
||||
|
||||
return &GroupStatus{
|
||||
GroupKey: groupKey,
|
||||
TotalJobs: totalJobs,
|
||||
Pending: pending,
|
||||
Running: running,
|
||||
Completed: completed,
|
||||
Failed: failed,
|
||||
Cancelled: cancelled,
|
||||
FirstJobCreateTime: parseNullTime(firstCreateTime),
|
||||
LastJobUpdateTime: parseNullTime(lastUpdateTime),
|
||||
}, nil
|
||||
}
|
||||
|
||||
// CancelJob cancels an import job
|
||||
func (m *jobManager) CancelJob(ctx context.Context, jobID int64) error {
|
||||
query := fmt.Sprintf("CANCEL IMPORT JOB %d", jobID)
|
||||
_, err := m.db.ExecContext(ctx, query)
|
||||
return errors.Trace(err)
|
||||
}
|
||||
|
||||
func parseTime(s string) time.Time {
|
||||
t, _ := time.Parse(timeLayout, s)
|
||||
return t
|
||||
}
|
||||
|
||||
func parseNullTime(ns sql.NullString) time.Time {
|
||||
if !ns.Valid {
|
||||
return time.Time{}
|
||||
}
|
||||
return parseTime(ns.String)
|
||||
}
|
||||
242
pkg/importsdk/job_manager_test.go
Normal file
242
pkg/importsdk/job_manager_test.go
Normal file
@ -0,0 +1,242 @@
|
||||
// Copyright 2025 PingCAP, Inc.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package importsdk
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"testing"
|
||||
|
||||
"github.com/DATA-DOG/go-sqlmock"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestSubmitJob(t *testing.T) {
|
||||
db, mock, err := sqlmock.New()
|
||||
require.NoError(t, err)
|
||||
defer db.Close()
|
||||
|
||||
manager := NewJobManager(db)
|
||||
ctx := context.Background()
|
||||
sqlQuery := "IMPORT INTO ..."
|
||||
|
||||
// Columns expected by scanJobStatus
|
||||
cols := []string{
|
||||
"Job_ID", "Group_Key", "Data_Source", "Target_Table", "Table_ID",
|
||||
"Phase", "Status", "Source_File_Size", "Imported_Rows", "Result_Message",
|
||||
"Create_Time", "Start_Time", "End_Time", "Created_By", "Update_Time",
|
||||
"Step", "Processed_Size", "Total_Size", "Percent", "Speed", "ETA",
|
||||
}
|
||||
|
||||
// Case 1: Success
|
||||
rows := sqlmock.NewRows(cols).AddRow(
|
||||
int64(123), "", "s3://bucket/file.csv", "db.table", int64(1),
|
||||
"import", "finished", "100MB", int64(1000), "success",
|
||||
"2023-01-01 10:00:00", "2023-01-01 10:00:01", "2023-01-01 10:00:02", "user", "2023-01-01 10:00:02",
|
||||
"", "100MB", "100MB", "100%", "10MB/s", "0s",
|
||||
)
|
||||
mock.ExpectQuery(sqlQuery).WillReturnRows(rows)
|
||||
|
||||
jobID, err := manager.SubmitJob(ctx, sqlQuery)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, int64(123), jobID)
|
||||
|
||||
// Case 2: No rows returned
|
||||
mock.ExpectQuery(sqlQuery).WillReturnRows(sqlmock.NewRows(cols))
|
||||
_, err = manager.SubmitJob(ctx, sqlQuery)
|
||||
require.Error(t, err)
|
||||
require.Contains(t, err.Error(), "no job id returned")
|
||||
|
||||
// Case 3: Error
|
||||
mock.ExpectQuery(sqlQuery).WillReturnError(sql.ErrConnDone)
|
||||
_, err = manager.SubmitJob(ctx, sqlQuery)
|
||||
require.Error(t, err)
|
||||
|
||||
require.NoError(t, mock.ExpectationsWereMet())
|
||||
}
|
||||
|
||||
func TestGetJobStatus(t *testing.T) {
|
||||
db, mock, err := sqlmock.New()
|
||||
require.NoError(t, err)
|
||||
defer db.Close()
|
||||
|
||||
manager := NewJobManager(db)
|
||||
ctx := context.Background()
|
||||
jobID := int64(123)
|
||||
|
||||
// Columns expected by GetJobStatus (same as scanJobStatus)
|
||||
cols := []string{
|
||||
"Job_ID", "Group_Key", "Data_Source", "Target_Table", "Table_ID",
|
||||
"Phase", "Status", "Source_File_Size", "Imported_Rows", "Result_Message",
|
||||
"Create_Time", "Start_Time", "End_Time", "Created_By", "Update_Time",
|
||||
"Step", "Processed_Size", "Total_Size", "Percent", "Speed", "ETA",
|
||||
}
|
||||
|
||||
// Case 1: Success
|
||||
rows := sqlmock.NewRows(cols).AddRow(
|
||||
jobID, "", "s3://bucket/file.csv", "db.table", int64(1),
|
||||
"import", "finished", "100MB", int64(1000), "success",
|
||||
"2023-01-01 10:00:00", "2023-01-01 10:00:01", "2023-01-01 10:00:02", "user", "2023-01-01 10:00:02",
|
||||
"", "100MB", "100MB", "100%", "10MB/s", "0s",
|
||||
)
|
||||
mock.ExpectQuery("SHOW IMPORT JOB 123").WillReturnRows(rows)
|
||||
|
||||
status, err := manager.GetJobStatus(ctx, jobID)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, jobID, status.JobID)
|
||||
require.Equal(t, "finished", status.Status)
|
||||
|
||||
// Case 2: Job not found
|
||||
mock.ExpectQuery("SHOW IMPORT JOB 123").WillReturnRows(sqlmock.NewRows(cols))
|
||||
_, err = manager.GetJobStatus(ctx, jobID)
|
||||
require.ErrorIs(t, err, ErrJobNotFound)
|
||||
|
||||
// Case 3: Error
|
||||
mock.ExpectQuery("SHOW IMPORT JOB 123").WillReturnError(sql.ErrConnDone)
|
||||
_, err = manager.GetJobStatus(ctx, jobID)
|
||||
require.Error(t, err)
|
||||
|
||||
require.NoError(t, mock.ExpectationsWereMet())
|
||||
}
|
||||
|
||||
func TestCancelJob(t *testing.T) {
|
||||
db, mock, err := sqlmock.New()
|
||||
require.NoError(t, err)
|
||||
defer db.Close()
|
||||
|
||||
manager := NewJobManager(db)
|
||||
ctx := context.Background()
|
||||
jobID := int64(123)
|
||||
|
||||
// Case 1: Success
|
||||
mock.ExpectExec("CANCEL IMPORT JOB 123").WillReturnResult(sqlmock.NewResult(0, 0))
|
||||
err = manager.CancelJob(ctx, jobID)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Case 2: Error
|
||||
mock.ExpectExec("CANCEL IMPORT JOB 123").WillReturnError(sql.ErrConnDone)
|
||||
err = manager.CancelJob(ctx, jobID)
|
||||
require.Error(t, err)
|
||||
|
||||
require.NoError(t, mock.ExpectationsWereMet())
|
||||
}
|
||||
|
||||
func TestGetGroupSummary(t *testing.T) {
|
||||
db, mock, err := sqlmock.New()
|
||||
require.NoError(t, err)
|
||||
defer db.Close()
|
||||
|
||||
manager := NewJobManager(db)
|
||||
ctx := context.Background()
|
||||
groupKey := "test_group"
|
||||
|
||||
// Columns expected by scanGroupStatus
|
||||
cols := []string{
|
||||
"Group_Key", "Total_Jobs", "Pending", "Running", "Completed", "Failed", "Cancelled",
|
||||
"First_Job_Create_Time", "Last_Job_Update_Time",
|
||||
}
|
||||
|
||||
// Case 1: Success
|
||||
rows := sqlmock.NewRows(cols).AddRow(
|
||||
groupKey, int64(10), int64(1), int64(2), int64(3), int64(2), int64(2),
|
||||
"2023-01-01 10:00:00", "2023-01-01 12:00:00",
|
||||
)
|
||||
mock.ExpectQuery("SHOW IMPORT GROUP 'test_group'").WillReturnRows(rows)
|
||||
|
||||
summary, err := manager.GetGroupSummary(ctx, groupKey)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, groupKey, summary.GroupKey)
|
||||
require.Equal(t, int64(10), summary.TotalJobs)
|
||||
require.Equal(t, int64(1), summary.Pending)
|
||||
require.Equal(t, int64(2), summary.Running)
|
||||
require.Equal(t, int64(3), summary.Completed)
|
||||
require.Equal(t, int64(2), summary.Failed)
|
||||
require.Equal(t, int64(2), summary.Cancelled)
|
||||
|
||||
// Case 2: Empty Group Key
|
||||
_, err = manager.GetGroupSummary(ctx, "")
|
||||
require.ErrorIs(t, err, ErrInvalidOptions)
|
||||
|
||||
// Case 3: Group not found
|
||||
mock.ExpectQuery("SHOW IMPORT GROUP 'test_group'").WillReturnRows(sqlmock.NewRows(cols))
|
||||
_, err = manager.GetGroupSummary(ctx, groupKey)
|
||||
require.ErrorIs(t, err, ErrJobNotFound)
|
||||
|
||||
// Case 4: Error
|
||||
mock.ExpectQuery("SHOW IMPORT GROUP 'test_group'").WillReturnError(sql.ErrConnDone)
|
||||
_, err = manager.GetGroupSummary(ctx, groupKey)
|
||||
require.Error(t, err)
|
||||
|
||||
require.NoError(t, mock.ExpectationsWereMet())
|
||||
}
|
||||
|
||||
func TestGetJobsByGroup(t *testing.T) {
|
||||
db, mock, err := sqlmock.New()
|
||||
require.NoError(t, err)
|
||||
defer db.Close()
|
||||
|
||||
manager := NewJobManager(db)
|
||||
ctx := context.Background()
|
||||
groupKey := "test_group"
|
||||
|
||||
// Columns expected by scanJobStatus
|
||||
cols := []string{
|
||||
"Job_ID", "Group_Key", "Data_Source", "Target_Table", "Table_ID",
|
||||
"Phase", "Status", "Source_File_Size", "Imported_Rows", "Result_Message",
|
||||
"Create_Time", "Start_Time", "End_Time", "Created_By", "Update_Time",
|
||||
"Step", "Processed_Size", "Total_Size", "Percent", "Speed", "ETA",
|
||||
}
|
||||
|
||||
// Case 1: Success
|
||||
rows := sqlmock.NewRows(cols).
|
||||
AddRow(
|
||||
int64(1), groupKey, "s3://bucket/file1.csv", "db.t1", int64(1),
|
||||
"import", "finished", "100MB", int64(1000), "success",
|
||||
"2023-01-01 10:00:00", "2023-01-01 10:00:01", "2023-01-01 10:00:02", "user", "2023-01-01 10:00:02",
|
||||
"", "100MB", "100MB", "100%", "10MB/s", "0s",
|
||||
).
|
||||
AddRow(
|
||||
int64(2), groupKey, "s3://bucket/file2.csv", "db.t2", int64(2),
|
||||
"import", "running", "200MB", int64(500), "",
|
||||
"2023-01-01 10:00:00", "2023-01-01 10:00:01", "", "user", "2023-01-01 10:00:02",
|
||||
"", "100MB", "200MB", "50%", "10MB/s", "10s",
|
||||
)
|
||||
mock.ExpectQuery("SHOW IMPORT JOBS WHERE GROUP_KEY = 'test_group'").WillReturnRows(rows)
|
||||
|
||||
jobs, err := manager.GetJobsByGroup(ctx, groupKey)
|
||||
require.NoError(t, err)
|
||||
require.Len(t, jobs, 2)
|
||||
require.Equal(t, int64(1), jobs[0].JobID)
|
||||
require.Equal(t, "finished", jobs[0].Status)
|
||||
require.Equal(t, int64(2), jobs[1].JobID)
|
||||
require.Equal(t, "running", jobs[1].Status)
|
||||
|
||||
// Case 2: Empty Group Key
|
||||
_, err = manager.GetJobsByGroup(ctx, "")
|
||||
require.ErrorIs(t, err, ErrInvalidOptions)
|
||||
|
||||
// Case 3: No jobs found (empty list)
|
||||
mock.ExpectQuery("SHOW IMPORT JOBS WHERE GROUP_KEY = 'test_group'").WillReturnRows(sqlmock.NewRows(cols))
|
||||
jobs, err = manager.GetJobsByGroup(ctx, groupKey)
|
||||
require.NoError(t, err)
|
||||
require.Empty(t, jobs)
|
||||
|
||||
// Case 4: Error
|
||||
mock.ExpectQuery("SHOW IMPORT JOBS WHERE GROUP_KEY = 'test_group'").WillReturnError(sql.ErrConnDone)
|
||||
_, err = manager.GetJobsByGroup(ctx, groupKey)
|
||||
require.Error(t, err)
|
||||
|
||||
require.NoError(t, mock.ExpectationsWereMet())
|
||||
}
|
||||
12
pkg/importsdk/mock/BUILD.bazel
Normal file
12
pkg/importsdk/mock/BUILD.bazel
Normal file
@ -0,0 +1,12 @@
|
||||
load("@io_bazel_rules_go//go:def.bzl", "go_library")
|
||||
|
||||
go_library(
|
||||
name = "mock",
|
||||
srcs = ["sdk_mock.go"],
|
||||
importpath = "github.com/pingcap/tidb/pkg/importsdk/mock",
|
||||
visibility = ["//visibility:public"],
|
||||
deps = [
|
||||
"//pkg/importsdk",
|
||||
"@org_uber_go_mock//gomock",
|
||||
],
|
||||
)
|
||||
480
pkg/importsdk/mock/sdk_mock.go
Normal file
480
pkg/importsdk/mock/sdk_mock.go
Normal file
@ -0,0 +1,480 @@
|
||||
// Code generated by MockGen. DO NOT EDIT.
|
||||
// Source: github.com/pingcap/tidb/pkg/importsdk (interfaces: FileScanner,JobManager,SQLGenerator,SDK)
|
||||
//
|
||||
// Generated by this command:
|
||||
//
|
||||
// mockgen -package mock github.com/pingcap/tidb/pkg/importsdk FileScanner,JobManager,SQLGenerator,SDK
|
||||
//
|
||||
|
||||
// Package mock is a generated GoMock package.
|
||||
package mock
|
||||
|
||||
import (
|
||||
context "context"
|
||||
reflect "reflect"
|
||||
|
||||
importsdk "github.com/pingcap/tidb/pkg/importsdk"
|
||||
gomock "go.uber.org/mock/gomock"
|
||||
)
|
||||
|
||||
// MockFileScanner is a mock of FileScanner interface.
|
||||
type MockFileScanner struct {
|
||||
ctrl *gomock.Controller
|
||||
recorder *MockFileScannerMockRecorder
|
||||
}
|
||||
|
||||
// MockFileScannerMockRecorder is the mock recorder for MockFileScanner.
|
||||
type MockFileScannerMockRecorder struct {
|
||||
mock *MockFileScanner
|
||||
}
|
||||
|
||||
// NewMockFileScanner creates a new mock instance.
|
||||
func NewMockFileScanner(ctrl *gomock.Controller) *MockFileScanner {
|
||||
mock := &MockFileScanner{ctrl: ctrl}
|
||||
mock.recorder = &MockFileScannerMockRecorder{mock}
|
||||
return mock
|
||||
}
|
||||
|
||||
// EXPECT returns an object that allows the caller to indicate expected use.
|
||||
func (m *MockFileScanner) EXPECT() *MockFileScannerMockRecorder {
|
||||
return m.recorder
|
||||
}
|
||||
|
||||
// ISGOMOCK indicates that this struct is a gomock mock.
|
||||
func (m *MockFileScanner) ISGOMOCK() struct{} {
|
||||
return struct{}{}
|
||||
}
|
||||
|
||||
// Close mocks base method.
|
||||
func (m *MockFileScanner) Close() error {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "Close")
|
||||
ret0, _ := ret[0].(error)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// Close indicates an expected call of Close.
|
||||
func (mr *MockFileScannerMockRecorder) Close() *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Close", reflect.TypeOf((*MockFileScanner)(nil).Close))
|
||||
}
|
||||
|
||||
// CreateSchemaAndTableByName mocks base method.
|
||||
func (m *MockFileScanner) CreateSchemaAndTableByName(arg0 context.Context, arg1, arg2 string) error {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "CreateSchemaAndTableByName", arg0, arg1, arg2)
|
||||
ret0, _ := ret[0].(error)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// CreateSchemaAndTableByName indicates an expected call of CreateSchemaAndTableByName.
|
||||
func (mr *MockFileScannerMockRecorder) CreateSchemaAndTableByName(arg0, arg1, arg2 any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CreateSchemaAndTableByName", reflect.TypeOf((*MockFileScanner)(nil).CreateSchemaAndTableByName), arg0, arg1, arg2)
|
||||
}
|
||||
|
||||
// CreateSchemasAndTables mocks base method.
|
||||
func (m *MockFileScanner) CreateSchemasAndTables(arg0 context.Context) error {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "CreateSchemasAndTables", arg0)
|
||||
ret0, _ := ret[0].(error)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// CreateSchemasAndTables indicates an expected call of CreateSchemasAndTables.
|
||||
func (mr *MockFileScannerMockRecorder) CreateSchemasAndTables(arg0 any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CreateSchemasAndTables", reflect.TypeOf((*MockFileScanner)(nil).CreateSchemasAndTables), arg0)
|
||||
}
|
||||
|
||||
// GetTableMetaByName mocks base method.
|
||||
func (m *MockFileScanner) GetTableMetaByName(arg0 context.Context, arg1, arg2 string) (*importsdk.TableMeta, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "GetTableMetaByName", arg0, arg1, arg2)
|
||||
ret0, _ := ret[0].(*importsdk.TableMeta)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// GetTableMetaByName indicates an expected call of GetTableMetaByName.
|
||||
func (mr *MockFileScannerMockRecorder) GetTableMetaByName(arg0, arg1, arg2 any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetTableMetaByName", reflect.TypeOf((*MockFileScanner)(nil).GetTableMetaByName), arg0, arg1, arg2)
|
||||
}
|
||||
|
||||
// GetTableMetas mocks base method.
|
||||
func (m *MockFileScanner) GetTableMetas(arg0 context.Context) ([]*importsdk.TableMeta, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "GetTableMetas", arg0)
|
||||
ret0, _ := ret[0].([]*importsdk.TableMeta)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// GetTableMetas indicates an expected call of GetTableMetas.
|
||||
func (mr *MockFileScannerMockRecorder) GetTableMetas(arg0 any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetTableMetas", reflect.TypeOf((*MockFileScanner)(nil).GetTableMetas), arg0)
|
||||
}
|
||||
|
||||
// GetTotalSize mocks base method.
|
||||
func (m *MockFileScanner) GetTotalSize(arg0 context.Context) int64 {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "GetTotalSize", arg0)
|
||||
ret0, _ := ret[0].(int64)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// GetTotalSize indicates an expected call of GetTotalSize.
|
||||
func (mr *MockFileScannerMockRecorder) GetTotalSize(arg0 any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetTotalSize", reflect.TypeOf((*MockFileScanner)(nil).GetTotalSize), arg0)
|
||||
}
|
||||
|
||||
// MockJobManager is a mock of JobManager interface.
|
||||
type MockJobManager struct {
|
||||
ctrl *gomock.Controller
|
||||
recorder *MockJobManagerMockRecorder
|
||||
}
|
||||
|
||||
// MockJobManagerMockRecorder is the mock recorder for MockJobManager.
|
||||
type MockJobManagerMockRecorder struct {
|
||||
mock *MockJobManager
|
||||
}
|
||||
|
||||
// NewMockJobManager creates a new mock instance.
|
||||
func NewMockJobManager(ctrl *gomock.Controller) *MockJobManager {
|
||||
mock := &MockJobManager{ctrl: ctrl}
|
||||
mock.recorder = &MockJobManagerMockRecorder{mock}
|
||||
return mock
|
||||
}
|
||||
|
||||
// EXPECT returns an object that allows the caller to indicate expected use.
|
||||
func (m *MockJobManager) EXPECT() *MockJobManagerMockRecorder {
|
||||
return m.recorder
|
||||
}
|
||||
|
||||
// ISGOMOCK indicates that this struct is a gomock mock.
|
||||
func (m *MockJobManager) ISGOMOCK() struct{} {
|
||||
return struct{}{}
|
||||
}
|
||||
|
||||
// CancelJob mocks base method.
|
||||
func (m *MockJobManager) CancelJob(arg0 context.Context, arg1 int64) error {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "CancelJob", arg0, arg1)
|
||||
ret0, _ := ret[0].(error)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// CancelJob indicates an expected call of CancelJob.
|
||||
func (mr *MockJobManagerMockRecorder) CancelJob(arg0, arg1 any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CancelJob", reflect.TypeOf((*MockJobManager)(nil).CancelJob), arg0, arg1)
|
||||
}
|
||||
|
||||
// GetGroupSummary mocks base method.
|
||||
func (m *MockJobManager) GetGroupSummary(arg0 context.Context, arg1 string) (*importsdk.GroupStatus, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "GetGroupSummary", arg0, arg1)
|
||||
ret0, _ := ret[0].(*importsdk.GroupStatus)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// GetGroupSummary indicates an expected call of GetGroupSummary.
|
||||
func (mr *MockJobManagerMockRecorder) GetGroupSummary(arg0, arg1 any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetGroupSummary", reflect.TypeOf((*MockJobManager)(nil).GetGroupSummary), arg0, arg1)
|
||||
}
|
||||
|
||||
// GetJobStatus mocks base method.
|
||||
func (m *MockJobManager) GetJobStatus(arg0 context.Context, arg1 int64) (*importsdk.JobStatus, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "GetJobStatus", arg0, arg1)
|
||||
ret0, _ := ret[0].(*importsdk.JobStatus)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// GetJobStatus indicates an expected call of GetJobStatus.
|
||||
func (mr *MockJobManagerMockRecorder) GetJobStatus(arg0, arg1 any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetJobStatus", reflect.TypeOf((*MockJobManager)(nil).GetJobStatus), arg0, arg1)
|
||||
}
|
||||
|
||||
// GetJobsByGroup mocks base method.
|
||||
func (m *MockJobManager) GetJobsByGroup(arg0 context.Context, arg1 string) ([]*importsdk.JobStatus, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "GetJobsByGroup", arg0, arg1)
|
||||
ret0, _ := ret[0].([]*importsdk.JobStatus)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// GetJobsByGroup indicates an expected call of GetJobsByGroup.
|
||||
func (mr *MockJobManagerMockRecorder) GetJobsByGroup(arg0, arg1 any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetJobsByGroup", reflect.TypeOf((*MockJobManager)(nil).GetJobsByGroup), arg0, arg1)
|
||||
}
|
||||
|
||||
// SubmitJob mocks base method.
|
||||
func (m *MockJobManager) SubmitJob(arg0 context.Context, arg1 string) (int64, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "SubmitJob", arg0, arg1)
|
||||
ret0, _ := ret[0].(int64)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// SubmitJob indicates an expected call of SubmitJob.
|
||||
func (mr *MockJobManagerMockRecorder) SubmitJob(arg0, arg1 any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SubmitJob", reflect.TypeOf((*MockJobManager)(nil).SubmitJob), arg0, arg1)
|
||||
}
|
||||
|
||||
// MockSQLGenerator is a mock of SQLGenerator interface.
|
||||
type MockSQLGenerator struct {
|
||||
ctrl *gomock.Controller
|
||||
recorder *MockSQLGeneratorMockRecorder
|
||||
}
|
||||
|
||||
// MockSQLGeneratorMockRecorder is the mock recorder for MockSQLGenerator.
|
||||
type MockSQLGeneratorMockRecorder struct {
|
||||
mock *MockSQLGenerator
|
||||
}
|
||||
|
||||
// NewMockSQLGenerator creates a new mock instance.
|
||||
func NewMockSQLGenerator(ctrl *gomock.Controller) *MockSQLGenerator {
|
||||
mock := &MockSQLGenerator{ctrl: ctrl}
|
||||
mock.recorder = &MockSQLGeneratorMockRecorder{mock}
|
||||
return mock
|
||||
}
|
||||
|
||||
// EXPECT returns an object that allows the caller to indicate expected use.
|
||||
func (m *MockSQLGenerator) EXPECT() *MockSQLGeneratorMockRecorder {
|
||||
return m.recorder
|
||||
}
|
||||
|
||||
// ISGOMOCK indicates that this struct is a gomock mock.
|
||||
func (m *MockSQLGenerator) ISGOMOCK() struct{} {
|
||||
return struct{}{}
|
||||
}
|
||||
|
||||
// GenerateImportSQL mocks base method.
|
||||
func (m *MockSQLGenerator) GenerateImportSQL(arg0 *importsdk.TableMeta, arg1 *importsdk.ImportOptions) (string, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "GenerateImportSQL", arg0, arg1)
|
||||
ret0, _ := ret[0].(string)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// GenerateImportSQL indicates an expected call of GenerateImportSQL.
|
||||
func (mr *MockSQLGeneratorMockRecorder) GenerateImportSQL(arg0, arg1 any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GenerateImportSQL", reflect.TypeOf((*MockSQLGenerator)(nil).GenerateImportSQL), arg0, arg1)
|
||||
}
|
||||
|
||||
// MockSDK is a mock of SDK interface.
|
||||
type MockSDK struct {
|
||||
ctrl *gomock.Controller
|
||||
recorder *MockSDKMockRecorder
|
||||
}
|
||||
|
||||
// MockSDKMockRecorder is the mock recorder for MockSDK.
|
||||
type MockSDKMockRecorder struct {
|
||||
mock *MockSDK
|
||||
}
|
||||
|
||||
// NewMockSDK creates a new mock instance.
|
||||
func NewMockSDK(ctrl *gomock.Controller) *MockSDK {
|
||||
mock := &MockSDK{ctrl: ctrl}
|
||||
mock.recorder = &MockSDKMockRecorder{mock}
|
||||
return mock
|
||||
}
|
||||
|
||||
// EXPECT returns an object that allows the caller to indicate expected use.
|
||||
func (m *MockSDK) EXPECT() *MockSDKMockRecorder {
|
||||
return m.recorder
|
||||
}
|
||||
|
||||
// ISGOMOCK indicates that this struct is a gomock mock.
|
||||
func (m *MockSDK) ISGOMOCK() struct{} {
|
||||
return struct{}{}
|
||||
}
|
||||
|
||||
// CancelJob mocks base method.
|
||||
func (m *MockSDK) CancelJob(arg0 context.Context, arg1 int64) error {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "CancelJob", arg0, arg1)
|
||||
ret0, _ := ret[0].(error)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// CancelJob indicates an expected call of CancelJob.
|
||||
func (mr *MockSDKMockRecorder) CancelJob(arg0, arg1 any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CancelJob", reflect.TypeOf((*MockSDK)(nil).CancelJob), arg0, arg1)
|
||||
}
|
||||
|
||||
// Close mocks base method.
|
||||
func (m *MockSDK) Close() error {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "Close")
|
||||
ret0, _ := ret[0].(error)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// Close indicates an expected call of Close.
|
||||
func (mr *MockSDKMockRecorder) Close() *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Close", reflect.TypeOf((*MockSDK)(nil).Close))
|
||||
}
|
||||
|
||||
// CreateSchemaAndTableByName mocks base method.
|
||||
func (m *MockSDK) CreateSchemaAndTableByName(arg0 context.Context, arg1, arg2 string) error {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "CreateSchemaAndTableByName", arg0, arg1, arg2)
|
||||
ret0, _ := ret[0].(error)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// CreateSchemaAndTableByName indicates an expected call of CreateSchemaAndTableByName.
|
||||
func (mr *MockSDKMockRecorder) CreateSchemaAndTableByName(arg0, arg1, arg2 any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CreateSchemaAndTableByName", reflect.TypeOf((*MockSDK)(nil).CreateSchemaAndTableByName), arg0, arg1, arg2)
|
||||
}
|
||||
|
||||
// CreateSchemasAndTables mocks base method.
|
||||
func (m *MockSDK) CreateSchemasAndTables(arg0 context.Context) error {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "CreateSchemasAndTables", arg0)
|
||||
ret0, _ := ret[0].(error)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// CreateSchemasAndTables indicates an expected call of CreateSchemasAndTables.
|
||||
func (mr *MockSDKMockRecorder) CreateSchemasAndTables(arg0 any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CreateSchemasAndTables", reflect.TypeOf((*MockSDK)(nil).CreateSchemasAndTables), arg0)
|
||||
}
|
||||
|
||||
// GenerateImportSQL mocks base method.
|
||||
func (m *MockSDK) GenerateImportSQL(arg0 *importsdk.TableMeta, arg1 *importsdk.ImportOptions) (string, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "GenerateImportSQL", arg0, arg1)
|
||||
ret0, _ := ret[0].(string)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// GenerateImportSQL indicates an expected call of GenerateImportSQL.
|
||||
func (mr *MockSDKMockRecorder) GenerateImportSQL(arg0, arg1 any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GenerateImportSQL", reflect.TypeOf((*MockSDK)(nil).GenerateImportSQL), arg0, arg1)
|
||||
}
|
||||
|
||||
// GetGroupSummary mocks base method.
|
||||
func (m *MockSDK) GetGroupSummary(arg0 context.Context, arg1 string) (*importsdk.GroupStatus, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "GetGroupSummary", arg0, arg1)
|
||||
ret0, _ := ret[0].(*importsdk.GroupStatus)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// GetGroupSummary indicates an expected call of GetGroupSummary.
|
||||
func (mr *MockSDKMockRecorder) GetGroupSummary(arg0, arg1 any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetGroupSummary", reflect.TypeOf((*MockSDK)(nil).GetGroupSummary), arg0, arg1)
|
||||
}
|
||||
|
||||
// GetJobStatus mocks base method.
|
||||
func (m *MockSDK) GetJobStatus(arg0 context.Context, arg1 int64) (*importsdk.JobStatus, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "GetJobStatus", arg0, arg1)
|
||||
ret0, _ := ret[0].(*importsdk.JobStatus)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// GetJobStatus indicates an expected call of GetJobStatus.
|
||||
func (mr *MockSDKMockRecorder) GetJobStatus(arg0, arg1 any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetJobStatus", reflect.TypeOf((*MockSDK)(nil).GetJobStatus), arg0, arg1)
|
||||
}
|
||||
|
||||
// GetJobsByGroup mocks base method.
|
||||
func (m *MockSDK) GetJobsByGroup(arg0 context.Context, arg1 string) ([]*importsdk.JobStatus, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "GetJobsByGroup", arg0, arg1)
|
||||
ret0, _ := ret[0].([]*importsdk.JobStatus)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// GetJobsByGroup indicates an expected call of GetJobsByGroup.
|
||||
func (mr *MockSDKMockRecorder) GetJobsByGroup(arg0, arg1 any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetJobsByGroup", reflect.TypeOf((*MockSDK)(nil).GetJobsByGroup), arg0, arg1)
|
||||
}
|
||||
|
||||
// GetTableMetaByName mocks base method.
|
||||
func (m *MockSDK) GetTableMetaByName(arg0 context.Context, arg1, arg2 string) (*importsdk.TableMeta, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "GetTableMetaByName", arg0, arg1, arg2)
|
||||
ret0, _ := ret[0].(*importsdk.TableMeta)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// GetTableMetaByName indicates an expected call of GetTableMetaByName.
|
||||
func (mr *MockSDKMockRecorder) GetTableMetaByName(arg0, arg1, arg2 any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetTableMetaByName", reflect.TypeOf((*MockSDK)(nil).GetTableMetaByName), arg0, arg1, arg2)
|
||||
}
|
||||
|
||||
// GetTableMetas mocks base method.
|
||||
func (m *MockSDK) GetTableMetas(arg0 context.Context) ([]*importsdk.TableMeta, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "GetTableMetas", arg0)
|
||||
ret0, _ := ret[0].([]*importsdk.TableMeta)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// GetTableMetas indicates an expected call of GetTableMetas.
|
||||
func (mr *MockSDKMockRecorder) GetTableMetas(arg0 any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetTableMetas", reflect.TypeOf((*MockSDK)(nil).GetTableMetas), arg0)
|
||||
}
|
||||
|
||||
// GetTotalSize mocks base method.
|
||||
func (m *MockSDK) GetTotalSize(arg0 context.Context) int64 {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "GetTotalSize", arg0)
|
||||
ret0, _ := ret[0].(int64)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// GetTotalSize indicates an expected call of GetTotalSize.
|
||||
func (mr *MockSDKMockRecorder) GetTotalSize(arg0 any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetTotalSize", reflect.TypeOf((*MockSDK)(nil).GetTotalSize), arg0)
|
||||
}
|
||||
|
||||
// SubmitJob mocks base method.
|
||||
func (m *MockSDK) SubmitJob(arg0 context.Context, arg1 string) (int64, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "SubmitJob", arg0, arg1)
|
||||
ret0, _ := ret[0].(int64)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// SubmitJob indicates an expected call of SubmitJob.
|
||||
func (mr *MockSDKMockRecorder) SubmitJob(arg0, arg1 any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SubmitJob", reflect.TypeOf((*MockSDK)(nil).SubmitJob), arg0, arg1)
|
||||
}
|
||||
119
pkg/importsdk/model.go
Normal file
119
pkg/importsdk/model.go
Normal file
@ -0,0 +1,119 @@
|
||||
// Copyright 2025 PingCAP, Inc.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package importsdk
|
||||
|
||||
import (
|
||||
"time"
|
||||
|
||||
"github.com/pingcap/tidb/pkg/lightning/config"
|
||||
"github.com/pingcap/tidb/pkg/lightning/mydump"
|
||||
)
|
||||
|
||||
// TableMeta contains metadata for a table to be imported
|
||||
type TableMeta struct {
|
||||
Database string
|
||||
Table string
|
||||
DataFiles []DataFileMeta
|
||||
TotalSize int64 // In bytes
|
||||
WildcardPath string // Wildcard pattern that matches only this table's data files
|
||||
SchemaFile string // Path to the table schema file, if available
|
||||
}
|
||||
|
||||
// DataFileMeta contains metadata for a data file
|
||||
type DataFileMeta struct {
|
||||
Path string
|
||||
Size int64
|
||||
Format mydump.SourceType
|
||||
Compression mydump.Compression
|
||||
}
|
||||
|
||||
// ImportOptions wraps the options for IMPORT INTO statement.
|
||||
// It reuses structures from executor/importer where possible.
|
||||
type ImportOptions struct {
|
||||
Format string
|
||||
CSVConfig *config.CSVConfig
|
||||
Thread int
|
||||
DiskQuota string
|
||||
MaxWriteSpeed string
|
||||
SplitFile bool
|
||||
RecordErrors int64
|
||||
Detached bool
|
||||
CloudStorageURI string
|
||||
GroupKey string
|
||||
SkipRows int
|
||||
CharacterSet string
|
||||
ChecksumTable string
|
||||
DisableTiKVImportMode bool
|
||||
DisablePrecheck bool
|
||||
ResourceParameters string
|
||||
}
|
||||
|
||||
// GroupStatus represents the aggregated status for a group of import jobs.
|
||||
type GroupStatus struct {
|
||||
GroupKey string
|
||||
TotalJobs int64
|
||||
Pending int64
|
||||
Running int64
|
||||
Completed int64
|
||||
Failed int64
|
||||
Cancelled int64
|
||||
FirstJobCreateTime time.Time
|
||||
LastJobUpdateTime time.Time
|
||||
}
|
||||
|
||||
// JobStatus represents the status of an import job.
|
||||
type JobStatus struct {
|
||||
JobID int64
|
||||
GroupKey string
|
||||
DataSource string
|
||||
TargetTable string
|
||||
TableID int64
|
||||
Phase string
|
||||
Status string
|
||||
SourceFileSize string
|
||||
ImportedRows int64
|
||||
ResultMessage string
|
||||
CreateTime time.Time
|
||||
StartTime time.Time
|
||||
EndTime time.Time
|
||||
CreatedBy string
|
||||
UpdateTime time.Time
|
||||
Step string
|
||||
ProcessedSize string
|
||||
TotalSize string
|
||||
Percent string
|
||||
Speed string
|
||||
ETA string
|
||||
}
|
||||
|
||||
// IsFinished returns true if the job is finished successfully.
|
||||
func (s *JobStatus) IsFinished() bool {
|
||||
return s.Status == "finished"
|
||||
}
|
||||
|
||||
// IsFailed returns true if the job failed.
|
||||
func (s *JobStatus) IsFailed() bool {
|
||||
return s.Status == "failed"
|
||||
}
|
||||
|
||||
// IsCancelled returns true if the job was cancelled.
|
||||
func (s *JobStatus) IsCancelled() bool {
|
||||
return s.Status == "cancelled"
|
||||
}
|
||||
|
||||
// IsCompleted returns true if the job is in a terminal state.
|
||||
func (s *JobStatus) IsCompleted() bool {
|
||||
return s.IsFinished() || s.IsFailed() || s.IsCancelled()
|
||||
}
|
||||
82
pkg/importsdk/model_test.go
Normal file
82
pkg/importsdk/model_test.go
Normal file
@ -0,0 +1,82 @@
|
||||
// Copyright 2025 PingCAP, Inc.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package importsdk
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestJobStatus(t *testing.T) {
|
||||
tests := []struct {
|
||||
status string
|
||||
isFinished bool
|
||||
isFailed bool
|
||||
isCancelled bool
|
||||
isCompleted bool
|
||||
}{
|
||||
{
|
||||
status: "finished",
|
||||
isFinished: true,
|
||||
isFailed: false,
|
||||
isCancelled: false,
|
||||
isCompleted: true,
|
||||
},
|
||||
{
|
||||
status: "failed",
|
||||
isFinished: false,
|
||||
isFailed: true,
|
||||
isCancelled: false,
|
||||
isCompleted: true,
|
||||
},
|
||||
{
|
||||
status: "cancelled",
|
||||
isFinished: false,
|
||||
isFailed: false,
|
||||
isCancelled: true,
|
||||
isCompleted: true,
|
||||
},
|
||||
{
|
||||
status: "running",
|
||||
isFinished: false,
|
||||
isFailed: false,
|
||||
isCancelled: false,
|
||||
isCompleted: false,
|
||||
},
|
||||
{
|
||||
status: "pending",
|
||||
isFinished: false,
|
||||
isFailed: false,
|
||||
isCancelled: false,
|
||||
isCompleted: false,
|
||||
},
|
||||
{
|
||||
status: "unknown",
|
||||
isFinished: false,
|
||||
isFailed: false,
|
||||
isCancelled: false,
|
||||
isCompleted: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
job := &JobStatus{Status: tt.status}
|
||||
require.Equal(t, tt.isFinished, job.IsFinished(), "status: %s", tt.status)
|
||||
require.Equal(t, tt.isFailed, job.IsFailed(), "status: %s", tt.status)
|
||||
require.Equal(t, tt.isCancelled, job.IsCancelled(), "status: %s", tt.status)
|
||||
require.Equal(t, tt.isCompleted, job.IsCompleted(), "status: %s", tt.status)
|
||||
}
|
||||
}
|
||||
175
pkg/importsdk/pattern.go
Normal file
175
pkg/importsdk/pattern.go
Normal file
@ -0,0 +1,175 @@
|
||||
// Copyright 2025 PingCAP, Inc.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package importsdk
|
||||
|
||||
import (
|
||||
"path/filepath"
|
||||
"strings"
|
||||
|
||||
"github.com/pingcap/errors"
|
||||
"github.com/pingcap/tidb/pkg/lightning/mydump"
|
||||
)
|
||||
|
||||
// generateWildcardPath creates a wildcard pattern path that matches only this table's files
|
||||
func generateWildcardPath(
|
||||
files []mydump.FileInfo,
|
||||
allFiles map[string]mydump.FileInfo,
|
||||
) (string, error) {
|
||||
tableFiles := make(map[string]struct{}, len(files))
|
||||
for _, df := range files {
|
||||
tableFiles[df.FileMeta.Path] = struct{}{}
|
||||
}
|
||||
|
||||
if len(files) == 0 {
|
||||
return "", errors.Annotate(ErrNoTableDataFiles, "cannot generate wildcard pattern because the table has no data files")
|
||||
}
|
||||
|
||||
// If there's only one file, we can just return its path
|
||||
if len(files) == 1 {
|
||||
return files[0].FileMeta.Path, nil
|
||||
}
|
||||
|
||||
// Try Mydumper-specific pattern first
|
||||
p := generateMydumperPattern(files[0])
|
||||
if p != "" && isValidPattern(p, tableFiles, allFiles) {
|
||||
return p, nil
|
||||
}
|
||||
|
||||
// Fallback to generic prefix/suffix pattern
|
||||
paths := make([]string, 0, len(files))
|
||||
for _, file := range files {
|
||||
paths = append(paths, file.FileMeta.Path)
|
||||
}
|
||||
p = generatePrefixSuffixPattern(paths)
|
||||
if p != "" && isValidPattern(p, tableFiles, allFiles) {
|
||||
return p, nil
|
||||
}
|
||||
return "", errors.Annotatef(ErrWildcardNotSpecific, "failed to find a wildcard that matches all and only the table's files.")
|
||||
}
|
||||
|
||||
// isValidPattern checks if a wildcard pattern matches only the table's files
|
||||
func isValidPattern(pattern string, tableFiles map[string]struct{}, allFiles map[string]mydump.FileInfo) bool {
|
||||
if pattern == "" {
|
||||
return false
|
||||
}
|
||||
|
||||
for path := range allFiles {
|
||||
isMatch, err := filepath.Match(pattern, path)
|
||||
if err != nil {
|
||||
return false // Invalid pattern
|
||||
}
|
||||
_, isTableFile := tableFiles[path]
|
||||
|
||||
// If pattern matches a file that's not from our table, it's invalid
|
||||
if isMatch && !isTableFile {
|
||||
return false
|
||||
}
|
||||
|
||||
// If pattern doesn't match our table's file, it's also invalid
|
||||
if !isMatch && isTableFile {
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
return true
|
||||
}
|
||||
|
||||
// generateMydumperPattern generates a wildcard pattern for Mydumper-formatted data files
|
||||
// belonging to a specific table, based on their naming convention.
|
||||
// It returns a pattern string that matches all data files for the table, or an empty string if not applicable.
|
||||
func generateMydumperPattern(file mydump.FileInfo) string {
|
||||
dbName, tableName := file.TableName.Schema, file.TableName.Name
|
||||
if dbName == "" || tableName == "" {
|
||||
return ""
|
||||
}
|
||||
|
||||
// compute dirPrefix and basename
|
||||
full := file.FileMeta.Path
|
||||
dirPrefix, name := "", full
|
||||
if idx := strings.LastIndex(full, "/"); idx >= 0 {
|
||||
dirPrefix = full[:idx+1]
|
||||
name = full[idx+1:]
|
||||
}
|
||||
|
||||
// compression ext from filename when compression exists (last suffix like .gz/.zst)
|
||||
compExt := ""
|
||||
if file.FileMeta.Compression != mydump.CompressionNone {
|
||||
compExt = filepath.Ext(name)
|
||||
}
|
||||
|
||||
// data ext after stripping compression ext
|
||||
base := strings.TrimSuffix(name, compExt)
|
||||
dataExt := filepath.Ext(base)
|
||||
return dirPrefix + dbName + "." + tableName + ".*" + dataExt + compExt
|
||||
}
|
||||
|
||||
// longestCommonPrefix finds the longest string that is a prefix of all strings in the slice
|
||||
func longestCommonPrefix(strs []string) string {
|
||||
if len(strs) == 0 {
|
||||
return ""
|
||||
}
|
||||
|
||||
prefix := strs[0]
|
||||
for _, s := range strs[1:] {
|
||||
i := 0
|
||||
for i < len(prefix) && i < len(s) && prefix[i] == s[i] {
|
||||
i++
|
||||
}
|
||||
prefix = prefix[:i]
|
||||
if prefix == "" {
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
return prefix
|
||||
}
|
||||
|
||||
// longestCommonSuffix finds the longest string that is a suffix of all strings in the slice, starting after the given prefix length
|
||||
func longestCommonSuffix(strs []string, prefixLen int) string {
|
||||
if len(strs) == 0 {
|
||||
return ""
|
||||
}
|
||||
|
||||
suffix := strs[0][prefixLen:]
|
||||
for _, s := range strs[1:] {
|
||||
remaining := s[prefixLen:]
|
||||
i := 0
|
||||
for i < len(suffix) && i < len(remaining) && suffix[len(suffix)-i-1] == remaining[len(remaining)-i-1] {
|
||||
i++
|
||||
}
|
||||
suffix = suffix[len(suffix)-i:]
|
||||
if suffix == "" {
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
return suffix
|
||||
}
|
||||
|
||||
// generatePrefixSuffixPattern returns a wildcard pattern that matches all and only the given paths
|
||||
// by finding the longest common prefix and suffix among them, and placing a '*' wildcard in between.
|
||||
func generatePrefixSuffixPattern(paths []string) string {
|
||||
if len(paths) == 0 {
|
||||
return ""
|
||||
}
|
||||
if len(paths) == 1 {
|
||||
return paths[0]
|
||||
}
|
||||
|
||||
prefix := longestCommonPrefix(paths)
|
||||
suffix := longestCommonSuffix(paths, len(prefix))
|
||||
|
||||
return prefix + "*" + suffix
|
||||
}
|
||||
208
pkg/importsdk/pattern_test.go
Normal file
208
pkg/importsdk/pattern_test.go
Normal file
@ -0,0 +1,208 @@
|
||||
// Copyright 2025 PingCAP, Inc.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package importsdk
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/pingcap/tidb/pkg/lightning/log"
|
||||
"github.com/pingcap/tidb/pkg/lightning/mydump"
|
||||
filter "github.com/pingcap/tidb/pkg/util/table-filter"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestLongestCommonPrefix(t *testing.T) {
|
||||
strs := []string{"s3://bucket/foo/bar/baz1", "s3://bucket/foo/bar/baz2", "s3://bucket/foo/bar/baz"}
|
||||
p := longestCommonPrefix(strs)
|
||||
require.Equal(t, "s3://bucket/foo/bar/baz", p)
|
||||
|
||||
// no common prefix
|
||||
require.Equal(t, "", longestCommonPrefix([]string{"a", "b"}))
|
||||
|
||||
// empty inputs
|
||||
require.Equal(t, "", longestCommonPrefix(nil))
|
||||
require.Equal(t, "", longestCommonPrefix([]string{}))
|
||||
}
|
||||
|
||||
func TestLongestCommonSuffix(t *testing.T) {
|
||||
strs := []string{"abcXYZ", "defXYZ", "XYZ"}
|
||||
s := longestCommonSuffix(strs, 0)
|
||||
require.Equal(t, "XYZ", s)
|
||||
|
||||
// no common suffix
|
||||
require.Equal(t, "", longestCommonSuffix([]string{"a", "b"}, 0))
|
||||
|
||||
// empty inputs
|
||||
require.Equal(t, "", longestCommonSuffix(nil, 0))
|
||||
require.Equal(t, "", longestCommonSuffix([]string{}, 0))
|
||||
|
||||
// same prefix
|
||||
require.Equal(t, "", longestCommonSuffix([]string{"abc", "abc"}, 3))
|
||||
require.Equal(t, "f", longestCommonSuffix([]string{"abcdf", "abcef"}, 3))
|
||||
}
|
||||
|
||||
func TestGeneratePrefixSuffixPattern(t *testing.T) {
|
||||
paths := []string{"pre_middle_suf", "pre_most_suf"}
|
||||
pattern := generatePrefixSuffixPattern(paths)
|
||||
// common prefix "pre_m", suffix "_suf"
|
||||
require.Equal(t, "pre_m*_suf", pattern)
|
||||
|
||||
// empty inputs
|
||||
require.Equal(t, "", generatePrefixSuffixPattern(nil))
|
||||
require.Equal(t, "", generatePrefixSuffixPattern([]string{}))
|
||||
|
||||
// only one file
|
||||
require.Equal(t, "pre_middle_suf", generatePrefixSuffixPattern([]string{"pre_middle_suf"}))
|
||||
|
||||
// no common prefix/suffix
|
||||
paths2 := []string{"foo", "bar"}
|
||||
require.Equal(t, "*", generatePrefixSuffixPattern(paths2))
|
||||
|
||||
// overlapping prefix/suffix
|
||||
paths3 := []string{"aaabaaa", "aaa"}
|
||||
require.Equal(t, "aaa*", generatePrefixSuffixPattern(paths3))
|
||||
}
|
||||
|
||||
func generateFileMetas(t *testing.T, paths []string) []mydump.FileInfo {
|
||||
t.Helper()
|
||||
|
||||
files := make([]mydump.FileInfo, 0, len(paths))
|
||||
fileRouter, err := mydump.NewDefaultFileRouter(log.L())
|
||||
require.NoError(t, err)
|
||||
for _, p := range paths {
|
||||
res, err := fileRouter.Route(p)
|
||||
require.NoError(t, err)
|
||||
files = append(files, mydump.FileInfo{
|
||||
TableName: res.Table,
|
||||
FileMeta: mydump.SourceFileMeta{
|
||||
Path: p,
|
||||
Type: res.Type,
|
||||
Compression: res.Compression,
|
||||
SortKey: res.Key,
|
||||
},
|
||||
})
|
||||
}
|
||||
return files
|
||||
}
|
||||
|
||||
func TestGenerateMydumperPattern(t *testing.T) {
|
||||
paths := []string{"db.tb.0001.sql", "db.tb.0002.sql"}
|
||||
p := generateMydumperPattern(generateFileMetas(t, paths)[0])
|
||||
require.Equal(t, "db.tb.*.sql", p)
|
||||
|
||||
paths2 := []string{"s3://bucket/dir/db.tb.0001.sql", "s3://bucket/dir/db.tb.0002.sql"}
|
||||
p2 := generateMydumperPattern(generateFileMetas(t, paths2)[0])
|
||||
require.Equal(t, "s3://bucket/dir/db.tb.*.sql", p2)
|
||||
|
||||
// not mydumper pattern
|
||||
require.Equal(t, "", generateMydumperPattern(mydump.FileInfo{
|
||||
TableName: filter.Table{},
|
||||
}))
|
||||
}
|
||||
|
||||
func TestValidatePattern(t *testing.T) {
|
||||
tableFiles := map[string]struct{}{
|
||||
"a.txt": {}, "b.txt": {},
|
||||
}
|
||||
// only table files in allFiles
|
||||
smallAll := map[string]mydump.FileInfo{
|
||||
"a.txt": {}, "b.txt": {},
|
||||
}
|
||||
require.True(t, isValidPattern("*.txt", tableFiles, smallAll))
|
||||
|
||||
// allFiles includes an extra file => invalid
|
||||
fullAll := map[string]mydump.FileInfo{
|
||||
"a.txt": {}, "b.txt": {}, "c.txt": {},
|
||||
}
|
||||
require.False(t, isValidPattern("*.txt", tableFiles, fullAll))
|
||||
|
||||
// If pattern doesn't match our table's file, it's also invalid
|
||||
require.False(t, isValidPattern("*.csv", tableFiles, smallAll))
|
||||
|
||||
// empty pattern => invalid
|
||||
require.False(t, isValidPattern("", tableFiles, smallAll))
|
||||
}
|
||||
|
||||
func TestGenerateWildcardPath(t *testing.T) {
|
||||
// Helper to create allFiles map
|
||||
createAllFiles := func(paths []string) map[string]mydump.FileInfo {
|
||||
allFiles := make(map[string]mydump.FileInfo)
|
||||
for _, p := range paths {
|
||||
allFiles[p] = mydump.FileInfo{
|
||||
FileMeta: mydump.SourceFileMeta{Path: p},
|
||||
}
|
||||
}
|
||||
return allFiles
|
||||
}
|
||||
|
||||
// No files
|
||||
files1 := []mydump.FileInfo{}
|
||||
allFiles1 := createAllFiles([]string{})
|
||||
_, err := generateWildcardPath(files1, allFiles1)
|
||||
require.Error(t, err)
|
||||
require.Contains(t, err.Error(), "no data files for table")
|
||||
|
||||
// Single file
|
||||
files2 := generateFileMetas(t, []string{"db.tb.0001.sql"})
|
||||
allFiles2 := createAllFiles([]string{"db.tb.0001.sql"})
|
||||
path2, err := generateWildcardPath(files2, allFiles2)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, "db.tb.0001.sql", path2)
|
||||
|
||||
// Mydumper pattern succeeds
|
||||
files3 := generateFileMetas(t, []string{"db.tb.0001.sql.gz", "db.tb.0002.sql.gz"})
|
||||
allFiles3 := createAllFiles([]string{"db.tb.0001.sql.gz", "db.tb.0002.sql.gz"})
|
||||
path3, err := generateWildcardPath(files3, allFiles3)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, "db.tb.*.sql.gz", path3)
|
||||
|
||||
// Mydumper pattern fails, fallback to prefix/suffix succeeds
|
||||
files4 := []mydump.FileInfo{
|
||||
{
|
||||
TableName: filter.Table{Schema: "db", Name: "tb"},
|
||||
FileMeta: mydump.SourceFileMeta{
|
||||
Path: "a.sql",
|
||||
Type: mydump.SourceTypeSQL,
|
||||
Compression: mydump.CompressionNone,
|
||||
},
|
||||
},
|
||||
{
|
||||
TableName: filter.Table{Schema: "db", Name: "tb"},
|
||||
FileMeta: mydump.SourceFileMeta{
|
||||
Path: "b.sql",
|
||||
Type: mydump.SourceTypeSQL,
|
||||
Compression: mydump.CompressionNone,
|
||||
},
|
||||
},
|
||||
}
|
||||
allFiles4 := map[string]mydump.FileInfo{
|
||||
files4[0].FileMeta.Path: files4[0],
|
||||
files4[1].FileMeta.Path: files4[1],
|
||||
}
|
||||
path4, err := generateWildcardPath(files4, allFiles4)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, "*.sql", path4)
|
||||
|
||||
allFiles4["db-schema.sql"] = mydump.FileInfo{
|
||||
FileMeta: mydump.SourceFileMeta{
|
||||
Path: "db-schema.sql",
|
||||
Type: mydump.SourceTypeSQL,
|
||||
Compression: mydump.CompressionNone,
|
||||
},
|
||||
}
|
||||
_, err = generateWildcardPath(files4, allFiles4)
|
||||
require.Error(t, err)
|
||||
require.Contains(t, err.Error(), "cannot generate a unique wildcard pattern")
|
||||
}
|
||||
@ -17,61 +17,21 @@ package importsdk
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
|
||||
"github.com/pingcap/errors"
|
||||
"github.com/pingcap/tidb/br/pkg/storage"
|
||||
"github.com/pingcap/tidb/pkg/lightning/config"
|
||||
"github.com/pingcap/tidb/pkg/lightning/log"
|
||||
"github.com/pingcap/tidb/pkg/lightning/mydump"
|
||||
"github.com/pingcap/tidb/pkg/parser/mysql"
|
||||
)
|
||||
|
||||
// SDK defines the interface for cloud import services
|
||||
type SDK interface {
|
||||
// CreateSchemasAndTables creates all database schemas and tables from source path
|
||||
CreateSchemasAndTables(ctx context.Context) error
|
||||
|
||||
// GetTableMetas returns metadata for all tables in the source path
|
||||
GetTableMetas(ctx context.Context) ([]*TableMeta, error)
|
||||
|
||||
// GetTableMetaByName returns metadata for a specific table
|
||||
GetTableMetaByName(ctx context.Context, schema, table string) (*TableMeta, error)
|
||||
|
||||
// GetTotalSize returns the cumulative size (in bytes) of all data files under the source path
|
||||
GetTotalSize(ctx context.Context) int64
|
||||
|
||||
// Close releases resources used by the SDK
|
||||
FileScanner
|
||||
JobManager
|
||||
SQLGenerator
|
||||
Close() error
|
||||
}
|
||||
|
||||
// TableMeta contains metadata for a table to be imported
|
||||
type TableMeta struct {
|
||||
Database string
|
||||
Table string
|
||||
DataFiles []DataFileMeta
|
||||
TotalSize int64 // In bytes
|
||||
WildcardPath string // Wildcard pattern that matches only this table's data files
|
||||
SchemaFile string // Path to the table schema file, if available
|
||||
}
|
||||
|
||||
// DataFileMeta contains metadata for a data file
|
||||
type DataFileMeta struct {
|
||||
Path string
|
||||
Size int64
|
||||
Format mydump.SourceType
|
||||
Compression mydump.Compression
|
||||
}
|
||||
|
||||
// ImportSDK implements SDK interface
|
||||
type ImportSDK struct {
|
||||
sourcePath string
|
||||
db *sql.DB
|
||||
store storage.ExternalStorage
|
||||
loader *mydump.MDLoader
|
||||
logger log.Logger
|
||||
config *sdkConfig
|
||||
// importSDK implements SDK interface
|
||||
type importSDK struct {
|
||||
FileScanner
|
||||
JobManager
|
||||
SQLGenerator
|
||||
}
|
||||
|
||||
// NewImportSDK creates a new CloudImportSDK instance
|
||||
@ -81,410 +41,22 @@ func NewImportSDK(ctx context.Context, sourcePath string, db *sql.DB, options ..
|
||||
opt(cfg)
|
||||
}
|
||||
|
||||
u, err := storage.ParseBackend(sourcePath, nil)
|
||||
scanner, err := NewFileScanner(ctx, sourcePath, db, cfg)
|
||||
if err != nil {
|
||||
return nil, errors.Annotatef(err, "failed to parse storage backend URL (source=%s). Please verify the URL format and credentials", sourcePath)
|
||||
}
|
||||
store, err := storage.New(ctx, u, &storage.ExternalStorageOptions{})
|
||||
if err != nil {
|
||||
return nil, errors.Annotatef(err, "failed to create external storage (source=%s). Check network/connectivity and permissions", sourcePath)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
ldrCfg := mydump.LoaderConfig{
|
||||
SourceURL: sourcePath,
|
||||
Filter: cfg.filter,
|
||||
FileRouters: cfg.fileRouteRules,
|
||||
DefaultFileRules: len(cfg.fileRouteRules) == 0,
|
||||
CharacterSet: cfg.charset,
|
||||
}
|
||||
jobManager := NewJobManager(db)
|
||||
sqlGenerator := NewSQLGenerator()
|
||||
|
||||
loader, err := mydump.NewLoaderWithStore(ctx, ldrCfg, store)
|
||||
if err != nil {
|
||||
return nil, errors.Annotatef(err, "failed to create MyDump loader (source=%s, charset=%s, filter=%v). Please check dump layout and router rules", sourcePath, cfg.charset, cfg.filter)
|
||||
}
|
||||
|
||||
return &ImportSDK{
|
||||
sourcePath: sourcePath,
|
||||
db: db,
|
||||
store: store,
|
||||
loader: loader,
|
||||
logger: cfg.logger,
|
||||
config: cfg,
|
||||
return &importSDK{
|
||||
FileScanner: scanner,
|
||||
JobManager: jobManager,
|
||||
SQLGenerator: sqlGenerator,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// SDKOption customizes the SDK configuration
|
||||
type SDKOption func(*sdkConfig)
|
||||
|
||||
type sdkConfig struct {
|
||||
// Loader options
|
||||
concurrency int
|
||||
sqlMode mysql.SQLMode
|
||||
fileRouteRules []*config.FileRouteRule
|
||||
filter []string
|
||||
charset string
|
||||
|
||||
// General options
|
||||
logger log.Logger
|
||||
// Close releases resources used by the SDK
|
||||
func (sdk *importSDK) Close() error {
|
||||
return sdk.FileScanner.Close()
|
||||
}
|
||||
|
||||
func defaultSDKConfig() *sdkConfig {
|
||||
return &sdkConfig{
|
||||
concurrency: 4,
|
||||
filter: config.GetDefaultFilter(),
|
||||
logger: log.L(),
|
||||
charset: "auto",
|
||||
}
|
||||
}
|
||||
|
||||
// WithConcurrency sets the number of concurrent DB/Table creation workers.
|
||||
func WithConcurrency(n int) SDKOption {
|
||||
return func(cfg *sdkConfig) {
|
||||
if n > 0 {
|
||||
cfg.concurrency = n
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// WithLogger specifies a custom logger
|
||||
func WithLogger(logger log.Logger) SDKOption {
|
||||
return func(cfg *sdkConfig) {
|
||||
cfg.logger = logger
|
||||
}
|
||||
}
|
||||
|
||||
// WithSQLMode specifies the SQL mode for schema parsing
|
||||
func WithSQLMode(mode mysql.SQLMode) SDKOption {
|
||||
return func(cfg *sdkConfig) {
|
||||
cfg.sqlMode = mode
|
||||
}
|
||||
}
|
||||
|
||||
// WithFilter specifies a filter for the loader
|
||||
func WithFilter(filter []string) SDKOption {
|
||||
return func(cfg *sdkConfig) {
|
||||
cfg.filter = filter
|
||||
}
|
||||
}
|
||||
|
||||
// WithFileRouters specifies custom file routing rules
|
||||
func WithFileRouters(routers []*config.FileRouteRule) SDKOption {
|
||||
return func(cfg *sdkConfig) {
|
||||
cfg.fileRouteRules = routers
|
||||
}
|
||||
}
|
||||
|
||||
// WithCharset specifies the character set for import (default "auto").
|
||||
func WithCharset(cs string) SDKOption {
|
||||
return func(cfg *sdkConfig) {
|
||||
if cs != "" {
|
||||
cfg.charset = cs
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// CreateSchemasAndTables implements the CloudImportSDK interface
|
||||
func (sdk *ImportSDK) CreateSchemasAndTables(ctx context.Context) error {
|
||||
dbMetas := sdk.loader.GetDatabases()
|
||||
if len(dbMetas) == 0 {
|
||||
return errors.Annotatef(ErrNoDatabasesFound, "source=%s. Ensure the path contains valid dump files (*.sql, *.csv, *.parquet, etc.) and filter rules are correct", sdk.sourcePath)
|
||||
}
|
||||
|
||||
// Create all schemas and tables
|
||||
importer := mydump.NewSchemaImporter(
|
||||
sdk.logger,
|
||||
sdk.config.sqlMode,
|
||||
sdk.db,
|
||||
sdk.store,
|
||||
sdk.config.concurrency,
|
||||
)
|
||||
|
||||
err := importer.Run(ctx, dbMetas)
|
||||
if err != nil {
|
||||
return errors.Annotatef(err, "creating schemas and tables failed (source=%s, db_count=%d, concurrency=%d)", sdk.sourcePath, len(dbMetas), sdk.config.concurrency)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetTableMetas implements the CloudImportSDK interface
|
||||
func (sdk *ImportSDK) GetTableMetas(context.Context) ([]*TableMeta, error) {
|
||||
dbMetas := sdk.loader.GetDatabases()
|
||||
allFiles := sdk.loader.GetAllFiles()
|
||||
var results []*TableMeta
|
||||
for _, dbMeta := range dbMetas {
|
||||
for _, tblMeta := range dbMeta.Tables {
|
||||
tableMeta, err := sdk.buildTableMeta(dbMeta, tblMeta, allFiles)
|
||||
if err != nil {
|
||||
return nil, errors.Wrapf(err, "failed to build metadata for table %s.%s",
|
||||
dbMeta.Name, tblMeta.Name)
|
||||
}
|
||||
results = append(results, tableMeta)
|
||||
}
|
||||
}
|
||||
|
||||
return results, nil
|
||||
}
|
||||
|
||||
// GetTableMetaByName implements CloudImportSDK interface
|
||||
func (sdk *ImportSDK) GetTableMetaByName(_ context.Context, schema, table string) (*TableMeta, error) {
|
||||
dbMetas := sdk.loader.GetDatabases()
|
||||
allFiles := sdk.loader.GetAllFiles()
|
||||
// Find the specific table
|
||||
for _, dbMeta := range dbMetas {
|
||||
if dbMeta.Name != schema {
|
||||
continue
|
||||
}
|
||||
|
||||
for _, tblMeta := range dbMeta.Tables {
|
||||
if tblMeta.Name != table {
|
||||
continue
|
||||
}
|
||||
|
||||
return sdk.buildTableMeta(dbMeta, tblMeta, allFiles)
|
||||
}
|
||||
|
||||
return nil, errors.Annotatef(ErrTableNotFound, "schema=%s, table=%s", schema, table)
|
||||
}
|
||||
|
||||
return nil, errors.Annotatef(ErrSchemaNotFound, "schema=%s", schema)
|
||||
}
|
||||
|
||||
// GetTotalSize implements CloudImportSDK interface
|
||||
func (sdk *ImportSDK) GetTotalSize(ctx context.Context) int64 {
|
||||
var total int64
|
||||
dbMetas := sdk.loader.GetDatabases()
|
||||
for _, dbMeta := range dbMetas {
|
||||
for _, tblMeta := range dbMeta.Tables {
|
||||
total += tblMeta.TotalSize
|
||||
}
|
||||
}
|
||||
return total
|
||||
}
|
||||
|
||||
// buildTableMeta creates a TableMeta from database and table metadata
|
||||
func (sdk *ImportSDK) buildTableMeta(
|
||||
dbMeta *mydump.MDDatabaseMeta,
|
||||
tblMeta *mydump.MDTableMeta,
|
||||
allDataFiles map[string]mydump.FileInfo,
|
||||
) (*TableMeta, error) {
|
||||
tableMeta := &TableMeta{
|
||||
Database: dbMeta.Name,
|
||||
Table: tblMeta.Name,
|
||||
DataFiles: make([]DataFileMeta, 0, len(tblMeta.DataFiles)),
|
||||
SchemaFile: tblMeta.SchemaFile.FileMeta.Path,
|
||||
}
|
||||
|
||||
// Process data files
|
||||
dataFiles, totalSize := processDataFiles(tblMeta.DataFiles)
|
||||
tableMeta.DataFiles = dataFiles
|
||||
tableMeta.TotalSize = totalSize
|
||||
|
||||
wildcard, err := generateWildcardPath(tblMeta.DataFiles, allDataFiles)
|
||||
if err != nil {
|
||||
return nil, errors.Annotatef(err, "failed to build wildcard for table=%s.%s", dbMeta.Name, tblMeta.Name)
|
||||
}
|
||||
tableMeta.WildcardPath = strings.TrimSuffix(sdk.store.URI(), "/") + "/" + wildcard
|
||||
|
||||
return tableMeta, nil
|
||||
}
|
||||
|
||||
// Close implements CloudImportSDK interface
|
||||
func (sdk *ImportSDK) Close() error {
|
||||
// close external storage
|
||||
if sdk.store != nil {
|
||||
sdk.store.Close()
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// processDataFiles converts mydump data files to DataFileMeta and calculates total size
|
||||
func processDataFiles(files []mydump.FileInfo) ([]DataFileMeta, int64) {
|
||||
dataFiles := make([]DataFileMeta, 0, len(files))
|
||||
var totalSize int64
|
||||
|
||||
for _, dataFile := range files {
|
||||
fileMeta := createDataFileMeta(dataFile)
|
||||
dataFiles = append(dataFiles, fileMeta)
|
||||
totalSize += dataFile.FileMeta.RealSize
|
||||
}
|
||||
|
||||
return dataFiles, totalSize
|
||||
}
|
||||
|
||||
// createDataFileMeta creates a DataFileMeta from a mydump.DataFile
|
||||
func createDataFileMeta(file mydump.FileInfo) DataFileMeta {
|
||||
return DataFileMeta{
|
||||
Path: file.FileMeta.Path,
|
||||
Size: file.FileMeta.RealSize,
|
||||
Format: file.FileMeta.Type,
|
||||
Compression: file.FileMeta.Compression,
|
||||
}
|
||||
}
|
||||
|
||||
// generateWildcardPath creates a wildcard pattern path that matches only this table's files
|
||||
func generateWildcardPath(
|
||||
files []mydump.FileInfo,
|
||||
allFiles map[string]mydump.FileInfo,
|
||||
) (string, error) {
|
||||
tableFiles := make(map[string]struct{}, len(files))
|
||||
for _, df := range files {
|
||||
tableFiles[df.FileMeta.Path] = struct{}{}
|
||||
}
|
||||
|
||||
if len(files) == 0 {
|
||||
return "", errors.Annotate(ErrNoTableDataFiles, "cannot generate wildcard pattern because the table has no data files")
|
||||
}
|
||||
|
||||
// If there's only one file, we can just return its path
|
||||
if len(files) == 1 {
|
||||
return files[0].FileMeta.Path, nil
|
||||
}
|
||||
|
||||
// Try Mydumper-specific pattern first
|
||||
p := generateMydumperPattern(files[0])
|
||||
if p != "" && isValidPattern(p, tableFiles, allFiles) {
|
||||
return p, nil
|
||||
}
|
||||
|
||||
// Fallback to generic prefix/suffix pattern
|
||||
paths := make([]string, 0, len(files))
|
||||
for _, file := range files {
|
||||
paths = append(paths, file.FileMeta.Path)
|
||||
}
|
||||
p = generatePrefixSuffixPattern(paths)
|
||||
if p != "" && isValidPattern(p, tableFiles, allFiles) {
|
||||
return p, nil
|
||||
}
|
||||
return "", errors.Annotatef(ErrWildcardNotSpecific, "failed to find a wildcard that matches all and only the table's files.")
|
||||
}
|
||||
|
||||
// isValidPattern checks if a wildcard pattern matches only the table's files
|
||||
func isValidPattern(pattern string, tableFiles map[string]struct{}, allFiles map[string]mydump.FileInfo) bool {
|
||||
if pattern == "" {
|
||||
return false
|
||||
}
|
||||
|
||||
for path := range allFiles {
|
||||
isMatch, err := filepath.Match(pattern, path)
|
||||
if err != nil {
|
||||
return false // Invalid pattern
|
||||
}
|
||||
_, isTableFile := tableFiles[path]
|
||||
|
||||
// If pattern matches a file that's not from our table, it's invalid
|
||||
if isMatch && !isTableFile {
|
||||
return false
|
||||
}
|
||||
|
||||
// If pattern doesn't match our table's file, it's also invalid
|
||||
if !isMatch && isTableFile {
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
return true
|
||||
}
|
||||
|
||||
// generateMydumperPattern generates a wildcard pattern for Mydumper-formatted data files
|
||||
// belonging to a specific table, based on their naming convention.
|
||||
// It returns a pattern string that matches all data files for the table, or an empty string if not applicable.
|
||||
func generateMydumperPattern(file mydump.FileInfo) string {
|
||||
dbName, tableName := file.TableName.Schema, file.TableName.Name
|
||||
if dbName == "" || tableName == "" {
|
||||
return ""
|
||||
}
|
||||
|
||||
// compute dirPrefix and basename
|
||||
full := file.FileMeta.Path
|
||||
dirPrefix, name := "", full
|
||||
if idx := strings.LastIndex(full, "/"); idx >= 0 {
|
||||
dirPrefix = full[:idx+1]
|
||||
name = full[idx+1:]
|
||||
}
|
||||
|
||||
// compression ext from filename when compression exists (last suffix like .gz/.zst)
|
||||
compExt := ""
|
||||
if file.FileMeta.Compression != mydump.CompressionNone {
|
||||
compExt = filepath.Ext(name)
|
||||
}
|
||||
|
||||
// data ext after stripping compression ext
|
||||
base := strings.TrimSuffix(name, compExt)
|
||||
dataExt := filepath.Ext(base)
|
||||
return dirPrefix + dbName + "." + tableName + ".*" + dataExt + compExt
|
||||
}
|
||||
|
||||
// longestCommonPrefix finds the longest string that is a prefix of all strings in the slice
|
||||
func longestCommonPrefix(strs []string) string {
|
||||
if len(strs) == 0 {
|
||||
return ""
|
||||
}
|
||||
|
||||
prefix := strs[0]
|
||||
for _, s := range strs[1:] {
|
||||
i := 0
|
||||
for i < len(prefix) && i < len(s) && prefix[i] == s[i] {
|
||||
i++
|
||||
}
|
||||
prefix = prefix[:i]
|
||||
if prefix == "" {
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
return prefix
|
||||
}
|
||||
|
||||
// longestCommonSuffix finds the longest string that is a suffix of all strings in the slice, starting after the given prefix length
|
||||
func longestCommonSuffix(strs []string, prefixLen int) string {
|
||||
if len(strs) == 0 {
|
||||
return ""
|
||||
}
|
||||
|
||||
suffix := strs[0][prefixLen:]
|
||||
for _, s := range strs[1:] {
|
||||
remaining := s[prefixLen:]
|
||||
i := 0
|
||||
for i < len(suffix) && i < len(remaining) && suffix[len(suffix)-i-1] == remaining[len(remaining)-i-1] {
|
||||
i++
|
||||
}
|
||||
suffix = suffix[len(suffix)-i:]
|
||||
if suffix == "" {
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
return suffix
|
||||
}
|
||||
|
||||
// generatePrefixSuffixPattern returns a wildcard pattern that matches all and only the given paths
|
||||
// by finding the longest common prefix and suffix among them, and placing a '*' wildcard in between.
|
||||
func generatePrefixSuffixPattern(paths []string) string {
|
||||
if len(paths) == 0 {
|
||||
return ""
|
||||
}
|
||||
if len(paths) == 1 {
|
||||
return paths[0]
|
||||
}
|
||||
|
||||
prefix := longestCommonPrefix(paths)
|
||||
suffix := longestCommonSuffix(paths, len(prefix))
|
||||
|
||||
return prefix + "*" + suffix
|
||||
}
|
||||
|
||||
// TODO: add error code and doc for cloud sdk
|
||||
// Sentinel errors to categorize common failure scenarios for clearer user messages.
|
||||
var (
|
||||
// ErrNoDatabasesFound indicates that the dump source contains no recognizable databases.
|
||||
ErrNoDatabasesFound = errors.New("no databases found in the source path")
|
||||
// ErrSchemaNotFound indicates the target schema doesn't exist in the dump source.
|
||||
ErrSchemaNotFound = errors.New("schema not found")
|
||||
// ErrTableNotFound indicates the target table doesn't exist in the dump source.
|
||||
ErrTableNotFound = errors.New("table not found")
|
||||
// ErrNoTableDataFiles indicates a table has zero data files and thus cannot proceed.
|
||||
ErrNoTableDataFiles = errors.New("no data files for table")
|
||||
// ErrWildcardNotSpecific indicates a wildcard cannot uniquely match the table's files.
|
||||
ErrWildcardNotSpecific = errors.New("cannot generate a unique wildcard pattern for the table's data files")
|
||||
)
|
||||
|
||||
@ -27,8 +27,6 @@ import (
|
||||
"github.com/pingcap/tidb/pkg/lightning/log"
|
||||
"github.com/pingcap/tidb/pkg/lightning/mydump"
|
||||
"github.com/pingcap/tidb/pkg/parser/mysql"
|
||||
filter "github.com/pingcap/tidb/pkg/util/table-filter"
|
||||
"github.com/stretchr/testify/require"
|
||||
"github.com/stretchr/testify/suite"
|
||||
)
|
||||
|
||||
@ -355,219 +353,82 @@ func (s *mockGCSSuite) TestOnlyDataFiles() {
|
||||
s.NoError(err)
|
||||
}
|
||||
|
||||
func TestLongestCommonPrefix(t *testing.T) {
|
||||
strs := []string{"s3://bucket/foo/bar/baz1", "s3://bucket/foo/bar/baz2", "s3://bucket/foo/bar/baz"}
|
||||
p := longestCommonPrefix(strs)
|
||||
require.Equal(t, "s3://bucket/foo/bar/baz", p)
|
||||
|
||||
// no common prefix
|
||||
require.Equal(t, "", longestCommonPrefix([]string{"a", "b"}))
|
||||
|
||||
// empty inputs
|
||||
require.Equal(t, "", longestCommonPrefix(nil))
|
||||
require.Equal(t, "", longestCommonPrefix([]string{}))
|
||||
func (s *mockGCSSuite) TestScanLimitation() {
|
||||
s.server.CreateBucketWithOpts(fakestorage.CreateBucketOpts{Name: "limitation"})
|
||||
s.server.CreateObject(fakestorage.Object{
|
||||
ObjectAttrs: fakestorage.ObjectAttrs{BucketName: "limitation", Name: "db2.tb2.001.csv"},
|
||||
Content: []byte("5,e\n6,f\n"),
|
||||
})
|
||||
s.server.CreateObject(fakestorage.Object{
|
||||
ObjectAttrs: fakestorage.ObjectAttrs{BucketName: "limitation", Name: "db2.tb2.002.csv"},
|
||||
Content: []byte("7,g\n8,h\n"),
|
||||
})
|
||||
db, _, err := sqlmock.New()
|
||||
s.NoError(err)
|
||||
defer db.Close()
|
||||
importSDK, err := NewImportSDK(
|
||||
context.Background(),
|
||||
fmt.Sprintf("gs://limitation?endpoint=%s&access-key=aaaaaa&secret-access-key=bbbbbb", gcsEndpoint),
|
||||
db,
|
||||
WithCharset("utf8"),
|
||||
WithConcurrency(8),
|
||||
WithFilter([]string{"*.*"}),
|
||||
WithSQLMode(mysql.ModeANSIQuotes),
|
||||
WithLogger(log.L()),
|
||||
WithSkipInvalidFiles(true),
|
||||
WithMaxScanFiles(1),
|
||||
)
|
||||
s.NoError(err)
|
||||
defer importSDK.Close()
|
||||
metas, err := importSDK.GetTableMetas(context.Background())
|
||||
s.NoError(err)
|
||||
s.Len(metas, 1)
|
||||
s.Len(metas[0].DataFiles, 1)
|
||||
}
|
||||
|
||||
func TestLongestCommonSuffix(t *testing.T) {
|
||||
strs := []string{"abcXYZ", "defXYZ", "XYZ"}
|
||||
s := longestCommonSuffix(strs, 0)
|
||||
require.Equal(t, "XYZ", s)
|
||||
|
||||
// no common suffix
|
||||
require.Equal(t, "", longestCommonSuffix([]string{"a", "b"}, 0))
|
||||
|
||||
// empty inputs
|
||||
require.Equal(t, "", longestCommonSuffix(nil, 0))
|
||||
require.Equal(t, "", longestCommonSuffix([]string{}, 0))
|
||||
|
||||
// same prefix
|
||||
require.Equal(t, "", longestCommonSuffix([]string{"abc", "abc"}, 3))
|
||||
require.Equal(t, "f", longestCommonSuffix([]string{"abcdf", "abcef"}, 3))
|
||||
}
|
||||
|
||||
func TestGeneratePrefixSuffixPattern(t *testing.T) {
|
||||
paths := []string{"pre_middle_suf", "pre_most_suf"}
|
||||
pattern := generatePrefixSuffixPattern(paths)
|
||||
// common prefix "pre_m", suffix "_suf"
|
||||
require.Equal(t, "pre_m*_suf", pattern)
|
||||
|
||||
// empty inputs
|
||||
require.Equal(t, "", generatePrefixSuffixPattern(nil))
|
||||
require.Equal(t, "", generatePrefixSuffixPattern([]string{}))
|
||||
|
||||
// only one file
|
||||
require.Equal(t, "pre_middle_suf", generatePrefixSuffixPattern([]string{"pre_middle_suf"}))
|
||||
|
||||
// no common prefix/suffix
|
||||
paths2 := []string{"foo", "bar"}
|
||||
require.Equal(t, "*", generatePrefixSuffixPattern(paths2))
|
||||
|
||||
// overlapping prefix/suffix
|
||||
paths3 := []string{"aaabaaa", "aaa"}
|
||||
require.Equal(t, "aaa*", generatePrefixSuffixPattern(paths3))
|
||||
}
|
||||
|
||||
func generateFileMetas(t *testing.T, paths []string) []mydump.FileInfo {
|
||||
t.Helper()
|
||||
|
||||
files := make([]mydump.FileInfo, 0, len(paths))
|
||||
fileRouter, err := mydump.NewDefaultFileRouter(log.L())
|
||||
require.NoError(t, err)
|
||||
for _, p := range paths {
|
||||
res, err := fileRouter.Route(p)
|
||||
require.NoError(t, err)
|
||||
files = append(files, mydump.FileInfo{
|
||||
TableName: res.Table,
|
||||
FileMeta: mydump.SourceFileMeta{
|
||||
Path: p,
|
||||
Type: res.Type,
|
||||
Compression: res.Compression,
|
||||
SortKey: res.Key,
|
||||
},
|
||||
func (s *mockGCSSuite) TestCreateTableMetaByName() {
|
||||
for i := 0; i < 2; i++ {
|
||||
s.server.CreateObject(fakestorage.Object{
|
||||
ObjectAttrs: fakestorage.ObjectAttrs{BucketName: "specific-table-test", Name: fmt.Sprintf("db%d-schema-create.sql", i)},
|
||||
Content: []byte(fmt.Sprintf("CREATE DATABASE IF NOT EXISTS db%d;\n", i)),
|
||||
})
|
||||
}
|
||||
return files
|
||||
}
|
||||
|
||||
func TestGenerateMydumperPattern(t *testing.T) {
|
||||
paths := []string{"db.tb.0001.sql", "db.tb.0002.sql"}
|
||||
p := generateMydumperPattern(generateFileMetas(t, paths)[0])
|
||||
require.Equal(t, "db.tb.*.sql", p)
|
||||
|
||||
paths2 := []string{"s3://bucket/dir/db.tb.0001.sql", "s3://bucket/dir/db.tb.0002.sql"}
|
||||
p2 := generateMydumperPattern(generateFileMetas(t, paths2)[0])
|
||||
require.Equal(t, "s3://bucket/dir/db.tb.*.sql", p2)
|
||||
|
||||
// not mydumper pattern
|
||||
require.Equal(t, "", generateMydumperPattern(mydump.FileInfo{
|
||||
TableName: filter.Table{},
|
||||
}))
|
||||
}
|
||||
|
||||
func TestCreateDataFileMeta(t *testing.T) {
|
||||
fi := mydump.FileInfo{
|
||||
TableName: filter.Table{
|
||||
Schema: "db",
|
||||
Name: "table",
|
||||
},
|
||||
FileMeta: mydump.SourceFileMeta{
|
||||
Path: "s3://bucket/path/to/f",
|
||||
FileSize: 123,
|
||||
Type: mydump.SourceTypeCSV,
|
||||
Compression: mydump.CompressionGZ,
|
||||
RealSize: 456,
|
||||
},
|
||||
}
|
||||
df := createDataFileMeta(fi)
|
||||
require.Equal(t, "s3://bucket/path/to/f", df.Path)
|
||||
require.Equal(t, int64(456), df.Size)
|
||||
require.Equal(t, mydump.SourceTypeCSV, df.Format)
|
||||
require.Equal(t, mydump.CompressionGZ, df.Compression)
|
||||
}
|
||||
|
||||
func TestProcessDataFiles(t *testing.T) {
|
||||
files := []mydump.FileInfo{
|
||||
{FileMeta: mydump.SourceFileMeta{Path: "s3://bucket/a", RealSize: 10}},
|
||||
{FileMeta: mydump.SourceFileMeta{Path: "s3://bucket/b", RealSize: 20}},
|
||||
}
|
||||
dfm, total := processDataFiles(files)
|
||||
require.Len(t, dfm, 2)
|
||||
require.Equal(t, int64(30), total)
|
||||
require.Equal(t, "s3://bucket/a", dfm[0].Path)
|
||||
require.Equal(t, "s3://bucket/b", dfm[1].Path)
|
||||
}
|
||||
|
||||
func TestValidatePattern(t *testing.T) {
|
||||
tableFiles := map[string]struct{}{
|
||||
"a.txt": {}, "b.txt": {},
|
||||
}
|
||||
// only table files in allFiles
|
||||
smallAll := map[string]mydump.FileInfo{
|
||||
"a.txt": {}, "b.txt": {},
|
||||
}
|
||||
require.True(t, isValidPattern("*.txt", tableFiles, smallAll))
|
||||
|
||||
// allFiles includes an extra file => invalid
|
||||
fullAll := map[string]mydump.FileInfo{
|
||||
"a.txt": {}, "b.txt": {}, "c.txt": {},
|
||||
}
|
||||
require.False(t, isValidPattern("*.txt", tableFiles, fullAll))
|
||||
|
||||
// If pattern doesn't match our table's file, it's also invalid
|
||||
require.False(t, isValidPattern("*.csv", tableFiles, smallAll))
|
||||
|
||||
// empty pattern => invalid
|
||||
require.False(t, isValidPattern("", tableFiles, smallAll))
|
||||
}
|
||||
|
||||
func TestGenerateWildcardPath(t *testing.T) {
|
||||
// Helper to create allFiles map
|
||||
createAllFiles := func(paths []string) map[string]mydump.FileInfo {
|
||||
allFiles := make(map[string]mydump.FileInfo)
|
||||
for _, p := range paths {
|
||||
allFiles[p] = mydump.FileInfo{
|
||||
FileMeta: mydump.SourceFileMeta{Path: p},
|
||||
}
|
||||
for j := 0; j != 2; j++ {
|
||||
tableName := fmt.Sprintf("db%d.tb%d", i, j)
|
||||
s.server.CreateObject(fakestorage.Object{
|
||||
ObjectAttrs: fakestorage.ObjectAttrs{BucketName: "specific-table-test", Name: fmt.Sprintf("%s-schema.sql", tableName)},
|
||||
Content: []byte(fmt.Sprintf("CREATE TABLE IF NOT EXISTS db%d.tb%d (a INT, b VARCHAR(10));\n", i, j)),
|
||||
})
|
||||
s.server.CreateObject(fakestorage.Object{
|
||||
ObjectAttrs: fakestorage.ObjectAttrs{BucketName: "specific-table-test", Name: fmt.Sprintf("%s.001.sql", tableName)},
|
||||
Content: []byte(fmt.Sprintf("INSERT INTO db%d.tb%d VALUES (1,'a'),(2,'b');\n", i, j)),
|
||||
})
|
||||
s.server.CreateObject(fakestorage.Object{
|
||||
ObjectAttrs: fakestorage.ObjectAttrs{BucketName: "specific-table-test", Name: fmt.Sprintf("%s.002.sql", tableName)},
|
||||
Content: []byte(fmt.Sprintf("INSERT INTO db%d.tb%d VALUES (3,'c'),(4,'d');\n", i, j)),
|
||||
})
|
||||
}
|
||||
return allFiles
|
||||
}
|
||||
|
||||
// No files
|
||||
files1 := []mydump.FileInfo{}
|
||||
allFiles1 := createAllFiles([]string{})
|
||||
_, err := generateWildcardPath(files1, allFiles1)
|
||||
require.Error(t, err)
|
||||
require.Contains(t, err.Error(), "no data files for table")
|
||||
db, mock, err := sqlmock.New()
|
||||
s.NoError(err)
|
||||
defer db.Close()
|
||||
|
||||
// Single file
|
||||
files2 := generateFileMetas(t, []string{"db.tb.0001.sql"})
|
||||
allFiles2 := createAllFiles([]string{"db.tb.0001.sql"})
|
||||
path2, err := generateWildcardPath(files2, allFiles2)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, "db.tb.0001.sql", path2)
|
||||
mock.ExpectQuery(`SELECT SCHEMA_NAME FROM information_schema.SCHEMATA`).
|
||||
WillReturnRows(sqlmock.NewRows([]string{"SCHEMA_NAME"}))
|
||||
mock.ExpectExec("CREATE DATABASE IF NOT EXISTS `db1`;").
|
||||
WillReturnResult(sqlmock.NewResult(0, 1))
|
||||
mock.ExpectExec(regexp.QuoteMeta("CREATE TABLE IF NOT EXISTS `db1`.`tb1` (`a` INT,`b` VARCHAR(10));")).
|
||||
WillReturnResult(sqlmock.NewResult(0, 1))
|
||||
|
||||
// Mydumper pattern succeeds
|
||||
files3 := generateFileMetas(t, []string{"db.tb.0001.sql.gz", "db.tb.0002.sql.gz"})
|
||||
allFiles3 := createAllFiles([]string{"db.tb.0001.sql.gz", "db.tb.0002.sql.gz"})
|
||||
path3, err := generateWildcardPath(files3, allFiles3)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, "db.tb.*.sql.gz", path3)
|
||||
importSDK, err := NewImportSDK(
|
||||
context.Background(),
|
||||
fmt.Sprintf("gs://specific-table-test?endpoint=%s&access-key=aaaaaa&secret-access-key=bbbbbb", gcsEndpoint),
|
||||
db,
|
||||
WithConcurrency(1),
|
||||
)
|
||||
s.NoError(err)
|
||||
defer importSDK.Close()
|
||||
|
||||
// Mydumper pattern fails, fallback to prefix/suffix succeeds
|
||||
files4 := []mydump.FileInfo{
|
||||
{
|
||||
TableName: filter.Table{Schema: "db", Name: "tb"},
|
||||
FileMeta: mydump.SourceFileMeta{
|
||||
Path: "a.sql",
|
||||
Type: mydump.SourceTypeSQL,
|
||||
Compression: mydump.CompressionNone,
|
||||
},
|
||||
},
|
||||
{
|
||||
TableName: filter.Table{Schema: "db", Name: "tb"},
|
||||
FileMeta: mydump.SourceFileMeta{
|
||||
Path: "b.sql",
|
||||
Type: mydump.SourceTypeSQL,
|
||||
Compression: mydump.CompressionNone,
|
||||
},
|
||||
},
|
||||
}
|
||||
allFiles4 := map[string]mydump.FileInfo{
|
||||
files4[0].FileMeta.Path: files4[0],
|
||||
files4[1].FileMeta.Path: files4[1],
|
||||
}
|
||||
path4, err := generateWildcardPath(files4, allFiles4)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, "*.sql", path4)
|
||||
|
||||
allFiles4["db-schema.sql"] = mydump.FileInfo{
|
||||
FileMeta: mydump.SourceFileMeta{
|
||||
Path: "db-schema.sql",
|
||||
Type: mydump.SourceTypeSQL,
|
||||
Compression: mydump.CompressionNone,
|
||||
},
|
||||
}
|
||||
_, err = generateWildcardPath(files4, allFiles4)
|
||||
require.Error(t, err)
|
||||
require.Contains(t, err.Error(), "cannot generate a unique wildcard pattern")
|
||||
err = importSDK.CreateSchemaAndTableByName(context.Background(), "db1", "tb1")
|
||||
s.NoError(err)
|
||||
}
|
||||
|
||||
161
pkg/importsdk/sql_generator.go
Normal file
161
pkg/importsdk/sql_generator.go
Normal file
@ -0,0 +1,161 @@
|
||||
// Copyright 2025 PingCAP, Inc.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package importsdk
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net/url"
|
||||
"strings"
|
||||
|
||||
"github.com/pingcap/tidb/pkg/lightning/config"
|
||||
)
|
||||
|
||||
// SQLGenerator defines the interface for generating IMPORT INTO SQL
|
||||
type SQLGenerator interface {
|
||||
GenerateImportSQL(tableMeta *TableMeta, options *ImportOptions) (string, error)
|
||||
}
|
||||
|
||||
type sqlGenerator struct{}
|
||||
|
||||
// NewSQLGenerator creates a new SQLGenerator
|
||||
func NewSQLGenerator() SQLGenerator {
|
||||
return &sqlGenerator{}
|
||||
}
|
||||
|
||||
// GenerateImportSQL generates the IMPORT INTO SQL statement
|
||||
func (g *sqlGenerator) GenerateImportSQL(tableMeta *TableMeta, options *ImportOptions) (string, error) {
|
||||
var sb strings.Builder
|
||||
sb.WriteString("IMPORT INTO ")
|
||||
sb.WriteString(escapeIdentifier(tableMeta.Database))
|
||||
sb.WriteString(".")
|
||||
sb.WriteString(escapeIdentifier(tableMeta.Table))
|
||||
|
||||
path := tableMeta.WildcardPath
|
||||
if options.ResourceParameters != "" {
|
||||
u, err := url.Parse(path)
|
||||
if err == nil {
|
||||
if u.RawQuery != "" {
|
||||
u.RawQuery += "&" + options.ResourceParameters
|
||||
} else {
|
||||
u.RawQuery = options.ResourceParameters
|
||||
}
|
||||
path = u.String()
|
||||
}
|
||||
}
|
||||
|
||||
sb.WriteString(" FROM '")
|
||||
sb.WriteString(path)
|
||||
sb.WriteString("'")
|
||||
|
||||
if options.Format != "" {
|
||||
sb.WriteString(" FORMAT '")
|
||||
sb.WriteString(options.Format)
|
||||
sb.WriteString("'")
|
||||
}
|
||||
|
||||
opts, err := g.buildOptions(options)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
if len(opts) > 0 {
|
||||
sb.WriteString(" WITH ")
|
||||
sb.WriteString(strings.Join(opts, ", "))
|
||||
}
|
||||
|
||||
return sb.String(), nil
|
||||
}
|
||||
|
||||
func (g *sqlGenerator) buildOptions(options *ImportOptions) ([]string, error) {
|
||||
var opts []string
|
||||
if options.Thread > 0 {
|
||||
opts = append(opts, fmt.Sprintf("THREAD=%d", options.Thread))
|
||||
}
|
||||
if options.DiskQuota != "" {
|
||||
opts = append(opts, fmt.Sprintf("DISK_QUOTA='%s'", options.DiskQuota))
|
||||
}
|
||||
if options.MaxWriteSpeed != "" {
|
||||
opts = append(opts, fmt.Sprintf("MAX_WRITE_SPEED='%s'", options.MaxWriteSpeed))
|
||||
}
|
||||
if options.SplitFile {
|
||||
opts = append(opts, "SPLIT_FILE")
|
||||
}
|
||||
if options.RecordErrors > 0 {
|
||||
opts = append(opts, fmt.Sprintf("RECORD_ERRORS=%d", options.RecordErrors))
|
||||
}
|
||||
if options.Detached {
|
||||
opts = append(opts, "DETACHED")
|
||||
}
|
||||
if options.CloudStorageURI != "" {
|
||||
opts = append(opts, fmt.Sprintf("CLOUD_STORAGE_URI='%s'", options.CloudStorageURI))
|
||||
}
|
||||
if options.GroupKey != "" {
|
||||
opts = append(opts, fmt.Sprintf("GROUP_KEY='%s'", escapeString(options.GroupKey)))
|
||||
}
|
||||
if options.SkipRows > 0 {
|
||||
opts = append(opts, fmt.Sprintf("SKIP_ROWS=%d", options.SkipRows))
|
||||
}
|
||||
if options.CharacterSet != "" {
|
||||
opts = append(opts, fmt.Sprintf("CHARACTER_SET='%s'", escapeString(options.CharacterSet)))
|
||||
}
|
||||
if options.ChecksumTable != "" {
|
||||
opts = append(opts, fmt.Sprintf("CHECKSUM_TABLE='%s'", escapeString(options.ChecksumTable)))
|
||||
}
|
||||
if options.DisableTiKVImportMode {
|
||||
opts = append(opts, "DISABLE_TIKV_IMPORT_MODE")
|
||||
}
|
||||
if options.DisablePrecheck {
|
||||
opts = append(opts, "DISABLE_PRECHECK")
|
||||
}
|
||||
|
||||
if options.CSVConfig != nil && options.Format == "csv" {
|
||||
csvOpts, err := g.buildCSVOptions(options.CSVConfig)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
opts = append(opts, csvOpts...)
|
||||
}
|
||||
return opts, nil
|
||||
}
|
||||
|
||||
func (g *sqlGenerator) buildCSVOptions(csvConfig *config.CSVConfig) ([]string, error) {
|
||||
var opts []string
|
||||
if csvConfig.FieldsTerminatedBy != "" {
|
||||
opts = append(opts, fmt.Sprintf("FIELDS_TERMINATED_BY='%s'", escapeString(csvConfig.FieldsTerminatedBy)))
|
||||
}
|
||||
if csvConfig.FieldsEnclosedBy != "" {
|
||||
opts = append(opts, fmt.Sprintf("FIELDS_ENCLOSED_BY='%s'", escapeString(csvConfig.FieldsEnclosedBy)))
|
||||
}
|
||||
if csvConfig.FieldsEscapedBy != "" {
|
||||
opts = append(opts, fmt.Sprintf("FIELDS_ESCAPED_BY='%s'", escapeString(csvConfig.FieldsEscapedBy)))
|
||||
}
|
||||
if csvConfig.LinesTerminatedBy != "" {
|
||||
opts = append(opts, fmt.Sprintf("LINES_TERMINATED_BY='%s'", escapeString(csvConfig.LinesTerminatedBy)))
|
||||
}
|
||||
if len(csvConfig.FieldNullDefinedBy) > 0 {
|
||||
if len(csvConfig.FieldNullDefinedBy) > 1 {
|
||||
return nil, ErrMultipleFieldsDefinedNullBy
|
||||
}
|
||||
opts = append(opts, fmt.Sprintf("FIELDS_DEFINED_NULL_BY='%s'", escapeString(csvConfig.FieldNullDefinedBy[0])))
|
||||
}
|
||||
return opts, nil
|
||||
}
|
||||
|
||||
func escapeIdentifier(s string) string {
|
||||
return "`" + strings.ReplaceAll(s, "`", "``") + "`"
|
||||
}
|
||||
|
||||
func escapeString(s string) string {
|
||||
return strings.ReplaceAll(strings.ReplaceAll(s, "\\", "\\\\"), "'", "''")
|
||||
}
|
||||
173
pkg/importsdk/sql_generator_test.go
Normal file
173
pkg/importsdk/sql_generator_test.go
Normal file
@ -0,0 +1,173 @@
|
||||
// Copyright 2025 PingCAP, Inc.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package importsdk
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/pingcap/tidb/pkg/lightning/config"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestGenerateImportSQL(t *testing.T) {
|
||||
gen := NewSQLGenerator()
|
||||
|
||||
defaultTableMeta := &TableMeta{
|
||||
Database: "test_db",
|
||||
Table: "test_table",
|
||||
WildcardPath: "s3://bucket/path/*.csv",
|
||||
}
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
tableMeta *TableMeta
|
||||
options *ImportOptions
|
||||
expected string
|
||||
expectedError string
|
||||
}{
|
||||
{
|
||||
name: "Basic",
|
||||
options: &ImportOptions{
|
||||
Format: "csv",
|
||||
},
|
||||
expected: "IMPORT INTO `test_db`.`test_table` FROM 's3://bucket/path/*.csv' FORMAT 'csv'",
|
||||
},
|
||||
{
|
||||
name: "With S3 Credentials",
|
||||
options: &ImportOptions{
|
||||
Format: "csv",
|
||||
ResourceParameters: "access-key=ak&endpoint=http%3A%2F%2Fminio%3A9000&secret-access-key=sk",
|
||||
},
|
||||
expected: "IMPORT INTO `test_db`.`test_table` FROM 's3://bucket/path/*.csv?access-key=ak&endpoint=http%3A%2F%2Fminio%3A9000&secret-access-key=sk' FORMAT 'csv'",
|
||||
},
|
||||
{
|
||||
name: "With Options",
|
||||
options: &ImportOptions{
|
||||
Format: "csv",
|
||||
Thread: 4,
|
||||
Detached: true,
|
||||
MaxWriteSpeed: "100MiB",
|
||||
},
|
||||
expected: "IMPORT INTO `test_db`.`test_table` FROM 's3://bucket/path/*.csv' FORMAT 'csv' WITH THREAD=4, MAX_WRITE_SPEED='100MiB', DETACHED",
|
||||
},
|
||||
{
|
||||
name: "All Options",
|
||||
options: &ImportOptions{
|
||||
Format: "csv",
|
||||
Thread: 8,
|
||||
DiskQuota: "100GiB",
|
||||
MaxWriteSpeed: "200MiB",
|
||||
SplitFile: true,
|
||||
RecordErrors: 100,
|
||||
Detached: true,
|
||||
CloudStorageURI: "s3://bucket/storage",
|
||||
GroupKey: "group1",
|
||||
SkipRows: 1,
|
||||
CharacterSet: "utf8mb4",
|
||||
ChecksumTable: "test_db.checksum_table",
|
||||
DisableTiKVImportMode: true,
|
||||
DisablePrecheck: true,
|
||||
},
|
||||
expected: "IMPORT INTO `test_db`.`test_table` FROM 's3://bucket/path/*.csv' FORMAT 'csv' WITH THREAD=8, DISK_QUOTA='100GiB', MAX_WRITE_SPEED='200MiB', SPLIT_FILE, RECORD_ERRORS=100, DETACHED, CLOUD_STORAGE_URI='s3://bucket/storage', GROUP_KEY='group1', SKIP_ROWS=1, CHARACTER_SET='utf8mb4', CHECKSUM_TABLE='test_db.checksum_table', DISABLE_TIKV_IMPORT_MODE, DISABLE_PRECHECK",
|
||||
},
|
||||
{
|
||||
name: "With CSV Config",
|
||||
options: &ImportOptions{
|
||||
Format: "csv",
|
||||
CSVConfig: &config.CSVConfig{
|
||||
FieldsTerminatedBy: ",",
|
||||
FieldsEnclosedBy: "\"",
|
||||
FieldsEscapedBy: "\\",
|
||||
LinesTerminatedBy: "\n",
|
||||
FieldNullDefinedBy: []string{"NULL"},
|
||||
},
|
||||
},
|
||||
expected: "IMPORT INTO `test_db`.`test_table` FROM 's3://bucket/path/*.csv' FORMAT 'csv' WITH FIELDS_TERMINATED_BY=',', FIELDS_ENCLOSED_BY='\"', FIELDS_ESCAPED_BY='\\\\', LINES_TERMINATED_BY='\n', FIELDS_DEFINED_NULL_BY='NULL'",
|
||||
},
|
||||
{
|
||||
name: "With Cloud Storage URI",
|
||||
options: &ImportOptions{
|
||||
Format: "parquet",
|
||||
CloudStorageURI: "s3://bucket/storage",
|
||||
},
|
||||
expected: "IMPORT INTO `test_db`.`test_table` FROM 's3://bucket/path/*.csv' FORMAT 'parquet' WITH CLOUD_STORAGE_URI='s3://bucket/storage'",
|
||||
},
|
||||
{
|
||||
name: "Multiple Null Defined By",
|
||||
options: &ImportOptions{
|
||||
Format: "csv",
|
||||
CSVConfig: &config.CSVConfig{
|
||||
FieldNullDefinedBy: []string{"NULL", "\\N"},
|
||||
},
|
||||
},
|
||||
expectedError: "IMPORT INTO only supports one FIELDS_DEFINED_NULL_BY value",
|
||||
},
|
||||
{
|
||||
name: "Resource Parameters Append",
|
||||
tableMeta: &TableMeta{
|
||||
Database: "test_db",
|
||||
Table: "test_table",
|
||||
WildcardPath: "s3://bucket/path/*.csv?foo=bar",
|
||||
},
|
||||
options: &ImportOptions{
|
||||
Format: "csv",
|
||||
ResourceParameters: "access-key=ak",
|
||||
},
|
||||
expected: "IMPORT INTO `test_db`.`test_table` FROM 's3://bucket/path/*.csv?foo=bar&access-key=ak' FORMAT 'csv'",
|
||||
},
|
||||
{
|
||||
name: "Escaping",
|
||||
options: &ImportOptions{
|
||||
Format: "csv",
|
||||
GroupKey: "group'1",
|
||||
CharacterSet: "utf8'mb4",
|
||||
CSVConfig: &config.CSVConfig{
|
||||
FieldsTerminatedBy: "'",
|
||||
FieldsEnclosedBy: "\\",
|
||||
},
|
||||
},
|
||||
expected: "IMPORT INTO `test_db`.`test_table` FROM 's3://bucket/path/*.csv' FORMAT 'csv' WITH GROUP_KEY='group''1', CHARACTER_SET='utf8''mb4', FIELDS_TERMINATED_BY='''', FIELDS_ENCLOSED_BY='\\\\'",
|
||||
},
|
||||
{
|
||||
name: "Identifier Escaping",
|
||||
tableMeta: &TableMeta{
|
||||
Database: "test`db",
|
||||
Table: "test`table",
|
||||
WildcardPath: "s3://bucket/path/*.csv",
|
||||
},
|
||||
options: &ImportOptions{
|
||||
Format: "csv",
|
||||
},
|
||||
expected: "IMPORT INTO `test``db`.`test``table` FROM 's3://bucket/path/*.csv' FORMAT 'csv'",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
tm := tt.tableMeta
|
||||
if tm == nil {
|
||||
tm = defaultTableMeta
|
||||
}
|
||||
sql, err := gen.GenerateImportSQL(tm, tt.options)
|
||||
if tt.expectedError != "" {
|
||||
require.Error(t, err)
|
||||
require.Contains(t, err.Error(), tt.expectedError)
|
||||
} else {
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, tt.expected, sql)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
@ -4,6 +4,7 @@ go_library(
|
||||
name = "testkit",
|
||||
srcs = [
|
||||
"asynctestkit.go",
|
||||
"db_driver.go",
|
||||
"dbtestkit.go",
|
||||
"mocksessionmanager.go",
|
||||
"mockstore.go",
|
||||
@ -71,8 +72,12 @@ go_library(
|
||||
go_test(
|
||||
name = "testkit_test",
|
||||
timeout = "short",
|
||||
srcs = ["testkit_test.go"],
|
||||
srcs = [
|
||||
"db_driver_test.go",
|
||||
"testkit_test.go",
|
||||
],
|
||||
embed = [":testkit"],
|
||||
flaky = True,
|
||||
shard_count = 2,
|
||||
deps = ["@com_github_stretchr_testify//require"],
|
||||
)
|
||||
|
||||
308
pkg/testkit/db_driver.go
Normal file
308
pkg/testkit/db_driver.go
Normal file
@ -0,0 +1,308 @@
|
||||
// Copyright 2025 PingCAP, Inc.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package testkit
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"database/sql/driver"
|
||||
"fmt"
|
||||
"io"
|
||||
"strconv"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
|
||||
"github.com/pingcap/errors"
|
||||
"github.com/pingcap/tidb/pkg/expression"
|
||||
"github.com/pingcap/tidb/pkg/parser/mysql"
|
||||
"github.com/pingcap/tidb/pkg/session/sessionapi"
|
||||
"github.com/pingcap/tidb/pkg/types"
|
||||
"github.com/pingcap/tidb/pkg/util/chunk"
|
||||
"github.com/pingcap/tidb/pkg/util/sqlexec"
|
||||
)
|
||||
|
||||
var (
|
||||
tkMapMu sync.RWMutex
|
||||
tkMap = make(map[string]*TestKit)
|
||||
tkIDSeq int64
|
||||
)
|
||||
|
||||
func init() {
|
||||
sql.Register("testkit", &testKitDriver{})
|
||||
}
|
||||
|
||||
// CreateMockDB creates a *sql.DB that uses the TestKit's store to create sessions.
|
||||
func CreateMockDB(tk *TestKit) *sql.DB {
|
||||
id := strconv.FormatInt(atomic.AddInt64(&tkIDSeq, 1), 10)
|
||||
tkMapMu.Lock()
|
||||
tkMap[id] = tk
|
||||
tkMapMu.Unlock()
|
||||
|
||||
db, err := sql.Open("testkit", id)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
return db
|
||||
}
|
||||
|
||||
type testKitDriver struct{}
|
||||
|
||||
func (d *testKitDriver) Open(name string) (driver.Conn, error) {
|
||||
tkMapMu.RLock()
|
||||
tk, ok := tkMap[name]
|
||||
tkMapMu.RUnlock()
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("testkit not found for %s", name)
|
||||
}
|
||||
se := NewSession(tk.t, tk.store)
|
||||
return &testKitConn{se: se}, nil
|
||||
}
|
||||
|
||||
type testKitConn struct {
|
||||
se sessionapi.Session
|
||||
}
|
||||
|
||||
func (c *testKitConn) Prepare(query string) (driver.Stmt, error) {
|
||||
return &testKitStmt{c: c, query: query}, nil
|
||||
}
|
||||
|
||||
func (c *testKitConn) Close() error {
|
||||
c.se.Close()
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *testKitConn) Begin() (driver.Tx, error) {
|
||||
_, err := c.Exec("BEGIN", nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &testKitTxn{c: c}, nil
|
||||
}
|
||||
|
||||
func (c *testKitConn) Exec(query string, args []driver.Value) (driver.Result, error) {
|
||||
return nil, driver.ErrSkip
|
||||
}
|
||||
|
||||
func (c *testKitConn) ExecContext(ctx context.Context, query string, args []driver.NamedValue) (driver.Result, error) {
|
||||
qArgs := make([]any, len(args))
|
||||
for i, a := range args {
|
||||
qArgs[i] = a.Value
|
||||
}
|
||||
|
||||
rs, err := c.execute(ctx, query, qArgs)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if rs != nil {
|
||||
if err := rs.Close(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
return &testKitResult{
|
||||
lastInsertID: int64(c.se.LastInsertID()),
|
||||
rowsAffected: int64(c.se.AffectedRows()),
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (c *testKitConn) QueryContext(ctx context.Context, query string, args []driver.NamedValue) (driver.Rows, error) {
|
||||
qArgs := make([]any, len(args))
|
||||
for i, a := range args {
|
||||
qArgs[i] = a.Value
|
||||
}
|
||||
|
||||
rs, err := c.execute(ctx, query, qArgs)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if rs == nil {
|
||||
return nil, nil
|
||||
}
|
||||
return &testKitRows{rs: rs}, nil
|
||||
}
|
||||
|
||||
func (c *testKitConn) execute(ctx context.Context, sql string, args []any) (sqlexec.RecordSet, error) {
|
||||
// Set the command value to ComQuery, so that the process info can be updated correctly
|
||||
c.se.SetCommandValue(mysql.ComQuery)
|
||||
defer c.se.SetCommandValue(mysql.ComSleep)
|
||||
|
||||
if len(args) == 0 {
|
||||
rss, err := c.se.Execute(ctx, sql)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if len(rss) == 0 {
|
||||
return nil, nil
|
||||
}
|
||||
return rss[0], nil
|
||||
}
|
||||
|
||||
stmtID, _, _, err := c.se.PrepareStmt(sql)
|
||||
if err != nil {
|
||||
return nil, errors.Trace(err)
|
||||
}
|
||||
params := expression.Args2Expressions4Test(args...)
|
||||
rs, err := c.se.ExecutePreparedStmt(ctx, stmtID, params)
|
||||
if err != nil {
|
||||
return rs, errors.Trace(err)
|
||||
}
|
||||
err = c.se.DropPreparedStmt(stmtID)
|
||||
if err != nil {
|
||||
return rs, errors.Trace(err)
|
||||
}
|
||||
return rs, nil
|
||||
}
|
||||
|
||||
type testKitStmt struct {
|
||||
c *testKitConn
|
||||
query string
|
||||
}
|
||||
|
||||
func (s *testKitStmt) Close() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *testKitStmt) NumInput() int {
|
||||
return -1
|
||||
}
|
||||
|
||||
func (s *testKitStmt) Exec(args []driver.Value) (driver.Result, error) {
|
||||
return nil, driver.ErrSkip
|
||||
}
|
||||
|
||||
func (s *testKitStmt) ExecContext(ctx context.Context, args []driver.NamedValue) (driver.Result, error) {
|
||||
return s.c.ExecContext(ctx, s.query, args)
|
||||
}
|
||||
|
||||
func (s *testKitStmt) Query(args []driver.Value) (driver.Rows, error) {
|
||||
return nil, driver.ErrSkip
|
||||
}
|
||||
|
||||
func (s *testKitStmt) QueryContext(ctx context.Context, args []driver.NamedValue) (driver.Rows, error) {
|
||||
return s.c.QueryContext(ctx, s.query, args)
|
||||
}
|
||||
|
||||
type testKitResult struct {
|
||||
lastInsertID int64
|
||||
rowsAffected int64
|
||||
}
|
||||
|
||||
func (r *testKitResult) LastInsertId() (int64, error) {
|
||||
return r.lastInsertID, nil
|
||||
}
|
||||
|
||||
func (r *testKitResult) RowsAffected() (int64, error) {
|
||||
return r.rowsAffected, nil
|
||||
}
|
||||
|
||||
type testKitRows struct {
|
||||
rs sqlexec.RecordSet
|
||||
chunk *chunk.Chunk
|
||||
it *chunk.Iterator4Chunk
|
||||
}
|
||||
|
||||
func (r *testKitRows) Columns() []string {
|
||||
fields := r.rs.Fields()
|
||||
cols := make([]string, len(fields))
|
||||
for i, f := range fields {
|
||||
cols[i] = f.Column.Name.O
|
||||
}
|
||||
return cols
|
||||
}
|
||||
|
||||
func (r *testKitRows) Close() error {
|
||||
return r.rs.Close()
|
||||
}
|
||||
|
||||
func (r *testKitRows) Next(dest []driver.Value) error {
|
||||
if r.chunk == nil {
|
||||
r.chunk = r.rs.NewChunk(nil)
|
||||
}
|
||||
|
||||
var row chunk.Row
|
||||
if r.it == nil {
|
||||
err := r.rs.Next(context.Background(), r.chunk)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if r.chunk.NumRows() == 0 {
|
||||
return io.EOF
|
||||
}
|
||||
r.it = chunk.NewIterator4Chunk(r.chunk)
|
||||
row = r.it.Begin()
|
||||
} else {
|
||||
row = r.it.Next()
|
||||
if row.IsEmpty() {
|
||||
err := r.rs.Next(context.Background(), r.chunk)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if r.chunk.NumRows() == 0 {
|
||||
return io.EOF
|
||||
}
|
||||
r.it = chunk.NewIterator4Chunk(r.chunk)
|
||||
row = r.it.Begin()
|
||||
}
|
||||
}
|
||||
|
||||
for i := range row.Len() {
|
||||
d := row.GetDatum(i, &r.rs.Fields()[i].Column.FieldType)
|
||||
// Handle NULL
|
||||
if d.IsNull() {
|
||||
dest[i] = nil
|
||||
} else {
|
||||
// Convert to appropriate type if needed, or just return string/bytes/int/float
|
||||
// driver.Value allows int64, float64, bool, []byte, string, time.Time
|
||||
// Datum.GetValue() returns interface{} which might be compatible.
|
||||
v := d.GetValue()
|
||||
switch x := v.(type) {
|
||||
case []byte:
|
||||
dest[i] = x
|
||||
case string:
|
||||
dest[i] = x
|
||||
case int64:
|
||||
dest[i] = x
|
||||
case uint64:
|
||||
dest[i] = x
|
||||
case float64:
|
||||
dest[i] = x
|
||||
case float32:
|
||||
dest[i] = x
|
||||
case types.Time:
|
||||
dest[i] = x.String()
|
||||
case types.Duration:
|
||||
dest[i] = x.String()
|
||||
case *types.MyDecimal:
|
||||
dest[i] = x.String()
|
||||
default:
|
||||
dest[i] = x
|
||||
}
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
type testKitTxn struct {
|
||||
c *testKitConn
|
||||
}
|
||||
|
||||
func (t *testKitTxn) Commit() error {
|
||||
_, err := t.c.Exec("COMMIT", nil)
|
||||
return err
|
||||
}
|
||||
|
||||
func (t *testKitTxn) Rollback() error {
|
||||
_, err := t.c.Exec("ROLLBACK", nil)
|
||||
return err
|
||||
}
|
||||
79
pkg/testkit/db_driver_test.go
Normal file
79
pkg/testkit/db_driver_test.go
Normal file
@ -0,0 +1,79 @@
|
||||
// Copyright 2025 PingCAP, Inc.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package testkit_test
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/pingcap/tidb/pkg/testkit"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestMockDB(t *testing.T) {
|
||||
store := testkit.CreateMockStore(t)
|
||||
|
||||
tk := testkit.NewTestKit(t, store)
|
||||
|
||||
db := testkit.CreateMockDB(tk)
|
||||
defer db.Close()
|
||||
|
||||
var err error
|
||||
_, err = db.Exec("use test")
|
||||
require.NoError(t, err)
|
||||
_, err = db.Exec("create table t (id int, v varchar(255))")
|
||||
require.NoError(t, err)
|
||||
_, err = db.Exec("insert into t values (1, 'a'), (2, 'b')")
|
||||
require.NoError(t, err)
|
||||
|
||||
// Test Query
|
||||
rows, err := db.Query("select * from t order by id")
|
||||
require.NoError(t, err)
|
||||
defer rows.Close()
|
||||
|
||||
var id int
|
||||
var v string
|
||||
require.True(t, rows.Next())
|
||||
require.NoError(t, rows.Scan(&id, &v))
|
||||
require.Equal(t, 1, id)
|
||||
require.Equal(t, "a", v)
|
||||
|
||||
require.True(t, rows.Next())
|
||||
require.NoError(t, rows.Scan(&id, &v))
|
||||
require.Equal(t, 2, id)
|
||||
require.Equal(t, "b", v)
|
||||
|
||||
require.False(t, rows.Next())
|
||||
require.NoError(t, rows.Err())
|
||||
|
||||
// Test Exec
|
||||
res, err := db.Exec("insert into t values (3, 'c')")
|
||||
require.NoError(t, err)
|
||||
affected, err := res.RowsAffected()
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, int64(1), affected)
|
||||
|
||||
// Verify with TestKit
|
||||
tk.MustExec("use test")
|
||||
tk.MustQuery("select * from t where id = 3").Check(testkit.Rows("3 c"))
|
||||
|
||||
// Test Prepare
|
||||
stmt, err := db.Prepare("select v from t where id = ?")
|
||||
require.NoError(t, err)
|
||||
defer stmt.Close()
|
||||
|
||||
row := stmt.QueryRow(2)
|
||||
require.NoError(t, row.Scan(&v))
|
||||
require.Equal(t, "b", v)
|
||||
}
|
||||
@ -11,6 +11,7 @@ go_test(
|
||||
"one_parquet_test.go",
|
||||
"parquet_test.go",
|
||||
"precheck_test.go",
|
||||
"sdk_test.go",
|
||||
"util_test.go",
|
||||
],
|
||||
embedsrcs = [
|
||||
@ -35,6 +36,7 @@ go_test(
|
||||
"//pkg/domain",
|
||||
"//pkg/domain/infosync",
|
||||
"//pkg/executor/importer",
|
||||
"//pkg/importsdk",
|
||||
"//pkg/infoschema",
|
||||
"//pkg/kv",
|
||||
"//pkg/lightning/backend/local",
|
||||
|
||||
180
tests/realtikvtest/importintotest/sdk_test.go
Normal file
180
tests/realtikvtest/importintotest/sdk_test.go
Normal file
@ -0,0 +1,180 @@
|
||||
// Copyright 2025 PingCAP, Inc.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package importintotest
|
||||
|
||||
import (
|
||||
"context"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"time"
|
||||
|
||||
"github.com/pingcap/tidb/pkg/importsdk"
|
||||
"github.com/pingcap/tidb/pkg/testkit"
|
||||
"github.com/pingcap/tidb/pkg/testkit/testfailpoint"
|
||||
)
|
||||
|
||||
func (s *mockGCSSuite) TestImportSDK() {
|
||||
ctx := context.Background()
|
||||
|
||||
// 1. Prepare source data (local files)
|
||||
tmpDir := s.T().TempDir()
|
||||
|
||||
// Create schema files for CreateSchemasAndTables
|
||||
// DB schema
|
||||
err := os.WriteFile(filepath.Join(tmpDir, "importsdk_test-schema-create.sql"), []byte("CREATE DATABASE importsdk_test;"), 0644)
|
||||
s.NoError(err)
|
||||
// Table schema
|
||||
err = os.WriteFile(filepath.Join(tmpDir, "importsdk_test.t-schema.sql"), []byte("CREATE TABLE t (id int, v varchar(255));"), 0644)
|
||||
s.NoError(err)
|
||||
|
||||
// Data file: mydumper format: {schema}.{table}.{seq}.csv
|
||||
fileName := "importsdk_test.t.001.csv"
|
||||
content := []byte("1,test1\n2,test2")
|
||||
err = os.WriteFile(filepath.Join(tmpDir, fileName), content, 0644)
|
||||
s.NoError(err)
|
||||
|
||||
// 2. Prepare DB connection
|
||||
// Ensure clean state
|
||||
s.tk.MustExec("DROP DATABASE IF EXISTS importsdk_test")
|
||||
|
||||
// Create a *sql.DB using the testkit driver
|
||||
db := testkit.CreateMockDB(s.tk)
|
||||
defer db.Close()
|
||||
|
||||
// 3. Initialize SDK
|
||||
// Use local file path.
|
||||
sdk, err := importsdk.NewImportSDK(ctx, "file://"+tmpDir, db)
|
||||
s.NoError(err)
|
||||
defer sdk.Close()
|
||||
|
||||
// 4. Test FileScanner.CreateSchemasAndTables
|
||||
err = sdk.CreateSchemasAndTables(ctx)
|
||||
s.NoError(err)
|
||||
// Verify DB and Table exist
|
||||
s.tk.MustExec("USE importsdk_test")
|
||||
s.tk.MustExec("SHOW CREATE TABLE t")
|
||||
|
||||
// 4.1 Test FileScanner.CreateSchemaAndTableByName
|
||||
s.tk.MustExec("DROP TABLE t")
|
||||
err = sdk.CreateSchemaAndTableByName(ctx, "importsdk_test", "t")
|
||||
s.NoError(err)
|
||||
s.tk.MustExec("SHOW CREATE TABLE t")
|
||||
|
||||
// 5. Test FileScanner.GetTableMetas
|
||||
metas, err := sdk.GetTableMetas(ctx)
|
||||
s.NoError(err)
|
||||
s.Len(metas, 1)
|
||||
s.Equal("importsdk_test", metas[0].Database)
|
||||
s.Equal("t", metas[0].Table)
|
||||
s.Equal(int64(len(content)), metas[0].TotalSize)
|
||||
|
||||
// 6. Test FileScanner.GetTableMetaByName
|
||||
meta, err := sdk.GetTableMetaByName(ctx, "importsdk_test", "t")
|
||||
s.NoError(err)
|
||||
s.Equal("importsdk_test", meta.Database)
|
||||
s.Equal("t", meta.Table)
|
||||
|
||||
// 7. Test FileScanner.GetTotalSize
|
||||
totalSize := sdk.GetTotalSize(ctx)
|
||||
s.Equal(int64(len(content)), totalSize)
|
||||
|
||||
// 8. Test SQLGenerator.GenerateImportSQL
|
||||
opts := &importsdk.ImportOptions{
|
||||
Thread: 4,
|
||||
Detached: true,
|
||||
}
|
||||
sql, err := sdk.GenerateImportSQL(meta, opts)
|
||||
s.NoError(err)
|
||||
s.Contains(sql, "IMPORT INTO `importsdk_test`.`t` FROM")
|
||||
s.Contains(sql, "THREAD=4")
|
||||
s.Contains(sql, "DETACHED")
|
||||
|
||||
// 9. Test JobManager.SubmitJob
|
||||
jobID, err := sdk.SubmitJob(ctx, sql)
|
||||
s.NoError(err)
|
||||
s.Greater(jobID, int64(0))
|
||||
|
||||
// 10. Test JobManager.GetJobStatus
|
||||
status, err := sdk.GetJobStatus(ctx, jobID)
|
||||
s.NoError(err)
|
||||
s.Equal(jobID, status.JobID)
|
||||
s.Equal("`importsdk_test`.`t`", status.TargetTable)
|
||||
|
||||
// Wait for job to finish
|
||||
s.Eventually(func() bool {
|
||||
status, err := sdk.GetJobStatus(ctx, jobID)
|
||||
s.NoError(err)
|
||||
return status.Status == "finished"
|
||||
}, 30*time.Second, 500*time.Millisecond)
|
||||
|
||||
// Verify data
|
||||
s.tk.MustQuery("SELECT * FROM importsdk_test.t").Check(testkit.Rows("1 test1", "2 test2"))
|
||||
|
||||
// 11. Test JobManager.CancelJob with failpoint
|
||||
s.tk.MustExec("TRUNCATE TABLE t")
|
||||
testfailpoint.Enable(s.T(), "github.com/pingcap/tidb/pkg/disttask/importinto/syncBeforeJobStarted", "pause")
|
||||
|
||||
jobID2, err := sdk.SubmitJob(ctx, sql)
|
||||
s.NoError(err)
|
||||
|
||||
// Cancel the job
|
||||
err = sdk.CancelJob(ctx, jobID2)
|
||||
s.NoError(err)
|
||||
|
||||
// Unpause
|
||||
testfailpoint.Disable(s.T(), "github.com/pingcap/tidb/pkg/disttask/importinto/syncBeforeJobStarted")
|
||||
|
||||
// Verify status is cancelled
|
||||
s.Eventually(func() bool {
|
||||
status, err := sdk.GetJobStatus(ctx, jobID2)
|
||||
s.NoError(err)
|
||||
return status.Status == "cancelled"
|
||||
}, 30*time.Second, 500*time.Millisecond)
|
||||
|
||||
// 12. Test JobManager.GetGroupSummary
|
||||
// Submit a job with GroupKey
|
||||
groupKey := "test_group_key"
|
||||
optsWithGroup := &importsdk.ImportOptions{
|
||||
Thread: 4,
|
||||
Detached: true,
|
||||
GroupKey: groupKey,
|
||||
}
|
||||
sqlWithGroup, err := sdk.GenerateImportSQL(meta, optsWithGroup)
|
||||
s.NoError(err)
|
||||
jobID3, err := sdk.SubmitJob(ctx, sqlWithGroup)
|
||||
s.NoError(err)
|
||||
s.Greater(jobID3, int64(0))
|
||||
|
||||
// Wait for job to finish
|
||||
s.Eventually(func() bool {
|
||||
status, err := sdk.GetJobStatus(ctx, jobID3)
|
||||
s.NoError(err)
|
||||
return status.Status == "finished"
|
||||
}, 30*time.Second, 500*time.Millisecond)
|
||||
|
||||
// Get group summary
|
||||
groupSummary, err := sdk.GetGroupSummary(ctx, groupKey)
|
||||
s.NoError(err)
|
||||
s.Equal(groupKey, groupSummary.GroupKey)
|
||||
s.Equal(int64(1), groupSummary.TotalJobs)
|
||||
s.Equal(int64(1), groupSummary.Completed)
|
||||
|
||||
// 13. Test JobManager.GetJobsByGroup
|
||||
jobs, err := sdk.GetJobsByGroup(ctx, groupKey)
|
||||
s.NoError(err)
|
||||
s.Len(jobs, 1)
|
||||
s.Equal(jobID3, jobs[0].JobID)
|
||||
s.Equal("finished", jobs[0].Status)
|
||||
}
|
||||
Reference in New Issue
Block a user