From 4fbe02adaf23eaee14d462ec3d95053774c5b673 Mon Sep 17 00:00:00 2001 From: lance6716 Date: Tue, 3 Dec 2024 16:19:47 +0800 Subject: [PATCH] dumpling: use I_S to get table list for TiDB and add database to WHERE (#57894) close pingcap/tidb#57902 --- dumpling/export/dump.go | 20 +++++++---- dumpling/export/prepare_test.go | 14 ++++---- dumpling/export/sql.go | 61 ++++++++++++++++----------------- dumpling/export/sql_test.go | 4 +-- 4 files changed, 51 insertions(+), 48 deletions(-) diff --git a/dumpling/export/dump.go b/dumpling/export/dump.go index 8941b12a86..8174d32746 100644 --- a/dumpling/export/dump.go +++ b/dumpling/export/dump.go @@ -1150,15 +1150,21 @@ func prepareTableListToDump(tctx *tcontext.Context, conf *Config, db *sql.Conn) return nil } - ifSeqExists, err := CheckIfSeqExists(db) - if err != nil { - return err - } var listType listTableType - if ifSeqExists { - listType = listTableByShowFullTables + + // TiDB has optimized the performance of reading INFORMATION_SCHEMA.TABLES + if conf.ServerInfo.ServerType == version.ServerTypeTiDB { + listType = listTableByInfoSchema } else { - listType = getListTableTypeByConf(conf) + ifSeqExists, err := checkIfSeqExists(db) + if err != nil { + return err + } + if ifSeqExists { + listType = listTableByShowFullTables + } else { + listType = getListTableTypeByConf(conf) + } } if conf.SpecifiedTables { diff --git a/dumpling/export/prepare_test.go b/dumpling/export/prepare_test.go index ebfb2ee7ef..bc745396d0 100644 --- a/dumpling/export/prepare_test.go +++ b/dumpling/export/prepare_test.go @@ -79,19 +79,19 @@ func TestListAllTables(t *testing.T) { AppendViews("db3", "t6", "t7", "t8") dbNames := make([]databaseName, 0, len(data)) - rows := sqlmock.NewRows([]string{"TABLE_SCHEMA", "TABLE_NAME", "TABLE_TYPE", "AVG_ROW_LENGTH"}) for dbName, tableInfos := range data { dbNames = append(dbNames, dbName) + query := "SELECT TABLE_NAME,TABLE_TYPE,AVG_ROW_LENGTH FROM INFORMATION_SCHEMA.TABLES WHERE TABLE_SCHEMA=\\? AND \\(TABLE_TYPE='BASE TABLE'\\)" + rows := sqlmock.NewRows([]string{"TABLE_NAME", "TABLE_TYPE", "AVG_ROW_LENGTH"}) for _, tbInfo := range tableInfos { if tbInfo.Type == TableTypeView { continue } - rows.AddRow(dbName, tbInfo.Name, tbInfo.Type.String(), tbInfo.AvgRowLength) + rows.AddRow(tbInfo.Name, tbInfo.Type.String(), tbInfo.AvgRowLength) } + mock.ExpectQuery(query).WithArgs(dbName).WillReturnRows(rows) } - query := "SELECT TABLE_SCHEMA,TABLE_NAME,TABLE_TYPE,AVG_ROW_LENGTH FROM INFORMATION_SCHEMA.TABLES WHERE TABLE_TYPE='BASE TABLE'" - mock.ExpectQuery(query).WillReturnRows(rows) tables, err := ListAllDatabasesTables(tctx, conn, dbNames, listTableByInfoSchema, TableTypeBase) require.NoError(t, err) @@ -108,9 +108,9 @@ func TestListAllTables(t *testing.T) { data = NewDatabaseTables(). AppendTables("db", []string{"t1"}, []uint64{1}). AppendViews("db", "t2") - query = "SELECT TABLE_SCHEMA,TABLE_NAME,TABLE_TYPE,AVG_ROW_LENGTH FROM INFORMATION_SCHEMA.TABLES WHERE TABLE_TYPE='BASE TABLE' OR TABLE_TYPE='VIEW'" - mock.ExpectQuery(query).WillReturnRows(sqlmock.NewRows([]string{"TABLE_SCHEMA", "TABLE_NAME", "TABLE_TYPE", "AVG_ROW_LENGTH"}). - AddRow("db", "t1", TableTypeBaseStr, 1).AddRow("db", "t2", TableTypeViewStr, nil)) + query := "SELECT TABLE_NAME,TABLE_TYPE,AVG_ROW_LENGTH FROM INFORMATION_SCHEMA.TABLES WHERE TABLE_SCHEMA=\\? AND \\(TABLE_TYPE='BASE TABLE' OR TABLE_TYPE='VIEW'\\)" + mock.ExpectQuery(query).WithArgs("db").WillReturnRows(sqlmock.NewRows([]string{"TABLE_NAME", "TABLE_TYPE", "AVG_ROW_LENGTH"}). + AddRow("t1", TableTypeBaseStr, 1).AddRow("t2", TableTypeViewStr, nil)) tables, err = ListAllDatabasesTables(tctx, conn, []string{"db"}, listTableByInfoSchema, TableTypeBase, TableTypeView) require.NoError(t, err) require.Len(t, tables, 1) diff --git a/dumpling/export/sql.go b/dumpling/export/sql.go index 3fc937c617..67ee1dcc30 100644 --- a/dumpling/export/sql.go +++ b/dumpling/export/sql.go @@ -368,10 +368,10 @@ func ListAllDatabasesTables(tctx *tcontext.Context, db *sql.Conn, databaseNames listType listTableType, tableTypes ...TableType) (DatabaseTables, error) { // revive:disable-line:flag-parameter dbTables := DatabaseTables{} var ( - schema, table, tableTypeStr string - tableType TableType - avgRowLength uint64 - err error + table, tableTypeStr string + tableType TableType + avgRowLength uint64 + err error ) tableTypeConditions := make([]string, len(tableTypes)) @@ -380,38 +380,35 @@ func ListAllDatabasesTables(tctx *tcontext.Context, db *sql.Conn, databaseNames } switch listType { case listTableByInfoSchema: - query := fmt.Sprintf("SELECT TABLE_SCHEMA,TABLE_NAME,TABLE_TYPE,AVG_ROW_LENGTH FROM INFORMATION_SCHEMA.TABLES WHERE %s", strings.Join(tableTypeConditions, " OR ")) for _, schema := range databaseNames { + query := fmt.Sprintf("SELECT TABLE_NAME,TABLE_TYPE,AVG_ROW_LENGTH FROM INFORMATION_SCHEMA.TABLES WHERE TABLE_SCHEMA=? AND (%s)", strings.Join(tableTypeConditions, " OR ")) dbTables[schema] = make([]*TableInfo, 0) - } - if err = simpleQueryWithArgs(tctx, db, func(rows *sql.Rows) error { - var ( - sqlAvgRowLength sql.NullInt64 - err2 error - ) - if err2 = rows.Scan(&schema, &table, &tableTypeStr, &sqlAvgRowLength); err != nil { - return errors.Trace(err2) - } - tableType, err2 = ParseTableType(tableTypeStr) - if err2 != nil { - return errors.Trace(err2) - } + if err = simpleQueryWithArgs(tctx, db, func(rows *sql.Rows) error { + var ( + sqlAvgRowLength sql.NullInt64 + err2 error + ) + if err2 = rows.Scan(&table, &tableTypeStr, &sqlAvgRowLength); err != nil { + return errors.Trace(err2) + } + tableType, err2 = ParseTableType(tableTypeStr) + if err2 != nil { + return errors.Trace(err2) + } - if sqlAvgRowLength.Valid { - avgRowLength = uint64(sqlAvgRowLength.Int64) - } else { - avgRowLength = 0 - } - // only append tables to schemas in databaseNames - if _, ok := dbTables[schema]; ok { + if sqlAvgRowLength.Valid { + avgRowLength = uint64(sqlAvgRowLength.Int64) + } else { + avgRowLength = 0 + } dbTables[schema] = append(dbTables[schema], &TableInfo{table, avgRowLength, tableType}) + return nil + }, query, schema); err != nil { + return nil, errors.Annotatef(err, "sql: %s", query) } - return nil - }, query); err != nil { - return nil, errors.Annotatef(err, "sql: %s", query) } case listTableByShowFullTables: - for _, schema = range databaseNames { + for _, schema := range databaseNames { dbTables[schema] = make([]*TableInfo, 0) query := fmt.Sprintf("SHOW FULL TABLES FROM `%s` WHERE %s", escapeString(schema), strings.Join(tableTypeConditions, " OR ")) @@ -437,7 +434,7 @@ func ListAllDatabasesTables(tctx *tcontext.Context, db *sql.Conn, databaseNames for _, tableType = range tableTypes { selectedTableType[tableType] = struct{}{} } - for _, schema = range databaseNames { + for _, schema := range databaseNames { dbTables[schema] = make([]*TableInfo, 0) query := fmt.Sprintf(queryTemplate, escapeString(schema)) rows, err := db.QueryContext(tctx, query) @@ -922,8 +919,8 @@ func CheckTiDBWithTiKV(db *sql.DB) (bool, error) { return count > 0, nil } -// CheckIfSeqExists use sql to check whether sequence exists -func CheckIfSeqExists(db *sql.Conn) (bool, error) { +// checkIfSeqExists use sql to check whether sequence exists +func checkIfSeqExists(db *sql.Conn) (bool, error) { var count int const query = "SELECT COUNT(1) as c FROM INFORMATION_SCHEMA.TABLES WHERE TABLE_TYPE='SEQUENCE'" row := db.QueryRowContext(context.Background(), query) diff --git a/dumpling/export/sql_test.go b/dumpling/export/sql_test.go index e5772566f7..f612688a92 100644 --- a/dumpling/export/sql_test.go +++ b/dumpling/export/sql_test.go @@ -1815,7 +1815,7 @@ func TestCheckIfSeqExists(t *testing.T) { WillReturnRows(sqlmock.NewRows([]string{"c"}). AddRow("1")) - exists, err := CheckIfSeqExists(conn) + exists, err := checkIfSeqExists(conn) require.NoError(t, err) require.Equal(t, true, exists) @@ -1823,7 +1823,7 @@ func TestCheckIfSeqExists(t *testing.T) { WillReturnRows(sqlmock.NewRows([]string{"c"}). AddRow("0")) - exists, err = CheckIfSeqExists(conn) + exists, err = checkIfSeqExists(conn) require.NoError(t, err) require.Equal(t, false, exists) }