diff --git a/executor/lockstats/BUILD.bazel b/executor/lockstats/BUILD.bazel index 1357a96ee1..4d05fa72d6 100644 --- a/executor/lockstats/BUILD.bazel +++ b/executor/lockstats/BUILD.bazel @@ -14,7 +14,7 @@ go_library( "//infoschema", "//parser/ast", "//parser/model", - "//statistics/handle/lockstats", + "//statistics/handle/util", "//table/tables", "//util/chunk", "@com_github_pingcap_errors//:errors", diff --git a/executor/lockstats/lock_stats_executor.go b/executor/lockstats/lock_stats_executor.go index 0e797387ae..39660c05ae 100644 --- a/executor/lockstats/lock_stats_executor.go +++ b/executor/lockstats/lock_stats_executor.go @@ -24,7 +24,7 @@ import ( "github.com/pingcap/tidb/infoschema" "github.com/pingcap/tidb/parser/ast" "github.com/pingcap/tidb/parser/model" - "github.com/pingcap/tidb/statistics/handle/lockstats" + "github.com/pingcap/tidb/statistics/handle/util" "github.com/pingcap/tidb/table/tables" "github.com/pingcap/tidb/util/chunk" ) @@ -135,11 +135,11 @@ func populatePartitionIDAndNames( func populateTableAndPartitionIDs( tables []*ast.TableName, is infoschema.InfoSchema, -) (map[int64]*lockstats.TableInfo, error) { +) (map[int64]*util.StatsLockTable, error) { if len(tables) == 0 { return nil, errors.New("table list should not be empty") } - tableWithPartitions := make(map[int64]*lockstats.TableInfo, len(tables)) + tableWithPartitions := make(map[int64]*util.StatsLockTable, len(tables)) for _, table := range tables { tbl, err := is.TableByName(table.Schema, table.Name) @@ -147,7 +147,7 @@ func populateTableAndPartitionIDs( return nil, err } tid := tbl.Meta().ID - tableWithPartitions[tid] = &lockstats.TableInfo{ + tableWithPartitions[tid] = &util.StatsLockTable{ FullName: fmt.Sprintf("%s.%s", table.Schema.L, table.Name.L), } diff --git a/statistics/handle/BUILD.bazel b/statistics/handle/BUILD.bazel index d8e873392d..6ac1103d13 100644 --- a/statistics/handle/BUILD.bazel +++ b/statistics/handle/BUILD.bazel @@ -8,7 +8,6 @@ go_library( "dump.go", "handle.go", "handle_hist.go", - "lock_stats_handler.go", "update.go", ], importpath = "github.com/pingcap/tidb/statistics/handle", diff --git a/statistics/handle/handle.go b/statistics/handle/handle.go index 5bad1a33d7..a0d95941e3 100644 --- a/statistics/handle/handle.go +++ b/statistics/handle/handle.go @@ -33,6 +33,7 @@ import ( "github.com/pingcap/tidb/statistics/handle/extstats" "github.com/pingcap/tidb/statistics/handle/globalstats" "github.com/pingcap/tidb/statistics/handle/history" + "github.com/pingcap/tidb/statistics/handle/lockstats" handle_metrics "github.com/pingcap/tidb/statistics/handle/metrics" "github.com/pingcap/tidb/statistics/handle/storage" "github.com/pingcap/tidb/statistics/handle/usage" @@ -70,6 +71,9 @@ type Handle struct { // StatsAnalyze is used to handle auto-analyze and manage analyze jobs. util.StatsAnalyze + // StatsLock is used to manage locked stats. + util.StatsLock + // This gpool is used to reuse goroutine in the mergeGlobalStatsTopN. gpool *gp.Pool @@ -120,6 +124,7 @@ func NewHandle(_, initStatsCtx sessionctx.Context, lease time.Duration, pool uti InitStatsDone: make(chan struct{}), TableInfoGetter: util.NewTableInfoGetter(), StatsAnalyze: autoanalyze.NewStatsAnalyze(pool), + StatsLock: lockstats.NewStatsLock(pool), } handle.StatsGC = storage.NewStatsGC(pool, lease, handle.TableInfoGetter, handle.MarkExtendedStatsDeleted) diff --git a/statistics/handle/lock_stats_handler.go b/statistics/handle/lock_stats_handler.go deleted file mode 100644 index c4df6829ef..0000000000 --- a/statistics/handle/lock_stats_handler.go +++ /dev/null @@ -1,104 +0,0 @@ -// Copyright 2023 PingCAP, Inc. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package handle - -import ( - "github.com/pingcap/tidb/sessionctx" - "github.com/pingcap/tidb/statistics/handle/lockstats" - "github.com/pingcap/tidb/util/sqlexec" -) - -// LockTables add locked tables id to store. -// - tables: tables that will be locked. -// Return the message of skipped tables and error. -func (h *Handle) LockTables(tables map[int64]*lockstats.TableInfo) (skipped string, err error) { - err = h.callWithSCtx(func(sctx sessionctx.Context) error { - skipped, err = lockstats.AddLockedTables(sctx.(sqlexec.RestrictedSQLExecutor), tables) - return err - }) - return -} - -// LockPartitions add locked partitions id to store. -// If the whole table is locked, then skip all partitions of the table. -// - tid: table id of which will be locked. -// - tableName: table name of which will be locked. -// - pidNames: partition ids of which will be locked. -// Return the message of skipped tables and error. -// Note: If the whole table is locked, then skip all partitions of the table. -func (h *Handle) LockPartitions( - tid int64, - tableName string, - pidNames map[int64]string, -) (skipped string, err error) { - err = h.callWithSCtx(func(sctx sessionctx.Context) error { - skipped, err = lockstats.AddLockedPartitions(sctx.(sqlexec.RestrictedSQLExecutor), tid, tableName, pidNames) - return err - }) - return -} - -// RemoveLockedTables remove tables from table locked records. -// - tables: tables of which will be unlocked. -// Return the message of skipped tables and error. -func (h *Handle) RemoveLockedTables(tables map[int64]*lockstats.TableInfo) (skipped string, err error) { - err = h.callWithSCtx(func(sctx sessionctx.Context) error { - skipped, err = lockstats.RemoveLockedTables(sctx.(sqlexec.RestrictedSQLExecutor), tables) - return err - }) - return -} - -// RemoveLockedPartitions remove partitions from table locked records. -// - tid: table id of which will be unlocked. -// - tableName: table name of which will be unlocked. -// - pidNames: partition ids of which will be unlocked. -// Note: If the whole table is locked, then skip all partitions of the table. -func (h *Handle) RemoveLockedPartitions( - tid int64, - tableName string, - pidNames map[int64]string, -) (skipped string, err error) { - err = h.callWithSCtx(func(sctx sessionctx.Context) error { - skipped, err = lockstats.RemoveLockedPartitions(sctx.(sqlexec.RestrictedSQLExecutor), tid, tableName, pidNames) - return err - }) - return -} - -// GetLockedTables returns the locked status of the given tables. -// Note: This function query locked tables from store, so please try to batch the query. -func (h *Handle) GetLockedTables(tableIDs ...int64) (map[int64]struct{}, error) { - tableLocked, err := h.queryLockedTables() - if err != nil { - return nil, err - } - - return lockstats.GetLockedTables(tableLocked, tableIDs...), nil -} - -// queryLockedTables query locked tables from store. -func (h *Handle) queryLockedTables() (tables map[int64]struct{}, err error) { - err = h.callWithSCtx(func(sctx sessionctx.Context) error { - tables, err = lockstats.QueryLockedTables(sctx.(sqlexec.RestrictedSQLExecutor)) - return err - }) - return -} - -// GetTableLockedAndClearForTest for unit test only. -func (h *Handle) GetTableLockedAndClearForTest() (map[int64]struct{}, error) { - return h.queryLockedTables() -} diff --git a/statistics/handle/lockstats/BUILD.bazel b/statistics/handle/lockstats/BUILD.bazel index 578fed91cf..9336d82c38 100644 --- a/statistics/handle/lockstats/BUILD.bazel +++ b/statistics/handle/lockstats/BUILD.bazel @@ -10,7 +10,7 @@ go_library( importpath = "github.com/pingcap/tidb/statistics/handle/lockstats", visibility = ["//visibility:public"], deps = [ - "//parser/terror", + "//sessionctx", "//statistics/handle/cache", "//statistics/handle/util", "//util/logutil", @@ -34,8 +34,11 @@ go_test( deps = [ "//kv", "//parser/mysql", + "//sessionctx", + "//statistics/handle/util", "//types", "//util/chunk", + "//util/mock", "//util/sqlexec/mock", "@com_github_pingcap_errors//:errors", "@com_github_stretchr_testify//require", diff --git a/statistics/handle/lockstats/lock_stats.go b/statistics/handle/lockstats/lock_stats.go index 8d7717dfcc..d15101d086 100644 --- a/statistics/handle/lockstats/lock_stats.go +++ b/statistics/handle/lockstats/lock_stats.go @@ -15,11 +15,11 @@ package lockstats import ( - "context" "fmt" "slices" "strings" + "github.com/pingcap/tidb/sessionctx" "github.com/pingcap/tidb/statistics/handle/util" "github.com/pingcap/tidb/util/logutil" "github.com/pingcap/tidb/util/sqlexec" @@ -35,6 +35,99 @@ const ( insertSQL = "INSERT INTO mysql.stats_table_locked (table_id) VALUES (%?) ON DUPLICATE KEY UPDATE table_id = %?" ) +// statsLockImpl implements the util.StatsLock interface. +type statsLockImpl struct { + pool util.SessionPool +} + +// NewStatsLock creates a new StatsLock. +func NewStatsLock(pool util.SessionPool) util.StatsLock { + return &statsLockImpl{pool: pool} +} + +// LockTables add locked tables id to store. +// - tables: tables that will be locked. +// Return the message of skipped tables and error. +func (sl *statsLockImpl) LockTables(tables map[int64]*util.StatsLockTable) (skipped string, err error) { + err = util.CallWithSCtx(sl.pool, func(sctx sessionctx.Context) error { + skipped, err = AddLockedTables(sctx, tables) + return err + }, util.FlagWrapTxn) + return +} + +// LockPartitions add locked partitions id to store. +// If the whole table is locked, then skip all partitions of the table. +// - tid: table id of which will be locked. +// - tableName: table name of which will be locked. +// - pidNames: partition ids of which will be locked. +// Return the message of skipped tables and error. +// Note: If the whole table is locked, then skip all partitions of the table. +func (sl *statsLockImpl) LockPartitions( + tid int64, + tableName string, + pidNames map[int64]string, +) (skipped string, err error) { + err = util.CallWithSCtx(sl.pool, func(sctx sessionctx.Context) error { + skipped, err = AddLockedPartitions(sctx, tid, tableName, pidNames) + return err + }, util.FlagWrapTxn) + return +} + +// RemoveLockedTables remove tables from table locked records. +// - tables: tables of which will be unlocked. +// Return the message of skipped tables and error. +func (sl *statsLockImpl) RemoveLockedTables(tables map[int64]*util.StatsLockTable) (skipped string, err error) { + err = util.CallWithSCtx(sl.pool, func(sctx sessionctx.Context) error { + skipped, err = RemoveLockedTables(sctx, tables) + return err + }, util.FlagWrapTxn) + return +} + +// RemoveLockedPartitions remove partitions from table locked records. +// - tid: table id of which will be unlocked. +// - tableName: table name of which will be unlocked. +// - pidNames: partition ids of which will be unlocked. +// Note: If the whole table is locked, then skip all partitions of the table. +func (sl *statsLockImpl) RemoveLockedPartitions( + tid int64, + tableName string, + pidNames map[int64]string, +) (skipped string, err error) { + err = util.CallWithSCtx(sl.pool, func(sctx sessionctx.Context) error { + skipped, err = RemoveLockedPartitions(sctx, tid, tableName, pidNames) + return err + }, util.FlagWrapTxn) + return +} + +// queryLockedTables query locked tables from store. +func (sl *statsLockImpl) queryLockedTables() (tables map[int64]struct{}, err error) { + err = util.CallWithSCtx(sl.pool, func(sctx sessionctx.Context) error { + tables, err = QueryLockedTables(sctx) + return err + }) + return +} + +// GetLockedTables returns the locked status of the given tables. +// Note: This function query locked tables from store, so please try to batch the query. +func (sl *statsLockImpl) GetLockedTables(tableIDs ...int64) (map[int64]struct{}, error) { + tableLocked, err := sl.queryLockedTables() + if err != nil { + return nil, err + } + + return GetLockedTables(tableLocked, tableIDs...), nil +} + +// GetTableLockedAndClearForTest for unit test only. +func (sl *statsLockImpl) GetTableLockedAndClearForTest() (map[int64]struct{}, error) { + return sl.queryLockedTables() +} + var ( // Stats logger. statsLogger = logutil.BgLogger().With(zap.String("category", "stats")) @@ -42,33 +135,16 @@ var ( useCurrentSession = []sqlexec.OptionFuncAlias{sqlexec.ExecOptionUseCurSession} ) -// TableInfo is the table info of which will be locked. -type TableInfo struct { - PartitionInfo map[int64]string - // schema name + table name. - FullName string -} - // AddLockedTables add locked tables id to store. // - exec: sql executor. // - tables: tables that will be locked. // Return the message of skipped tables and error. func AddLockedTables( - exec sqlexec.RestrictedSQLExecutor, - tables map[int64]*TableInfo, + sctx sessionctx.Context, + tables map[int64]*util.StatsLockTable, ) (string, error) { - ctx := util.StatsCtx(context.Background()) - err := startTransaction(ctx, exec) - if err != nil { - return "", err - } - defer func() { - // Commit transaction. - err = finishTransaction(ctx, exec, err) - }() - // Load tables to check duplicate before insert. - lockedTables, err := QueryLockedTables(exec) + lockedTables, err := QueryLockedTables(sctx) if err != nil { return "", err } @@ -89,7 +165,7 @@ func AddLockedTables( lockedTablesAndPartitions := GetLockedTables(lockedTables, ids...) for tid, table := range tables { if _, ok := lockedTablesAndPartitions[tid]; !ok { - if err := insertIntoStatsTableLocked(ctx, exec, tid); err != nil { + if err := insertIntoStatsTableLocked(sctx, tid); err != nil { return "", err } } else { @@ -98,7 +174,7 @@ func AddLockedTables( for pid := range table.PartitionInfo { if _, ok := lockedTablesAndPartitions[pid]; !ok { - if err := insertIntoStatsTableLocked(ctx, exec, pid); err != nil { + if err := insertIntoStatsTableLocked(sctx, pid); err != nil { return "", err } } @@ -118,23 +194,13 @@ func AddLockedTables( // - pidNames: partition ids of which will be locked. // Return the message of skipped tables and error. func AddLockedPartitions( - exec sqlexec.RestrictedSQLExecutor, + sctx sessionctx.Context, tid int64, tableName string, pidNames map[int64]string, ) (string, error) { - ctx := util.StatsCtx(context.Background()) - err := startTransaction(ctx, exec) - if err != nil { - return "", err - } - defer func() { - // Commit transaction. - err = finishTransaction(ctx, exec, err) - }() - // Load tables to check duplicate before insert. - lockedTables, err := QueryLockedTables(exec) + lockedTables, err := QueryLockedTables(sctx) if err != nil { return "", err } @@ -165,7 +231,7 @@ func AddLockedPartitions( lockedPartitions := GetLockedTables(lockedTables, pids...) for _, pid := range pids { if _, ok := lockedPartitions[pid]; !ok { - if err := insertIntoStatsTableLocked(ctx, exec, pid); err != nil { + if err := insertIntoStatsTableLocked(sctx, pid); err != nil { return "", err } } else { @@ -223,12 +289,8 @@ func generateStableSkippedPartitionsMessage(ids []int64, tableName string, skipp return "" } -func insertIntoStatsTableLocked(ctx context.Context, exec sqlexec.RestrictedSQLExecutor, tid int64) error { - _, _, err := exec.ExecRestrictedSQL( - ctx, - useCurrentSession, - insertSQL, tid, tid, - ) +func insertIntoStatsTableLocked(sctx sessionctx.Context, tid int64) error { + _, _, err := util.ExecRows(sctx, insertSQL, tid, tid) if err != nil { logutil.BgLogger().Error("error occurred when insert mysql.stats_table_locked", zap.String("category", "stats"), zap.Error(err)) return err diff --git a/statistics/handle/lockstats/lock_stats_test.go b/statistics/handle/lockstats/lock_stats_test.go index 7b90f596a5..ba4315c728 100644 --- a/statistics/handle/lockstats/lock_stats_test.go +++ b/statistics/handle/lockstats/lock_stats_test.go @@ -15,11 +15,11 @@ package lockstats import ( - "context" "testing" "github.com/pingcap/errors" "github.com/pingcap/tidb/parser/mysql" + "github.com/pingcap/tidb/statistics/handle/util" "github.com/pingcap/tidb/types" "github.com/pingcap/tidb/util/chunk" "github.com/pingcap/tidb/util/sqlexec/mock" @@ -150,19 +150,18 @@ func TestGenerateSkippedPartitionsMessage(t *testing.T) { } func TestInsertIntoStatsTableLocked(t *testing.T) { - ctx := context.Background() ctrl := gomock.NewController(t) defer ctrl.Finish() exec := mock.NewMockRestrictedSQLExecutor(ctrl) // Executed SQL should be: exec.EXPECT().ExecRestrictedSQL( - gomock.Eq(ctx), - useCurrentSession, + util.StatsCtx, + util.UseCurrentSessionOpt, gomock.Eq(insertSQL), gomock.Eq([]interface{}{int64(1), int64(1)}), ) - err := insertIntoStatsTableLocked(ctx, exec, 1) + err := insertIntoStatsTableLocked(wrapAsSCtx(exec), 1) require.NoError(t, err) // Error should be returned when ExecRestrictedSQL returns error. @@ -173,7 +172,7 @@ func TestInsertIntoStatsTableLocked(t *testing.T) { gomock.Any(), ).Return(nil, nil, errors.New("test error")) - err = insertIntoStatsTableLocked(ctx, exec, 1) + err = insertIntoStatsTableLocked(wrapAsSCtx(exec), 1) require.Equal(t, "test error", err.Error()) } @@ -182,49 +181,37 @@ func TestAddLockedTables(t *testing.T) { defer ctrl.Finish() exec := mock.NewMockRestrictedSQLExecutor(ctrl) - // Executed SQL should be: - exec.EXPECT().ExecRestrictedSQL( - gomock.All(&ctxMatcher{}), - useCurrentSession, - gomock.Eq("BEGIN PESSIMISTIC"), - ) // Return table 1 is locked. c := chunk.NewChunkWithCapacity([]*types.FieldType{types.NewFieldType(mysql.TypeLonglong)}, 1) c.AppendInt64(0, int64(1)) rows := []chunk.Row{c.GetRow(0)} exec.EXPECT().ExecRestrictedSQL( gomock.All(&ctxMatcher{}), - useCurrentSession, + util.UseCurrentSessionOpt, selectSQL, ).Return(rows, nil, nil) exec.EXPECT().ExecRestrictedSQL( gomock.All(&ctxMatcher{}), - useCurrentSession, + util.UseCurrentSessionOpt, insertSQL, gomock.Eq([]interface{}{int64(2), int64(2)}), ) exec.EXPECT().ExecRestrictedSQL( gomock.All(&ctxMatcher{}), - useCurrentSession, + util.UseCurrentSessionOpt, insertSQL, gomock.Eq([]interface{}{int64(3), int64(3)}), ) exec.EXPECT().ExecRestrictedSQL( gomock.All(&ctxMatcher{}), - useCurrentSession, + util.UseCurrentSessionOpt, insertSQL, gomock.Eq([]interface{}{int64(4), int64(4)}), ) - exec.EXPECT().ExecRestrictedSQL( - gomock.All(&ctxMatcher{}), - useCurrentSession, - "COMMIT", - ) - - tables := map[int64]*TableInfo{ + tables := map[int64]*util.StatsLockTable{ 1: { FullName: "test.t1", PartitionInfo: map[int64]string{ @@ -240,7 +227,7 @@ func TestAddLockedTables(t *testing.T) { } msg, err := AddLockedTables( - exec, + wrapAsSCtx(exec), tables, ) require.NoError(t, err) @@ -252,41 +239,28 @@ func TestAddLockedPartitions(t *testing.T) { defer ctrl.Finish() exec := mock.NewMockRestrictedSQLExecutor(ctrl) - // Executed SQL should be: - exec.EXPECT().ExecRestrictedSQL( - gomock.All(&ctxMatcher{}), - useCurrentSession, - gomock.Eq("BEGIN PESSIMISTIC"), - ) - // No table is locked. exec.EXPECT().ExecRestrictedSQL( gomock.All(&ctxMatcher{}), - useCurrentSession, + util.UseCurrentSessionOpt, selectSQL, ).Return(nil, nil, nil) exec.EXPECT().ExecRestrictedSQL( gomock.All(&ctxMatcher{}), - useCurrentSession, + util.UseCurrentSessionOpt, insertSQL, gomock.Eq([]interface{}{int64(2), int64(2)}), ) exec.EXPECT().ExecRestrictedSQL( gomock.All(&ctxMatcher{}), - useCurrentSession, + util.UseCurrentSessionOpt, insertSQL, gomock.Eq([]interface{}{int64(3), int64(3)}), ) - exec.EXPECT().ExecRestrictedSQL( - gomock.All(&ctxMatcher{}), - useCurrentSession, - "COMMIT", - ) - msg, err := AddLockedPartitions( - exec, + wrapAsSCtx(exec), 1, "test.t1", map[int64]string{ @@ -303,31 +277,18 @@ func TestAddLockedPartitionsFailed(t *testing.T) { defer ctrl.Finish() exec := mock.NewMockRestrictedSQLExecutor(ctrl) - // Executed SQL should be: - exec.EXPECT().ExecRestrictedSQL( - gomock.All(&ctxMatcher{}), - useCurrentSession, - gomock.Eq("BEGIN PESSIMISTIC"), - ) - // Return table 1 is locked. c := chunk.NewChunkWithCapacity([]*types.FieldType{types.NewFieldType(mysql.TypeLonglong)}, 1) c.AppendInt64(0, int64(1)) rows := []chunk.Row{c.GetRow(0)} exec.EXPECT().ExecRestrictedSQL( gomock.All(&ctxMatcher{}), - useCurrentSession, + util.UseCurrentSessionOpt, selectSQL, ).Return(rows, nil, nil) - exec.EXPECT().ExecRestrictedSQL( - gomock.All(&ctxMatcher{}), - useCurrentSession, - "COMMIT", - ) - msg, err := AddLockedPartitions( - exec, + wrapAsSCtx(exec), 1, "test.t1", map[int64]string{ diff --git a/statistics/handle/lockstats/query_lock.go b/statistics/handle/lockstats/query_lock.go index 5bccffb185..5b3b6c56ea 100644 --- a/statistics/handle/lockstats/query_lock.go +++ b/statistics/handle/lockstats/query_lock.go @@ -15,21 +15,16 @@ package lockstats import ( - "context" - - "github.com/pingcap/errors" - "github.com/pingcap/tidb/parser/terror" + "github.com/pingcap/tidb/sessionctx" "github.com/pingcap/tidb/statistics/handle/util" - "github.com/pingcap/tidb/util/sqlexec" ) const selectSQL = "SELECT table_id FROM mysql.stats_table_locked" // QueryLockedTables loads locked tables from mysql.stats_table_locked. // Return it as a map for fast query. -func QueryLockedTables(exec sqlexec.RestrictedSQLExecutor) (map[int64]struct{}, error) { - ctx := util.StatsCtx(context.Background()) - rows, _, err := exec.ExecRestrictedSQL(ctx, useCurrentSession, selectSQL) +func QueryLockedTables(sctx sessionctx.Context) (map[int64]struct{}, error) { + rows, _, err := util.ExecRows(sctx, selectSQL) if err != nil { return nil, err } @@ -56,19 +51,3 @@ func GetLockedTables(tableLocked map[int64]struct{}, tableIDs ...int64) map[int6 return lockedTables } - -func startTransaction(ctx context.Context, exec sqlexec.RestrictedSQLExecutor) error { - _, _, err := exec.ExecRestrictedSQL(ctx, useCurrentSession, "BEGIN PESSIMISTIC") - return errors.Trace(err) -} - -// finishTransaction will execute `commit` when error is nil, otherwise `rollback`. -func finishTransaction(ctx context.Context, exec sqlexec.RestrictedSQLExecutor, err error) error { - if err == nil { - _, _, err = exec.ExecRestrictedSQL(ctx, useCurrentSession, "COMMIT") - } else { - _, _, err1 := exec.ExecRestrictedSQL(ctx, useCurrentSession, "ROLLBACK") - terror.Log(errors.Trace(err1)) - } - return errors.Trace(err) -} diff --git a/statistics/handle/lockstats/query_lock_test.go b/statistics/handle/lockstats/query_lock_test.go index d0063fb075..eac173b4e6 100644 --- a/statistics/handle/lockstats/query_lock_test.go +++ b/statistics/handle/lockstats/query_lock_test.go @@ -21,6 +21,7 @@ import ( "github.com/pingcap/errors" "github.com/pingcap/tidb/kv" "github.com/pingcap/tidb/parser/mysql" + statsutil "github.com/pingcap/tidb/statistics/handle/util" "github.com/pingcap/tidb/types" "github.com/pingcap/tidb/util/chunk" "github.com/pingcap/tidb/util/sqlexec/mock" @@ -128,10 +129,10 @@ func executeQueryLockedTables(exec *mock.MockRestrictedSQLExecutor, numRows int, if wantErr { exec.EXPECT().ExecRestrictedSQL( gomock.All(&ctxMatcher{}), - useCurrentSession, + statsutil.UseCurrentSessionOpt, selectSQL, ).Return(nil, nil, errors.New("error")) - return QueryLockedTables(exec) + return QueryLockedTables(wrapAsSCtx(exec)) } c := chunk.NewChunkWithCapacity([]*types.FieldType{types.NewFieldType(mysql.TypeLonglong)}, numRows) @@ -144,9 +145,9 @@ func executeQueryLockedTables(exec *mock.MockRestrictedSQLExecutor, numRows int, } exec.EXPECT().ExecRestrictedSQL( gomock.All(&ctxMatcher{}), - useCurrentSession, + statsutil.UseCurrentSessionOpt, selectSQL, ).Return(rows, nil, nil) - return QueryLockedTables(exec) + return QueryLockedTables(wrapAsSCtx(exec)) } diff --git a/statistics/handle/lockstats/unlock_stats.go b/statistics/handle/lockstats/unlock_stats.go index 8e8deb798e..ee3caccbed 100644 --- a/statistics/handle/lockstats/unlock_stats.go +++ b/statistics/handle/lockstats/unlock_stats.go @@ -15,12 +15,10 @@ package lockstats import ( - "context" - "github.com/pingcap/errors" + "github.com/pingcap/tidb/sessionctx" "github.com/pingcap/tidb/statistics/handle/cache" "github.com/pingcap/tidb/statistics/handle/util" - "github.com/pingcap/tidb/util/sqlexec" "go.uber.org/zap" ) @@ -37,21 +35,11 @@ const ( // - tables: tables of which will be unlocked. // Return the message of skipped tables and error. func RemoveLockedTables( - exec sqlexec.RestrictedSQLExecutor, - tables map[int64]*TableInfo, + sctx sessionctx.Context, + tables map[int64]*util.StatsLockTable, ) (string, error) { - ctx := util.StatsCtx(context.Background()) - err := startTransaction(ctx, exec) - if err != nil { - return "", err - } - defer func() { - // Commit or rollback the transaction. - err = finishTransaction(ctx, exec, err) - }() - // Load tables to check locked before delete. - lockedTables, err := QueryLockedTables(exec) + lockedTables, err := QueryLockedTables(sctx) if err != nil { return "", err } @@ -75,7 +63,7 @@ func RemoveLockedTables( skippedTables = append(skippedTables, table.FullName) continue } - if err := updateStatsAndUnlockTable(ctx, exec, tid); err != nil { + if err := updateStatsAndUnlockTable(sctx, tid); err != nil { return "", err } @@ -84,7 +72,7 @@ func RemoveLockedTables( if _, ok := lockedTablesAndPartitions[pid]; !ok { continue } - if err := updateStatsAndUnlockPartition(ctx, exec, pid, tid); err != nil { + if err := updateStatsAndUnlockPartition(sctx, pid, tid); err != nil { return "", err } } @@ -102,23 +90,13 @@ func RemoveLockedTables( // - pidNames: partition ids of which will be unlocked. // Return the message of skipped tables and error. func RemoveLockedPartitions( - exec sqlexec.RestrictedSQLExecutor, + sctx sessionctx.Context, tid int64, tableName string, pidNames map[int64]string, ) (string, error) { - ctx := util.StatsCtx(context.Background()) - err := startTransaction(ctx, exec) - if err != nil { - return "", err - } - defer func() { - // Commit or rollback the transaction. - err = finishTransaction(ctx, exec, err) - }() - // Load tables to check locked before delete. - lockedTables, err := QueryLockedTables(exec) + lockedTables, err := QueryLockedTables(sctx) if err != nil { return "", err } @@ -149,7 +127,7 @@ func RemoveLockedPartitions( skippedPartitions = append(skippedPartitions, pidNames[pid]) continue } - if err := updateStatsAndUnlockPartition(ctx, exec, pid, tid); err != nil { + if err := updateStatsAndUnlockPartition(sctx, pid, tid); err != nil { return "", err } } @@ -159,10 +137,8 @@ func RemoveLockedPartitions( return msg, err } -func updateDelta(ctx context.Context, exec sqlexec.RestrictedSQLExecutor, count, modifyCount int64, version uint64, tid int64) error { - if _, _, err := exec.ExecRestrictedSQL( - ctx, - useCurrentSession, +func updateDelta(sctx sessionctx.Context, count, modifyCount int64, version uint64, tid int64) error { + if _, _, err := util.ExecRows(sctx, updateDeltaSQL, version, count, count, modifyCount, tid, ); err != nil { @@ -172,44 +148,42 @@ func updateDelta(ctx context.Context, exec sqlexec.RestrictedSQLExecutor, count, return nil } -func updateStatsAndUnlockTable(ctx context.Context, exec sqlexec.RestrictedSQLExecutor, tid int64) error { - count, modifyCount, version, err := getStatsDeltaFromTableLocked(ctx, tid, exec) +func updateStatsAndUnlockTable(sctx sessionctx.Context, tid int64) error { + count, modifyCount, version, err := getStatsDeltaFromTableLocked(sctx, tid) if err != nil { return err } - if err := updateDelta(ctx, exec, count, modifyCount, version, tid); err != nil { + if err := updateDelta(sctx, count, modifyCount, version, tid); err != nil { return err } cache.TableRowStatsCache.Invalidate(tid) - _, _, err = exec.ExecRestrictedSQL( - ctx, - useCurrentSession, + _, _, err = util.ExecRows( + sctx, DeleteLockSQL, tid, ) return err } // updateStatsAndUnlockPartition also update the stats to the table level. -func updateStatsAndUnlockPartition(ctx context.Context, exec sqlexec.RestrictedSQLExecutor, partitionID int64, tid int64) error { - count, modifyCount, version, err := getStatsDeltaFromTableLocked(ctx, partitionID, exec) +func updateStatsAndUnlockPartition(sctx sessionctx.Context, partitionID int64, tid int64) error { + count, modifyCount, version, err := getStatsDeltaFromTableLocked(sctx, partitionID) if err != nil { return err } - if err := updateDelta(ctx, exec, count, modifyCount, version, partitionID); err != nil { + if err := updateDelta(sctx, count, modifyCount, version, partitionID); err != nil { return err } cache.TableRowStatsCache.Invalidate(partitionID) - if err := updateDelta(ctx, exec, count, modifyCount, version, tid); err != nil { + if err := updateDelta(sctx, count, modifyCount, version, tid); err != nil { return err } cache.TableRowStatsCache.Invalidate(tid) - _, _, err = exec.ExecRestrictedSQL( - ctx, - useCurrentSession, + _, _, err = util.ExecRows( + sctx, DeleteLockSQL, partitionID, ) @@ -217,10 +191,9 @@ func updateStatsAndUnlockPartition(ctx context.Context, exec sqlexec.RestrictedS } // getStatsDeltaFromTableLocked get count, modify_count and version for the given table from mysql.stats_table_locked. -func getStatsDeltaFromTableLocked(ctx context.Context, tableID int64, exec sqlexec.RestrictedSQLExecutor) (count, modifyCount int64, version uint64, err error) { - rows, _, err := exec.ExecRestrictedSQL( - ctx, - useCurrentSession, +func getStatsDeltaFromTableLocked(sctx sessionctx.Context, tableID int64) (count, modifyCount int64, version uint64, err error) { + rows, _, err := util.ExecRows( + sctx, selectDeltaSQL, tableID, ) if err != nil { diff --git a/statistics/handle/lockstats/unlock_stats_test.go b/statistics/handle/lockstats/unlock_stats_test.go index 32197fb640..306ffa3328 100644 --- a/statistics/handle/lockstats/unlock_stats_test.go +++ b/statistics/handle/lockstats/unlock_stats_test.go @@ -15,18 +15,26 @@ package lockstats import ( - "context" "testing" "github.com/pingcap/errors" "github.com/pingcap/tidb/parser/mysql" + "github.com/pingcap/tidb/sessionctx" + "github.com/pingcap/tidb/statistics/handle/util" "github.com/pingcap/tidb/types" "github.com/pingcap/tidb/util/chunk" + mockctx "github.com/pingcap/tidb/util/mock" "github.com/pingcap/tidb/util/sqlexec/mock" "github.com/stretchr/testify/require" "go.uber.org/mock/gomock" ) +func wrapAsSCtx(exec *mock.MockRestrictedSQLExecutor) sessionctx.Context { + sctx := mockctx.NewContext() + sctx.SetValue(mock.MockRestrictedSQLExecutorKey{}, exec) + return sctx +} + func TestGetStatsDeltaFromTableLocked(t *testing.T) { ctrl := gomock.NewController(t) defer ctrl.Finish() @@ -70,15 +78,14 @@ func TestGetStatsDeltaFromTableLocked(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - ctx := context.Background() exec.EXPECT().ExecRestrictedSQL( - ctx, - useCurrentSession, + util.StatsCtx, + util.UseCurrentSessionOpt, selectDeltaSQL, gomock.Eq([]interface{}{int64(1)}), ).Return(tt.execResult, nil, tt.execError) - count, modifyCount, version, err := getStatsDeltaFromTableLocked(ctx, 1, exec) + count, modifyCount, version, err := getStatsDeltaFromTableLocked(wrapAsSCtx(exec), 1) if tt.execError != nil { require.Equal(t, tt.execError.Error(), err.Error()) } else { @@ -130,37 +137,36 @@ func TestUpdateStatsAndUnlockTable(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - ctx := context.Background() exec.EXPECT().ExecRestrictedSQL( - ctx, - useCurrentSession, + util.StatsCtx, + util.UseCurrentSessionOpt, selectDeltaSQL, gomock.Eq([]interface{}{tt.tableID}), ).Return([]chunk.Row{createStatsDeltaRow(1, 1, 1000)}, nil, nil) if tt.execError == nil { exec.EXPECT().ExecRestrictedSQL( - ctx, - useCurrentSession, + util.StatsCtx, + util.UseCurrentSessionOpt, updateDeltaSQL, gomock.Eq([]interface{}{uint64(1000), int64(1), int64(1), int64(1), int64(1)}), ).Return(nil, nil, nil) exec.EXPECT().ExecRestrictedSQL( - ctx, - useCurrentSession, + util.StatsCtx, + util.UseCurrentSessionOpt, DeleteLockSQL, gomock.Eq([]interface{}{tt.tableID}), ).Return(nil, nil, nil) } else { exec.EXPECT().ExecRestrictedSQL( - ctx, - useCurrentSession, + util.StatsCtx, + util.UseCurrentSessionOpt, updateDeltaSQL, gomock.Eq([]interface{}{uint64(1000), int64(1), int64(1), int64(1), int64(1)}), ).Return(nil, nil, tt.execError) } - err := updateStatsAndUnlockTable(ctx, exec, tt.tableID) + err := updateStatsAndUnlockTable(wrapAsSCtx(exec), tt.tableID) if tt.execError != nil { require.Equal(t, tt.execError.Error(), err.Error()) } else { @@ -175,13 +181,6 @@ func TestRemoveLockedTables(t *testing.T) { defer ctrl.Finish() exec := mock.NewMockRestrictedSQLExecutor(ctrl) - // Executed SQL should be: - exec.EXPECT().ExecRestrictedSQL( - gomock.All(&ctxMatcher{}), - useCurrentSession, - gomock.Eq("BEGIN PESSIMISTIC"), - ) - // Return table 1 and partition p1 are locked. table := chunk.NewChunkWithCapacity([]*types.FieldType{types.NewFieldType(mysql.TypeLonglong)}, 1) table.AppendInt64(0, int64(1)) @@ -190,67 +189,61 @@ func TestRemoveLockedTables(t *testing.T) { rows := []chunk.Row{table.GetRow(0), partition.GetRow(0)} exec.EXPECT().ExecRestrictedSQL( gomock.All(&ctxMatcher{}), - useCurrentSession, + util.UseCurrentSessionOpt, selectSQL, ).Return(rows, nil, nil) // No rows returned for table 1, because the delta is only stored in partition p1. exec.EXPECT().ExecRestrictedSQL( gomock.All(&ctxMatcher{}), - useCurrentSession, + util.UseCurrentSessionOpt, selectDeltaSQL, gomock.Eq([]interface{}{int64(1)}), ).Return([]chunk.Row{}, nil, nil) exec.EXPECT().ExecRestrictedSQL( gomock.All(&ctxMatcher{}), - useCurrentSession, + util.UseCurrentSessionOpt, updateDeltaSQL, gomock.Eq([]interface{}{uint64(0), int64(0), int64(0), int64(0), int64(1)}), ).Return(nil, nil, nil) exec.EXPECT().ExecRestrictedSQL( gomock.All(&ctxMatcher{}), - useCurrentSession, + util.UseCurrentSessionOpt, DeleteLockSQL, gomock.Eq([]interface{}{int64(1)}), ).Return(nil, nil, nil) exec.EXPECT().ExecRestrictedSQL( gomock.All(&ctxMatcher{}), - useCurrentSession, + util.UseCurrentSessionOpt, selectDeltaSQL, gomock.Eq([]interface{}{int64(4)}), ).Return([]chunk.Row{createStatsDeltaRow(1, 1, 1000)}, nil, nil) exec.EXPECT().ExecRestrictedSQL( gomock.All(&ctxMatcher{}), - useCurrentSession, + util.UseCurrentSessionOpt, updateDeltaSQL, gomock.Eq([]interface{}{uint64(1000), int64(1), int64(1), int64(1), int64(4)}), ).Return(nil, nil, nil) // Patch the delta to table 1 from partition p1. exec.EXPECT().ExecRestrictedSQL( gomock.All(&ctxMatcher{}), - useCurrentSession, + util.UseCurrentSessionOpt, updateDeltaSQL, gomock.Eq([]interface{}{uint64(1000), int64(1), int64(1), int64(1), int64(1)}), ).Return(nil, nil, nil) exec.EXPECT().ExecRestrictedSQL( gomock.All(&ctxMatcher{}), - useCurrentSession, + util.UseCurrentSessionOpt, DeleteLockSQL, gomock.Eq([]interface{}{int64(4)}), ).Return(nil, nil, nil) - exec.EXPECT().ExecRestrictedSQL( - gomock.All(&ctxMatcher{}), - useCurrentSession, - "COMMIT", - ) - - tables := map[int64]*TableInfo{ + tables := map[int64]*util.StatsLockTable{ 1: { FullName: "test.t1", PartitionInfo: map[int64]string{ @@ -266,7 +259,7 @@ func TestRemoveLockedTables(t *testing.T) { } msg, err := RemoveLockedTables( - exec, + wrapAsSCtx(exec), tables, ) require.NoError(t, err) @@ -278,63 +271,50 @@ func TestRemoveLockedPartitions(t *testing.T) { defer ctrl.Finish() exec := mock.NewMockRestrictedSQLExecutor(ctrl) - // Executed SQL should be: - exec.EXPECT().ExecRestrictedSQL( - gomock.All(&ctxMatcher{}), - useCurrentSession, - gomock.Eq("BEGIN PESSIMISTIC"), - ) - // Return table 2 is locked. c := chunk.NewChunkWithCapacity([]*types.FieldType{types.NewFieldType(mysql.TypeLonglong)}, 1) c.AppendInt64(0, int64(2)) rows := []chunk.Row{c.GetRow(0)} exec.EXPECT().ExecRestrictedSQL( gomock.All(&ctxMatcher{}), - useCurrentSession, + util.UseCurrentSessionOpt, selectSQL, ).Return(rows, nil, nil) exec.EXPECT().ExecRestrictedSQL( gomock.All(&ctxMatcher{}), - useCurrentSession, + util.UseCurrentSessionOpt, selectDeltaSQL, gomock.Eq([]interface{}{int64(2)}), ).Return([]chunk.Row{createStatsDeltaRow(1, 1, 1000)}, nil, nil) exec.EXPECT().ExecRestrictedSQL( gomock.All(&ctxMatcher{}), - useCurrentSession, + util.UseCurrentSessionOpt, updateDeltaSQL, gomock.Eq([]interface{}{uint64(1000), int64(1), int64(1), int64(1), int64(2)}), ).Return(nil, nil, nil) exec.EXPECT().ExecRestrictedSQL( gomock.All(&ctxMatcher{}), - useCurrentSession, + util.UseCurrentSessionOpt, updateDeltaSQL, gomock.Eq([]interface{}{uint64(1000), int64(1), int64(1), int64(1), int64(1)}), ).Return(nil, nil, nil) exec.EXPECT().ExecRestrictedSQL( gomock.All(&ctxMatcher{}), - useCurrentSession, + util.UseCurrentSessionOpt, DeleteLockSQL, gomock.Eq([]interface{}{int64(2)}), ).Return(nil, nil, nil) - exec.EXPECT().ExecRestrictedSQL( - gomock.All(&ctxMatcher{}), - useCurrentSession, - "COMMIT", - ) - pidAndNames := map[int64]string{ 2: "p1", } msg, err := RemoveLockedPartitions( - exec, + wrapAsSCtx(exec), 1, "test.t1", pidAndNames, @@ -348,35 +328,22 @@ func TestRemoveLockedPartitionsFailedIfTheWholeTableIsLocked(t *testing.T) { defer ctrl.Finish() exec := mock.NewMockRestrictedSQLExecutor(ctrl) - // Executed SQL should be: - exec.EXPECT().ExecRestrictedSQL( - gomock.All(&ctxMatcher{}), - useCurrentSession, - gomock.Eq("BEGIN PESSIMISTIC"), - ) - // Return table 2 is locked. c := chunk.NewChunkWithCapacity([]*types.FieldType{types.NewFieldType(mysql.TypeLonglong)}, 1) c.AppendInt64(0, int64(1)) rows := []chunk.Row{c.GetRow(0)} exec.EXPECT().ExecRestrictedSQL( gomock.All(&ctxMatcher{}), - useCurrentSession, + util.UseCurrentSessionOpt, selectSQL, ).Return(rows, nil, nil) - exec.EXPECT().ExecRestrictedSQL( - gomock.All(&ctxMatcher{}), - useCurrentSession, - "COMMIT", - ) - pidAndNames := map[int64]string{ 2: "p1", } msg, err := RemoveLockedPartitions( - exec, + wrapAsSCtx(exec), 1, "test.t1", pidAndNames, diff --git a/statistics/handle/storage/save.go b/statistics/handle/storage/save.go index 5045a06ca2..9ae757f62d 100644 --- a/statistics/handle/storage/save.go +++ b/statistics/handle/storage/save.go @@ -15,7 +15,6 @@ package storage import ( - "context" "encoding/json" "fmt" "strings" @@ -125,7 +124,7 @@ func SaveTableStatsToStorage(sctx sessionctx.Context, results *statistics.AnalyzeResults, analyzeSnapshot bool) (statsVer uint64, err error) { needDumpFMS := results.TableID.IsPartitionTable() tableID := results.TableID.GetStatisticsID() - ctx := util.StatsCtx(context.Background()) + ctx := util.StatsCtx _, err = util.Exec(sctx, "begin pessimistic") if err != nil { return 0, err diff --git a/statistics/handle/util/BUILD.bazel b/statistics/handle/util/BUILD.bazel index 3c1fd5fdf5..f2e338ea68 100644 --- a/statistics/handle/util/BUILD.bazel +++ b/statistics/handle/util/BUILD.bazel @@ -21,7 +21,9 @@ go_library( "//table", "//types", "//util/chunk", + "//util/intest", "//util/sqlexec", + "//util/sqlexec/mock", "@com_github_ngaut_pools//:pools", "@com_github_pingcap_errors//:errors", "@com_github_tikv_client_go_v2//oracle", diff --git a/statistics/handle/util/interfaces.go b/statistics/handle/util/interfaces.go index 274223d49c..11ab78a601 100644 --- a/statistics/handle/util/interfaces.go +++ b/statistics/handle/util/interfaces.go @@ -109,7 +109,7 @@ type StatsAnalyze interface { // TODO: HandleAutoAnalyze } -// StatsCache is used to manage all table statistics in memory +// StatsCache is used to manage all table statistics in memory. type StatsCache interface { // Close closes this cache. Close() @@ -145,3 +145,54 @@ type StatsCache interface { // Replace replaces this cache. Replace(cache StatsCache) } + +// StatsLockTable is the table info of which will be locked. +type StatsLockTable struct { + PartitionInfo map[int64]string + // schema name + table name. + FullName string +} + +// StatsLock is used to manage locked stats. +type StatsLock interface { + // LockTables add locked tables id to store. + // - tables: tables that will be locked. + // Return the message of skipped tables and error. + LockTables(tables map[int64]*StatsLockTable) (skipped string, err error) + + // LockPartitions add locked partitions id to store. + // If the whole table is locked, then skip all partitions of the table. + // - tid: table id of which will be locked. + // - tableName: table name of which will be locked. + // - pidNames: partition ids of which will be locked. + // Return the message of skipped tables and error. + // Note: If the whole table is locked, then skip all partitions of the table. + LockPartitions( + tid int64, + tableName string, + pidNames map[int64]string, + ) (skipped string, err error) + + // RemoveLockedTables remove tables from table locked records. + // - tables: tables of which will be unlocked. + // Return the message of skipped tables and error. + RemoveLockedTables(tables map[int64]*StatsLockTable) (skipped string, err error) + + // RemoveLockedPartitions remove partitions from table locked records. + // - tid: table id of which will be unlocked. + // - tableName: table name of which will be unlocked. + // - pidNames: partition ids of which will be unlocked. + // Note: If the whole table is locked, then skip all partitions of the table. + RemoveLockedPartitions( + tid int64, + tableName string, + pidNames map[int64]string, + ) (skipped string, err error) + + // GetLockedTables returns the locked status of the given tables. + // Note: This function query locked tables from store, so please try to batch the query. + GetLockedTables(tableIDs ...int64) (map[int64]struct{}, error) + + // GetTableLockedAndClearForTest for unit test only. + GetTableLockedAndClearForTest() (map[int64]struct{}, error) +} diff --git a/statistics/handle/util/util.go b/statistics/handle/util/util.go index dd58144688..0fa0a7d953 100644 --- a/statistics/handle/util/util.go +++ b/statistics/handle/util/util.go @@ -27,27 +27,32 @@ import ( "github.com/pingcap/tidb/sessionctx" "github.com/pingcap/tidb/sessionctx/variable" "github.com/pingcap/tidb/util/chunk" + "github.com/pingcap/tidb/util/intest" "github.com/pingcap/tidb/util/sqlexec" + "github.com/pingcap/tidb/util/sqlexec/mock" "github.com/tikv/client-go/v2/oracle" ) +var ( + // UseCurrentSessionOpt to make sure the sql is executed in current session. + UseCurrentSessionOpt = []sqlexec.OptionFuncAlias{sqlexec.ExecOptionUseCurSession} + + // StatsCtx is used to mark the request is from stats module. + StatsCtx = kv.WithInternalSourceType(context.Background(), kv.InternalTxnStats) +) + // SessionPool is used to recycle sessionctx. type SessionPool interface { Get() (pools.Resource, error) Put(pools.Resource) } -// StatsCtx is used to mark the request is from stats module. -func StatsCtx(ctx context.Context) context.Context { - return kv.WithInternalSourceType(ctx, kv.InternalTxnStats) -} - // FinishTransaction will execute `commit` when error is nil, otherwise `rollback`. func FinishTransaction(sctx sessionctx.Context, err error) error { if err == nil { - _, err = Exec(sctx, "commit") + _, _, err = ExecRows(sctx, "commit") } else { - _, err1 := Exec(sctx, "rollback") + _, _, err1 := ExecRows(sctx, "rollback") terror.Log(errors.Trace(err1)) } return errors.Trace(err) @@ -150,7 +155,7 @@ func UpdateSCtxVarsForStats(sctx sessionctx.Context) error { // WrapTxn uses a transaction here can let different SQLs in this operation have the same data visibility. func WrapTxn(sctx sessionctx.Context, f func(sctx sessionctx.Context) error) (err error) { // TODO: check whether this sctx is already in a txn - if _, err := Exec(sctx, "begin"); err != nil { + if _, _, err := ExecRows(sctx, "begin"); err != nil { return err } defer func() { @@ -176,16 +181,23 @@ func Exec(sctx sessionctx.Context, sql string, args ...interface{}) (sqlexec.Rec return nil, errors.Errorf("invalid sql executor") } // TODO: use RestrictedSQLExecutor + ExecOptionUseCurSession instead of SQLExecutor - return sqlExec.ExecuteInternal(StatsCtx(context.Background()), sql, args...) + return sqlExec.ExecuteInternal(StatsCtx, sql, args...) } // ExecRows is a helper function to execute sql and return rows and fields. func ExecRows(sctx sessionctx.Context, sql string, args ...interface{}) (rows []chunk.Row, fields []*ast.ResultField, err error) { + if intest.InTest { + if v := sctx.Value(mock.MockRestrictedSQLExecutorKey{}); v != nil { + return v.(*mock.MockRestrictedSQLExecutor).ExecRestrictedSQL(StatsCtx, + UseCurrentSessionOpt, sql, args...) + } + } + sqlExec, ok := sctx.(sqlexec.RestrictedSQLExecutor) if !ok { return nil, nil, errors.Errorf("invalid sql executor") } - return sqlExec.ExecRestrictedSQL(StatsCtx(context.Background()), []sqlexec.OptionFuncAlias{sqlexec.ExecOptionUseCurSession}, sql, args...) + return sqlExec.ExecRestrictedSQL(StatsCtx, UseCurrentSessionOpt, sql, args...) } // ExecWithOpts is a helper function to execute sql and return rows and fields. @@ -194,7 +206,7 @@ func ExecWithOpts(sctx sessionctx.Context, opts []sqlexec.OptionFuncAlias, sql s if !ok { return nil, nil, errors.Errorf("invalid sql executor") } - return sqlExec.ExecRestrictedSQL(StatsCtx(context.Background()), opts, sql, args...) + return sqlExec.ExecRestrictedSQL(StatsCtx, opts, sql, args...) } // DurationToTS converts duration to timestamp. diff --git a/util/sqlexec/mock/restricted_sql_executor_mock.go b/util/sqlexec/mock/restricted_sql_executor_mock.go index 1321bf09fd..1ebf1079db 100644 --- a/util/sqlexec/mock/restricted_sql_executor_mock.go +++ b/util/sqlexec/mock/restricted_sql_executor_mock.go @@ -14,6 +14,14 @@ import ( gomock "go.uber.org/mock/gomock" ) +// MockRestrictedSQLExecutorKey is the key to represent MockRestrictedSQLExecutorMockRecorder in ctx. +type MockRestrictedSQLExecutorKey struct{} + +// String implements the string.Stringer interface. +func (k MockRestrictedSQLExecutorKey) String() string { + return "__MockRestrictedSQLExecutor" +} + // MockRestrictedSQLExecutor is a mock of RestrictedSQLExecutor interface. type MockRestrictedSQLExecutor struct { ctrl *gomock.Controller