diff --git a/br/pkg/storage/azblob.go b/br/pkg/storage/azblob.go index 41d8fa88f5..c3d734ebe9 100644 --- a/br/pkg/storage/azblob.go +++ b/br/pkg/storage/azblob.go @@ -348,13 +348,13 @@ func (s *AzureBlobStorage) WalkDir(ctx context.Context, opt *WalkOption, fn func if opt == nil { opt = &WalkOption{} } - if len(opt.ObjPrefix) != 0 { - return errors.New("azure storage not support ObjPrefix for now") - } prefix := path.Join(s.options.Prefix, opt.SubDir) if len(prefix) > 0 && !strings.HasSuffix(prefix, "/") { prefix += "/" } + if len(opt.ObjPrefix) != 0 { + prefix += opt.ObjPrefix + } listOption := &azblob.ContainerListBlobFlatSegmentOptions{Prefix: &prefix} for { diff --git a/br/pkg/storage/gcs.go b/br/pkg/storage/gcs.go index 19d5f00810..db094cdc10 100644 --- a/br/pkg/storage/gcs.go +++ b/br/pkg/storage/gcs.go @@ -205,13 +205,14 @@ func (s *GCSStorage) WalkDir(ctx context.Context, opt *WalkOption, fn func(strin if opt == nil { opt = &WalkOption{} } - if len(opt.ObjPrefix) != 0 { - return errors.New("gcs storage not support ObjPrefix for now") - } prefix := path.Join(s.gcs.Prefix, opt.SubDir) if len(prefix) > 0 && !strings.HasSuffix(prefix, "/") { prefix += "/" } + if len(opt.ObjPrefix) != 0 { + prefix += opt.ObjPrefix + } + query := &storage.Query{Prefix: prefix} // only need each object's name and size err := query.SetAttrSelection([]string{"Name", "Size"}) diff --git a/br/pkg/storage/storage.go b/br/pkg/storage/storage.go index 279e4ff9fe..93fb56dcd6 100644 --- a/br/pkg/storage/storage.go +++ b/br/pkg/storage/storage.go @@ -33,8 +33,9 @@ const ( type WalkOption struct { // walk on SubDir of specify directory SubDir string - // ObjPrefix used fo prefix search in storage. - // it can save lots of time when we want find specify prefix objects in storage. + // ObjPrefix used fo prefix search in storage. Note that only part of storage + // support it. + // It can save lots of time when we want find specify prefix objects in storage. // For example. we have 10000 .sst files and 10 backupmeta.(\d+) files. // we can use ObjPrefix = "backupmeta" to retrieve all meta files quickly. ObjPrefix string diff --git a/executor/asyncloaddata/show_test.go b/executor/asyncloaddata/show_test.go index 07b8363180..d77fd2566c 100644 --- a/executor/asyncloaddata/show_test.go +++ b/executor/asyncloaddata/show_test.go @@ -87,10 +87,17 @@ func (s *mockGCSSuite) TestInternalStatus() { s.server.CreateObject(fakestorage.Object{ ObjectAttrs: fakestorage.ObjectAttrs{ BucketName: "test-tsv", - Name: "t.tsv", + Name: "t1.tsv", }, - Content: []byte(`1 -2`), + Content: []byte(`1`), + }) + + s.server.CreateObject(fakestorage.Object{ + ObjectAttrs: fakestorage.ObjectAttrs{ + BucketName: "test-tsv", + Name: "t2.tsv", + }, + Content: []byte(`2`), }) ctx := context.Background() @@ -118,7 +125,7 @@ func (s *mockGCSSuite) TestInternalStatus() { expected := &JobInfo{ JobID: id, User: "test-load@test-host", - DataSource: fmt.Sprintf("gs://test-tsv/t.tsv?endpoint=%s", gcsEndpoint), + DataSource: fmt.Sprintf("gs://test-tsv/t*.tsv?endpoint=%s", gcsEndpoint), TableSchema: "load_tsv", TableName: "t", ImportMode: "logical", @@ -141,7 +148,7 @@ func (s *mockGCSSuite) TestInternalStatus() { // tk2 @ 0:08 info, err = GetJobInfo(ctx, tk2.Session(), id) require.NoError(s.T(), err) - expected.Progress = `{"SourceFileSize":3,"LoadedFileSize":0,"LoadedRowCnt":1}` + expected.Progress = `{"SourceFileSize":2,"LoadedFileSize":0,"LoadedRowCnt":1}` require.Equal(s.T(), expected, info) // tk @ 0:09 // commit one task and sleep 3 seconds @@ -149,7 +156,7 @@ func (s *mockGCSSuite) TestInternalStatus() { // tk2 @ 0:11 info, err = GetJobInfo(ctx, tk2.Session(), id) require.NoError(s.T(), err) - expected.Progress = `{"SourceFileSize":3,"LoadedFileSize":2,"LoadedRowCnt":2}` + expected.Progress = `{"SourceFileSize":2,"LoadedFileSize":1,"LoadedRowCnt":2}` require.Equal(s.T(), expected, info) // tk @ 0:12 // finish job @@ -159,7 +166,7 @@ func (s *mockGCSSuite) TestInternalStatus() { require.NoError(s.T(), err) expected.Status = JobFinished expected.StatusMessage = "Records: 2 Deleted: 0 Skipped: 0 Warnings: 0" - expected.Progress = `{"SourceFileSize":3,"LoadedFileSize":3,"LoadedRowCnt":2}` + expected.Progress = `{"SourceFileSize":2,"LoadedFileSize":2,"LoadedRowCnt":2}` require.Equal(s.T(), expected, info) }() @@ -183,7 +190,7 @@ func (s *mockGCSSuite) TestInternalStatus() { s.enableFailpoint("github.com/pingcap/tidb/executor/AfterStartJob", `sleep(3000)`) s.enableFailpoint("github.com/pingcap/tidb/executor/AfterCommitOneTask", `sleep(3000)`) s.tk.MustExec("SET SESSION tidb_dml_batch_size = 1;") - sql := fmt.Sprintf(`LOAD DATA INFILE 'gs://test-tsv/t.tsv?endpoint=%s' + sql := fmt.Sprintf(`LOAD DATA INFILE 'gs://test-tsv/t*.tsv?endpoint=%s' INTO TABLE load_tsv.t;`, gcsEndpoint) s.tk.MustExec(sql) wg.Wait() diff --git a/executor/load_data.go b/executor/load_data.go index cf72281847..0b747f0cd0 100644 --- a/executor/load_data.go +++ b/executor/load_data.go @@ -50,6 +50,7 @@ import ( "github.com/pingcap/tidb/util/intest" "github.com/pingcap/tidb/util/logutil" "github.com/pingcap/tidb/util/sqlexec" + "github.com/pingcap/tidb/util/stringutil" "go.uber.org/zap" "golang.org/x/sync/errgroup" ) @@ -144,8 +145,8 @@ func (e *LoadDataExec) Next(ctx context.Context, req *chunk.Chunk) error { if err != nil { return ErrLoadDataInvalidURI.GenWithStackByArgs(err.Error()) } - var filename string - u.Path, filename = filepath.Split(u.Path) + path := strings.Trim(u.Path, "/") + u.Path = "" b, err := storage.ParseBackendFromURL(u, nil) if err != nil { return ErrLoadDataInvalidURI.GenWithStackByArgs(getMsgFromBRError(err)) @@ -153,7 +154,12 @@ func (e *LoadDataExec) Next(ctx context.Context, req *chunk.Chunk) error { if b.GetLocal() != nil { return ErrLoadDataFromServerDisk.GenWithStackByArgs(e.loadDataWorker.Path) } - return e.loadFromRemote(ctx, b, filename) + // try to find pattern error in advance + _, err = filepath.Match(stringutil.EscapeGlobExceptAsterisk(path), "") + if err != nil { + return ErrLoadDataInvalidURI.GenWithStackByArgs("Glob pattern error: " + err.Error()) + } + return e.loadFromRemote(ctx, b, path) case ast.FileLocClient: // let caller use handleQuerySpecial to read data in this connection sctx := e.loadDataWorker.ctx @@ -170,7 +176,7 @@ func (e *LoadDataExec) Next(ctx context.Context, req *chunk.Chunk) error { func (e *LoadDataExec) loadFromRemote( ctx context.Context, b *backup.StorageBackend, - filename string, + path string, ) error { opt := &storage.ExternalStorageOptions{} if intest.InTest { @@ -180,17 +186,74 @@ func (e *LoadDataExec) loadFromRemote( if err != nil { return ErrLoadDataCantAccess } - fileReader, err := s.Open(ctx, filename) - if err != nil { - return ErrLoadDataCantRead.GenWithStackByArgs(getMsgFromBRError(err), "Please check the INFILE path is correct") - } - defer fileReader.Close() - e.loadDataWorker.loadRemoteInfo = loadRemoteInfo{ - store: s, - path: filename, + idx := strings.IndexByte(path, '*') + // simple path when the INFILE represent one file + if idx == -1 { + opener := func(ctx context.Context) (io.ReadSeekCloser, error) { + fileReader, err2 := s.Open(ctx, path) + if err2 != nil { + return nil, ErrLoadDataCantRead.GenWithStackByArgs(getMsgFromBRError(err2), "Please check the INFILE path is correct") + } + return fileReader, nil + } + + // try to read the file size to report progress. Don't fail the main load + // if this fails to tolerate transient errors. + filesize := int64(-1) + reader, err2 := opener(ctx) + if err2 == nil { + size, err3 := reader.Seek(0, io.SeekEnd) + if err3 != nil { + logutil.Logger(ctx).Warn("failed to read file size by seek in LOAD DATA", + zap.Error(err3)) + } else { + filesize = size + } + terror.Log(reader.Close()) + } + + return e.loadDataWorker.Load(ctx, []LoadDataReaderInfo{{ + Opener: opener, + Remote: &loadRemoteInfo{store: s, path: path, size: filesize}, + }}) } - return e.loadDataWorker.Load(ctx, fileReader) + + // when the INFILE represent multiple files + readerInfos := make([]LoadDataReaderInfo, 0, 8) + commonPrefix := path[:idx] + // we only support '*', in order to reuse glob library manually escape the path + escapedPath := stringutil.EscapeGlobExceptAsterisk(path) + + err = s.WalkDir(ctx, &storage.WalkOption{ObjPrefix: commonPrefix}, + func(remotePath string, size int64) error { + // we have checked in LoadDataExec.Next + //nolint: errcheck + match, _ := filepath.Match(escapedPath, remotePath) + if !match { + return nil + } + readerInfos = append(readerInfos, LoadDataReaderInfo{ + Opener: func(ctx context.Context) (io.ReadSeekCloser, error) { + fileReader, err2 := s.Open(ctx, remotePath) + if err2 != nil { + return nil, ErrLoadDataCantRead.GenWithStackByArgs(getMsgFromBRError(err2), "Please check the INFILE path is correct") + } + return fileReader, nil + }, + Remote: &loadRemoteInfo{ + store: s, + path: remotePath, + size: size, + }, + }) + return nil + }) + if err != nil { + return err + } + + return e.loadDataWorker.Load(ctx, readerInfos) } // Close implements the Executor Close interface. @@ -223,6 +286,7 @@ type commitTask struct { type loadRemoteInfo struct { store storage.ExternalStorage path string + size int64 } // LoadDataWorker does a LOAD DATA job. @@ -266,7 +330,7 @@ type LoadDataWorker struct { row []types.Datum rows [][]types.Datum commitTaskQueue chan commitTask - loadRemoteInfo loadRemoteInfo + finishedSize int64 progress atomic.Pointer[asyncloaddata.Progress] getSysSessionFn func() (sessionctx.Context, error) putSysSessionFn func(context.Context, sessionctx.Context) @@ -623,8 +687,20 @@ func (e *LoadDataWorker) reorderColumns(columnNames []string) error { // LoadDataReadBlockSize is exposed for test. var LoadDataReadBlockSize = int64(config.ReadBlockSize) +// LoadDataReaderInfo provides information for a data reader of LOAD DATA. +type LoadDataReaderInfo struct { + // Opener can be called at needed to get a io.ReadSeekCloser. It will only + // be called once. + Opener func(ctx context.Context) (io.ReadSeekCloser, error) + // Remote is not nil only if load from cloud storage. + Remote *loadRemoteInfo +} + // Load reads from readerFn and do load data job. -func (e *LoadDataWorker) Load(ctx context.Context, reader io.ReadSeekCloser) error { +func (e *LoadDataWorker) Load( + ctx context.Context, + readerInfos []LoadDataReaderInfo, +) error { var ( jobID int64 parser mydump.Parser @@ -673,65 +749,24 @@ func (e *LoadDataWorker) Load(ctx context.Context, reader io.ReadSeekCloser) err failpoint.Inject("AfterCreateLoadDataJob", nil) - switch strings.ToLower(e.format) { - case "": - // CSV-like - parser, err = mydump.NewCSVParser( - ctx, - e.GenerateCSVConfig(), - reader, - LoadDataReadBlockSize, - nil, - false, - // TODO: support charset conversion - nil) - case LoadDataFormatSQLDump: - parser = mydump.NewChunkParser( - ctx, - e.Ctx.GetSessionVars().SQLMode, - reader, - LoadDataReadBlockSize, - nil, - ) - case LoadDataFormatParquet: - if e.loadRemoteInfo.store == nil { - return ErrLoadParquetFromLocal - } - parser, err = mydump.NewParquetParser( - ctx, - e.loadRemoteInfo.store, - reader, - e.loadRemoteInfo.path, - ) - default: - return ErrLoadDataUnsupportedFormat.GenWithStackByArgs(e.format) - } - if err != nil { - return ErrLoadDataWrongFormatConfig.GenWithStack(err.Error()) - } - parser.SetLogger(log.Logger{Logger: logutil.Logger(ctx)}) - progress := asyncloaddata.Progress{ SourceFileSize: -1, LoadedFileSize: 0, LoadedRowCnt: 0, } - if e.loadRemoteInfo.store != nil { - reader2, err2 := e.loadRemoteInfo.store.Open(ctx, e.loadRemoteInfo.path) - if err2 != nil { - logutil.Logger(ctx).Warn("open file failed, can not know file size", zap.Error(err2)) - } else { - //nolint: errcheck - defer reader2.Close() - - filesize, err3 := reader2.Seek(0, io.SeekEnd) - if err3 != nil { - logutil.Logger(ctx).Warn("seek file end failed, can not know file size", zap.Error(err2)) - } else { - progress.SourceFileSize = filesize - } + totalFilesize := int64(0) + hasErr := false + for _, readerInfo := range readerInfos { + if readerInfo.Remote == nil { + logutil.Logger(ctx).Warn("can not get total file size when LOAD DATA from local file") + hasErr = true + break } + totalFilesize += readerInfo.Remote.size + } + if !hasErr { + progress.SourceFileSize = totalFilesize } e.progress.Store(&progress) @@ -747,6 +782,37 @@ func (e *LoadDataWorker) Load(ctx context.Context, reader io.ReadSeekCloser) err // goroutine that the job is finished. done := make(chan struct{}) + // processStream goroutine. + group.Go(func() error { + for _, info := range readerInfos { + reader, err2 := info.Opener(ctx) + if err2 != nil { + return err2 + } + + parser, err2 = e.buildParser(ctx, reader, info.Remote) + if err2 != nil { + terror.Log(reader.Close()) + return err2 + } + err2 = e.processStream(groupCtx, parser, reader) + terror.Log(reader.Close()) + if err2 != nil { + return err2 + } + } + + close(e.commitTaskQueue) + return nil + }) + // commitWork goroutine. + group.Go(func() error { + err2 := e.commitWork(groupCtx) + if err2 == nil { + close(done) + } + return err2 + }) // UpdateJobProgress goroutine. group.Go(func() error { ticker := time.NewTicker(time.Duration(asyncloaddata.HeartBeatInSec) * time.Second) @@ -777,24 +843,79 @@ func (e *LoadDataWorker) Load(ctx context.Context, reader io.ReadSeekCloser) err } } }) - // processStream goroutine. - group.Go(func() error { - return e.processStream(groupCtx, parser, reader) - }) - // commitWork goroutine. - group.Go(func() error { - err2 := e.commitWork(groupCtx) - if err2 == nil { - close(done) - } - return err2 - }) err = group.Wait() e.SetMessage() return err } +func (e *LoadDataWorker) buildParser( + ctx context.Context, + reader io.ReadSeekCloser, + remote *loadRemoteInfo, +) (parser mydump.Parser, err error) { + switch strings.ToLower(e.format) { + case "": + // CSV-like + parser, err = mydump.NewCSVParser( + ctx, + e.GenerateCSVConfig(), + reader, + LoadDataReadBlockSize, + nil, + false, + // TODO: support charset conversion + nil) + case LoadDataFormatSQLDump: + parser = mydump.NewChunkParser( + ctx, + e.Ctx.GetSessionVars().SQLMode, + reader, + LoadDataReadBlockSize, + nil, + ) + case LoadDataFormatParquet: + if remote == nil { + return nil, ErrLoadParquetFromLocal + } + parser, err = mydump.NewParquetParser( + ctx, + remote.store, + reader, + remote.path, + ) + default: + return nil, ErrLoadDataUnsupportedFormat.GenWithStackByArgs(e.format) + } + if err != nil { + return nil, ErrLoadDataWrongFormatConfig.GenWithStack(err.Error()) + } + parser.SetLogger(log.Logger{Logger: logutil.Logger(ctx)}) + + // handle IGNORE N LINES + ignoreOneLineFn := parser.ReadRow + if csvParser, ok := parser.(*mydump.CSVParser); ok { + ignoreOneLineFn = func() error { + _, _, err := csvParser.ReadUntilTerminator() + return err + } + } + + ignoreLineCnt := e.IgnoreLines + for ignoreLineCnt > 0 { + err = ignoreOneLineFn() + if err != nil { + if errors.Cause(err) == io.EOF { + return parser, nil + } + return nil, err + } + + ignoreLineCnt-- + } + return parser, nil +} + // processStream process input stream from parser. When returns nil, it means // all data is read. func (e *LoadDataWorker) processStream( @@ -815,23 +936,26 @@ func (e *LoadDataWorker) processStream( checkKilled := time.NewTicker(30 * time.Second) defer checkKilled.Stop() - loggedError := false + var ( + loggedError = false + currScannedSize = int64(0) + ) for { // prepare batch and enqueue task - if err = e.ReadRows(ctx, parser); err != nil { + if err = e.ReadOneBatchRows(ctx, parser); err != nil { return } if e.curBatchCnt == 0 { - close(e.commitTaskQueue) + e.finishedSize += currScannedSize return } TrySendTask: - pos, err2 := seeker.Seek(0, io.SeekCurrent) - if err2 != nil && !loggedError { + currScannedSize, err = seeker.Seek(0, io.SeekCurrent) + if err != nil && !loggedError { loggedError = true logutil.Logger(ctx).Error(" LOAD DATA failed to read current file offset by seek", - zap.Error(err2)) + zap.Error(err)) } select { case <-ctx.Done(): @@ -847,7 +971,7 @@ func (e *LoadDataWorker) processStream( cnt: e.curBatchCnt, rows: e.rows, loadedRowCnt: e.rowCount, - scannedFileSize: pos, + scannedFileSize: e.finishedSize + currScannedSize, }: } // reset rows buffer, will reallocate buffer but NOT reuse @@ -1018,31 +1142,12 @@ func (e *LoadDataWorker) addRecordLD(ctx context.Context, row []types.Datum) err return nil } -// ReadRows reads rows from parser. When parser's reader meet EOF, it will return -// nil. For other errors it will return directly. When the rows batch is full it -// will also return nil. +// ReadOneBatchRows reads rows from parser. When parser's reader meet EOF, it +// will return nil. For other errors it will return directly. When the rows +// batch is full it will also return nil. // The result rows are saved in e.rows and update some members, caller can check // if curBatchCnt == 0 to know if reached EOF. -func (e *LoadDataWorker) ReadRows(ctx context.Context, parser mydump.Parser) error { - ignoreOneLineFn := parser.ReadRow - if csvParser, ok := parser.(*mydump.CSVParser); ok { - ignoreOneLineFn = func() error { - _, _, err := csvParser.ReadUntilTerminator() - return err - } - } - - for e.IgnoreLines > 0 { - err := ignoreOneLineFn() - if err != nil { - if errors.Cause(err) == io.EOF { - return nil - } - return err - } - - e.IgnoreLines-- - } +func (e *LoadDataWorker) ReadOneBatchRows(ctx context.Context, parser mydump.Parser) error { for { if err := parser.ReadRow(); err != nil { if errors.Cause(err) == io.EOF { diff --git a/executor/loadremotetest/BUILD.bazel b/executor/loadremotetest/BUILD.bazel index d756e3d0f8..ab8884b12b 100644 --- a/executor/loadremotetest/BUILD.bazel +++ b/executor/loadremotetest/BUILD.bazel @@ -6,6 +6,7 @@ go_test( srcs = [ "error_test.go", "main_test.go", + "multi_file_test.go", "one_csv_test.go", "one_parquet_test.go", "one_sqldump_test.go", diff --git a/executor/loadremotetest/multi_file_test.go b/executor/loadremotetest/multi_file_test.go new file mode 100644 index 0000000000..4522cd596d --- /dev/null +++ b/executor/loadremotetest/multi_file_test.go @@ -0,0 +1,134 @@ +// Copyright 2023 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package loadremotetest + +import ( + "bytes" + "fmt" + "strconv" + + "github.com/fsouza/fake-gcs-server/fakestorage" + "github.com/pingcap/tidb/testkit" +) + +func (s *mockGCSSuite) TestFilenameAsterisk() { + s.tk.MustExec("DROP DATABASE IF EXISTS multi_load;") + s.tk.MustExec("CREATE DATABASE multi_load;") + s.tk.MustExec("CREATE TABLE multi_load.t (i INT PRIMARY KEY, s varchar(32));") + + s.server.CreateObject(fakestorage.Object{ + ObjectAttrs: fakestorage.ObjectAttrs{ + BucketName: "test-multi-load", + Name: "db.tbl.001.tsv", + }, + Content: []byte("1\ttest1\n" + + "2\ttest2"), + }) + s.server.CreateObject(fakestorage.Object{ + ObjectAttrs: fakestorage.ObjectAttrs{ + BucketName: "test-multi-load", + Name: "db.tbl.002.tsv", + }, + Content: []byte("3\ttest3\n" + + "4\ttest4"), + }) + s.server.CreateObject(fakestorage.Object{ + ObjectAttrs: fakestorage.ObjectAttrs{ + BucketName: "test-multi-load", + Name: "db.tbl.003.tsv", + }, + Content: []byte("5\ttest5\n" + + "6\ttest6"), + }) + s.server.CreateObject(fakestorage.Object{ + ObjectAttrs: fakestorage.ObjectAttrs{ + BucketName: "test-multi-load", + Name: "not.me.[1-9].tsv", + }, + Content: []byte("7\ttest7\n" + + "8\ttest8"), + }) + s.server.CreateObject(fakestorage.Object{ + ObjectAttrs: fakestorage.ObjectAttrs{ + BucketName: "not-me", + Name: "db.tbl.001.tsv", + }, + Content: []byte("9\ttest9\n" + + "10\ttest10"), + }) + + sql := fmt.Sprintf(`LOAD DATA INFILE 'gs://test-multi-load/db.tbl.*.tsv?endpoint=%s' + INTO TABLE multi_load.t;`, gcsEndpoint) + s.tk.MustExec(sql) + s.tk.MustQuery("SELECT * FROM multi_load.t;").Check(testkit.Rows( + "1 test1", "2 test2", "3 test3", "4 test4", "5 test5", "6 test6", + )) + + s.tk.MustExec("TRUNCATE TABLE multi_load.t;") + sql = fmt.Sprintf(`LOAD DATA INFILE 'gs://test-multi-load/db.tbl.*.tsv?endpoint=%s' + INTO TABLE multi_load.t IGNORE 1 LINES;`, gcsEndpoint) + s.tk.MustExec(sql) + s.tk.MustQuery("SELECT * FROM multi_load.t;").Check(testkit.Rows( + "2 test2", "4 test4", "6 test6", + )) + + // only '*' is supported in pattern matching + s.tk.MustExec("TRUNCATE TABLE multi_load.t;") + sql = fmt.Sprintf(`LOAD DATA INFILE 'gs://test-multi-load/not.me.[1-9].tsv?endpoint=%s' + INTO TABLE multi_load.t;`, gcsEndpoint) + s.tk.MustExec(sql) + s.tk.MustQuery("SELECT * FROM multi_load.t;").Check(testkit.Rows( + "7 test7", "8 test8", + )) +} + +func (s *mockGCSSuite) TestMultiBatchWithIgnoreLines() { + s.tk.MustExec("DROP DATABASE IF EXISTS multi_load;") + s.tk.MustExec("CREATE DATABASE multi_load;") + s.tk.MustExec("CREATE TABLE multi_load.t2 (i INT);") + + // [start, end] is both inclusive + genData := func(start, end int) []byte { + buf := make([][]byte, 0, end-start+1) + for i := start; i <= end; i++ { + buf = append(buf, []byte(strconv.Itoa(i))) + } + return bytes.Join(buf, []byte("\n")) + } + + s.server.CreateObject(fakestorage.Object{ + ObjectAttrs: fakestorage.ObjectAttrs{ + BucketName: "test-multi-load", + Name: "multi-batch.001.tsv", + }, + Content: genData(1, 10), + }) + s.server.CreateObject(fakestorage.Object{ + ObjectAttrs: fakestorage.ObjectAttrs{ + BucketName: "test-multi-load", + Name: "multi-batch.002.tsv", + }, + Content: genData(11, 20), + }) + + s.tk.MustExec("SET SESSION tidb_dml_batch_size = 3;") + sql := fmt.Sprintf(`LOAD DATA INFILE 'gs://test-multi-load/multi-batch.*.tsv?endpoint=%s' + INTO TABLE multi_load.t2 IGNORE 2 LINES;`, gcsEndpoint) + s.tk.MustExec(sql) + s.tk.MustQuery("SELECT * FROM multi_load.t2;").Check(testkit.Rows( + "3", "4", "5", "6", "7", "8", "9", "10", + "13", "14", "15", "16", "17", "18", "19", "20", + )) +} diff --git a/executor/writetest/write_test.go b/executor/writetest/write_test.go index 861186e52b..d5813271dc 100644 --- a/executor/writetest/write_test.go +++ b/executor/writetest/write_test.go @@ -1898,7 +1898,13 @@ func checkCases( nil) require.NoError(t, err) - err1 := ld.ReadRows(context.Background(), parser) + for ld.IgnoreLines > 0 { + ld.IgnoreLines-- + //nolint: errcheck + _ = parser.ReadRow() + } + + err1 := ld.ReadOneBatchRows(context.Background(), parser) require.NoError(t, err1) err1 = ld.CheckAndInsertOneBatch(context.Background(), ld.GetRows(), ld.GetCurBatchCnt()) require.NoError(t, err1) @@ -1939,7 +1945,7 @@ func TestLoadDataMissingColumn(t *testing.T) { false, nil) require.NoError(t, err) - err = ld.ReadRows(context.Background(), parser) + err = ld.ReadOneBatchRows(context.Background(), parser) require.NoError(t, err) require.Len(t, ld.GetRows(), 0) r := tk.MustQuery(selectSQL) @@ -2065,7 +2071,7 @@ func TestLoadData(t *testing.T) { false, nil) require.NoError(t, err) - err = ld.ReadRows(context.Background(), parser) + err = ld.ReadOneBatchRows(context.Background(), parser) require.NoError(t, err) err = ld.CheckAndInsertOneBatch(context.Background(), ld.GetRows(), ld.GetCurBatchCnt()) require.NoError(t, err) @@ -2079,7 +2085,7 @@ func TestLoadData(t *testing.T) { sc.IgnoreTruncate = originIgnoreTruncate }() sc.IgnoreTruncate = false - // fields and lines are default, ReadRows returns data is nil + // fields and lines are default, ReadOneBatchRows returns data is nil tests := []testCase{ // In MySQL we have 4 warnings: 1*"Incorrect integer value: '' for column 'id' at row", 3*"Row 1 doesn't contain data for all columns" {[]byte("\n"), []string{"1|||"}, "Records: 1 Deleted: 0 Skipped: 0 Warnings: 2"}, @@ -2099,7 +2105,7 @@ func TestLoadData(t *testing.T) { } checkCases(tests, ld, t, tk, ctx, selectSQL, deleteSQL) - // lines starting symbol is "" and terminated symbol length is 2, ReadRows returns data is nil + // lines starting symbol is "" and terminated symbol length is 2, ReadOneBatchRows returns data is nil ld.LinesInfo.Terminated = "||" tests = []testCase{ {[]byte("0\t2\t3\t4\t5||"), []string{"12|2|3|4"}, "Records: 1 Deleted: 0 Skipped: 0 Warnings: 1"}, @@ -2114,7 +2120,7 @@ func TestLoadData(t *testing.T) { } checkCases(tests, ld, t, tk, ctx, selectSQL, deleteSQL) - // fields and lines aren't default, ReadRows returns data is nil + // fields and lines aren't default, ReadOneBatchRows returns data is nil ld.FieldsInfo.Terminated = "\\" ld.LinesInfo.Starting = "xxx" ld.LinesInfo.Terminated = "|!#^" @@ -2155,7 +2161,7 @@ func TestLoadData(t *testing.T) { checkCases(tests, ld, t, tk, ctx, selectSQL, deleteSQL) // TODO: not support it now - // lines starting symbol is the same as terminated symbol, ReadRows returns data is nil + // lines starting symbol is the same as terminated symbol, ReadOneBatchRows returns data is nil //ld.LinesInfo.Terminated = "xxx" //tests = []testCase{ // // data1 = nil, data2 != nil @@ -2178,7 +2184,7 @@ func TestLoadData(t *testing.T) { // {[]byte("xxx34\\2\\3\\4\\5xx"), []byte("xxxx35\\22\\33xxxxxx36\\222xxx"), // []string{"34|2|3|4", "35|22|33|", "36|222||"}, nil, "Records: 3 Deleted: 0 Skipped: 0 Warnings: 0"}, // - // // ReadRows returns data isn't nil + // // ReadOneBatchRows returns data isn't nil // {[]byte("\\2\\3\\4xxxx"), nil, []byte("xxxx"), "Records: 0 Deleted: 0 Skipped: 0 Warnings: 0"}, // {[]byte("\\2\\3\\4xxx"), nil, []string{"37|||"}, nil, "Records: 1 Deleted: 0 Skipped: 0 Warnings: 1"}, // {[]byte("\\2\\3\\4xxxxxx11\\22\\33\\44xxx"), nil, @@ -2414,7 +2420,7 @@ func TestLoadDataIntoPartitionedTable(t *testing.T) { nil) require.NoError(t, err) - err = ld.ReadRows(context.Background(), parser) + err = ld.ReadOneBatchRows(context.Background(), parser) require.NoError(t, err) err = ld.CheckAndInsertOneBatch(context.Background(), ld.GetRows(), ld.GetCurBatchCnt()) require.NoError(t, err) diff --git a/server/conn.go b/server/conn.go index ea1b71d672..719afbe17d 100644 --- a/server/conn.go +++ b/server/conn.go @@ -1619,7 +1619,10 @@ func (cc *clientConn) handleLoadData(ctx context.Context, loadDataWorker *execut } }() - err = loadDataWorker.Load(ctx, executor.NewSimpleSeekerOnReadCloser(r)) + err = loadDataWorker.Load(ctx, []executor.LoadDataReaderInfo{{ + Opener: func(_ context.Context) (io.ReadSeekCloser, error) { + return executor.NewSimpleSeekerOnReadCloser(r), nil + }}}) _ = r.Close() wg.Wait() diff --git a/util/stringutil/string_util.go b/util/stringutil/string_util.go index 7e572693fe..4d4fb90891 100644 --- a/util/stringutil/string_util.go +++ b/util/stringutil/string_util.go @@ -406,3 +406,17 @@ func ConvertPosInUtf8(str *string, pos int64) int64 { preStrNum := utf8.RuneCountInString(preStr) return int64(preStrNum + 1) } + +// EscapeGlobExceptAsterisk escapes '?', '[', ']' for a glob path pattern. +func EscapeGlobExceptAsterisk(s string) string { + var buf strings.Builder + buf.Grow(len(s)) + for _, c := range s { + switch c { + case '?', '[', ']': + buf.WriteByte('\\') + } + buf.WriteRune(c) + } + return buf.String() +} diff --git a/util/stringutil/string_util_test.go b/util/stringutil/string_util_test.go index 7574b5e49e..0b8de86388 100644 --- a/util/stringutil/string_util_test.go +++ b/util/stringutil/string_util_test.go @@ -194,6 +194,18 @@ func TestBuildStringFromLabels(t *testing.T) { } } +func TestEscapeGlobExceptAsterisk(t *testing.T) { + cases := [][2]string{ + {"123", "123"}, + {"12*3", "12*3"}, + {"12?", `12\?`}, + {`[1-2]`, `\[1-2\]`}, + } + for _, pair := range cases { + require.Equal(t, pair[1], EscapeGlobExceptAsterisk(pair[0])) + } +} + func BenchmarkDoMatch(b *testing.B) { escape := byte('\\') tbl := []struct {