diff --git a/session/session.go b/session/session.go index 47ea068659..dfb6853cfb 100644 --- a/session/session.go +++ b/session/session.go @@ -813,120 +813,6 @@ func (s *session) sysSessionPool() sessionPool { return domain.GetDomain(s).SysSessionPool() } -// ExecRestrictedSQL implements RestrictedSQLExecutor interface. -// This is used for executing some restricted sql statements, usually executed during a normal statement execution. -// Unlike normal Exec, it doesn't reset statement status, doesn't commit or rollback the current transaction -// and doesn't write binlog. -func (s *session) ExecRestrictedSQL(sql string) ([]chunk.Row, []*ast.ResultField, error) { - return s.ExecRestrictedSQLWithContext(context.TODO(), sql) -} - -// ExecRestrictedSQLWithContext implements RestrictedSQLExecutor interface. -func (s *session) ExecRestrictedSQLWithContext(ctx context.Context, sql string, opts ...sqlexec.OptionFuncAlias) ( - []chunk.Row, []*ast.ResultField, error) { - var execOption sqlexec.ExecOption - for _, opt := range opts { - opt(&execOption) - } - // Use special session to execute the sql. - tmp, err := s.sysSessionPool().Get() - if err != nil { - return nil, nil, err - } - defer s.sysSessionPool().Put(tmp) - se := tmp.(*session) - // The special session will share the `InspectionTableCache` with current session - // if the current session in inspection mode. - if cache := s.sessionVars.InspectionTableCache; cache != nil { - se.sessionVars.InspectionTableCache = cache - defer func() { se.sessionVars.InspectionTableCache = nil }() - } - if ok := s.sessionVars.OptimizerUseInvisibleIndexes; ok { - se.sessionVars.OptimizerUseInvisibleIndexes = true - defer func() { se.sessionVars.OptimizerUseInvisibleIndexes = false }() - } - prePruneMode := se.sessionVars.PartitionPruneMode.Load() - defer func() { - if !execOption.IgnoreWarning { - if se != nil && se.GetSessionVars().StmtCtx.WarningCount() > 0 { - warnings := se.GetSessionVars().StmtCtx.GetWarnings() - s.GetSessionVars().StmtCtx.AppendWarnings(warnings) - } - } - se.sessionVars.PartitionPruneMode.Store(prePruneMode) - }() - - if execOption.SnapshotTS != 0 { - se.sessionVars.SnapshotInfoschema, err = domain.GetDomain(s).GetSnapshotInfoSchema(execOption.SnapshotTS) - if err != nil { - return nil, nil, err - } - if err := se.sessionVars.SetSystemVar(variable.TiDBSnapshot, strconv.FormatUint(execOption.SnapshotTS, 10)); err != nil { - return nil, nil, err - } - defer func() { - if err := se.sessionVars.SetSystemVar(variable.TiDBSnapshot, ""); err != nil { - logutil.BgLogger().Error("set tidbSnapshot error", zap.Error(err)) - } - se.sessionVars.SnapshotInfoschema = nil - }() - } - - if execOption.AnalyzeVer != 0 { - oldStatsVer := se.GetSessionVars().AnalyzeVersion - se.GetSessionVars().AnalyzeVersion = execOption.AnalyzeVer - defer func() { - se.GetSessionVars().AnalyzeVersion = oldStatsVer - }() - } - - // for analyze stmt we need let worker session follow user session that executing stmt. - se.sessionVars.PartitionPruneMode.Store(s.sessionVars.PartitionPruneMode.Load()) - metrics.SessionRestrictedSQLCounter.Inc() - - return execRestrictedSQL(ctx, se, sql) -} - -// ExecRestrictedSQLWithSnapshot implements RestrictedSQLExecutor interface. -// This is used for executing some restricted sql statements with snapshot. -// If current session sets the snapshot timestamp, then execute with this snapshot timestamp. -// Otherwise, execute with the current transaction start timestamp if the transaction is valid. -func (s *session) ExecRestrictedSQLWithSnapshot(sql string) ([]chunk.Row, []*ast.ResultField, error) { - var snapshot uint64 - txn, err := s.Txn(false) - if err != nil { - return nil, nil, err - } - if txn.Valid() { - snapshot = s.txn.StartTS() - } - if s.sessionVars.SnapshotTS != 0 { - snapshot = s.sessionVars.SnapshotTS - } - return s.ExecRestrictedSQLWithContext(context.TODO(), sql, sqlexec.ExecOptionWithSnapshot(snapshot)) -} - -func execRestrictedSQL(ctx context.Context, se *session, sql string) ([]chunk.Row, []*ast.ResultField, error) { - ctx = context.WithValue(ctx, execdetails.StmtExecDetailKey, &execdetails.StmtExecDetails{}) - startTime := time.Now() - rs, err := se.ExecuteInternal(ctx, sql) - if rs != nil { - defer terror.Call(rs.Close) - } - if err != nil || rs == nil { - return nil, nil, err - } - - // Execute all recordset, take out the first one as result. - rows, err := drainRecordSet(ctx, se, rs) - if err != nil { - return nil, nil, err - } - - metrics.QueryDurationHistogram.WithLabelValues(metrics.LblInternal).Observe(time.Since(startTime).Seconds()) - return rows, rs.Fields(), err -} - func createSessionFunc(store kv.Storage) pools.Factory { return func() (pools.Resource, error) { se, err := createSession(store) diff --git a/session/session_test.go b/session/session_test.go index 9bc6bb3dfd..ee0a683044 100644 --- a/session/session_test.go +++ b/session/session_test.go @@ -870,13 +870,6 @@ func (s *testSessionSuite) TestDatabase(c *C) { tk.MustExec("drop schema if exists xxx") } -func (s *testSessionSuite) TestExecRestrictedSQL(c *C) { - tk := testkit.NewTestKitWithInit(c, s.store) - r, _, err := tk.Se.(sqlexec.RestrictedSQLExecutor).ExecRestrictedSQL("select 1;") - c.Assert(err, IsNil) - c.Assert(len(r), Equals, 1) -} - // TestInTrans . See https://dev.mysql.com/doc/internals/en/status-flags.html func (s *testSessionSuite) TestInTrans(c *C) { tk := testkit.NewTestKitWithInit(c, s.store) diff --git a/session/tidb_test.go b/session/tidb_test.go index eff7406332..80da89191d 100644 --- a/session/tidb_test.go +++ b/session/tidb_test.go @@ -23,6 +23,7 @@ import ( "time" . "github.com/pingcap/check" + "github.com/pingcap/parser/ast" "github.com/pingcap/parser/auth" "github.com/pingcap/tidb/config" "github.com/pingcap/tidb/domain" @@ -81,17 +82,23 @@ func (s *testMainSuite) TestSysSessionPoolGoroutineLeak(c *C) { se, err := createSession(store) c.Assert(err, IsNil) + count := 200 + stmts := make([]ast.StmtNode, count) + for i := 0; i < count; i++ { + stmt, err := se.ParseWithParams(context.Background(), "select * from mysql.user limit 1") + c.Assert(err, IsNil) + stmts[i] = stmt + } // Test an issue that sysSessionPool doesn't call session's Close, cause // asyncGetTSWorker goroutine leak. - count := 200 var wg sync.WaitGroup wg.Add(count) for i := 0; i < count; i++ { - go func(se *session) { - _, _, err := se.ExecRestrictedSQL("select * from mysql.user limit 1") + go func(se *session, stmt ast.StmtNode) { + _, _, err := se.ExecRestrictedStmt(context.Background(), stmt) c.Assert(err, IsNil) wg.Done() - }(se) + }(se, stmts[i]) } wg.Wait() } diff --git a/util/sqlexec/restricted_sql_executor.go b/util/sqlexec/restricted_sql_executor.go index 58cccce4ea..e350118a77 100644 --- a/util/sqlexec/restricted_sql_executor.go +++ b/util/sqlexec/restricted_sql_executor.go @@ -31,21 +31,8 @@ import ( // And in the same time, we do not want this interface becomes a general way to run sql statement. // We hope this could be used with some restrictions such as only allowing system tables as target, // do not allowing recursion call. -// For more information please refer to the comments in session.ExecRestrictedSQL(). // This is implemented in session.go. type RestrictedSQLExecutor interface { - // ExecRestrictedSQL run sql statement in ctx with some restriction. - ExecRestrictedSQL(sql string) ([]chunk.Row, []*ast.ResultField, error) - // ExecRestrictedSQLWithContext run sql statement in ctx with some restriction. - ExecRestrictedSQLWithContext(ctx context.Context, sql string, opts ...OptionFuncAlias) ([]chunk.Row, []*ast.ResultField, error) - // ExecRestrictedSQLWithSnapshot run sql statement in ctx with some restriction and with snapshot. - // If current session sets the snapshot timestamp, then execute with this snapshot timestamp. - // Otherwise, execute with the current transaction start timestamp if the transaction is valid. - ExecRestrictedSQLWithSnapshot(sql string) ([]chunk.Row, []*ast.ResultField, error) - - // The above methods are all deprecated. - // After the refactor finish, they will be removed. - // ParseWithParams is the parameterized version of Parse: it will try to prevent injection under utf8mb4. // It works like printf() in c, there are following format specifiers: // 1. %?: automatic conversion by the type of arguments. E.g. []string -> ('s1','s2'..) @@ -60,32 +47,32 @@ type RestrictedSQLExecutor interface { ExecRestrictedStmt(ctx context.Context, stmt ast.StmtNode, opts ...OptionFuncAlias) ([]chunk.Row, []*ast.ResultField, error) } -// ExecOption is a struct defined for ExecRestrictedSQLWithContext option. +// ExecOption is a struct defined for ExecRestrictedStmt option. type ExecOption struct { IgnoreWarning bool SnapshotTS uint64 AnalyzeVer int } -// OptionFuncAlias is defined for the optional paramater of ExecRestrictedSQLWithContext. +// OptionFuncAlias is defined for the optional paramater of ExecRestrictedStmt. type OptionFuncAlias = func(option *ExecOption) -// ExecOptionIgnoreWarning tells ExecRestrictedSQLWithContext to ignore the warnings. +// ExecOptionIgnoreWarning tells ExecRestrictedStmt to ignore the warnings. var ExecOptionIgnoreWarning OptionFuncAlias = func(option *ExecOption) { option.IgnoreWarning = true } -// ExecOptionAnalyzeVer1 tells ExecRestrictedSQLWithContext to collect statistics with version1. +// ExecOptionAnalyzeVer1 tells ExecRestrictedStmt to collect statistics with version1. var ExecOptionAnalyzeVer1 OptionFuncAlias = func(option *ExecOption) { option.AnalyzeVer = 1 } -// ExecOptionAnalyzeVer2 tells ExecRestrictedSQLWithContext to collect statistics with version2. +// ExecOptionAnalyzeVer2 tells ExecRestrictedStmt to collect statistics with version2. var ExecOptionAnalyzeVer2 OptionFuncAlias = func(option *ExecOption) { option.AnalyzeVer = 2 } -// ExecOptionWithSnapshot tells ExecRestrictedSQLWithContext to use a snapshot. +// ExecOptionWithSnapshot tells ExecRestrictedStmt to use a snapshot. func ExecOptionWithSnapshot(snapshot uint64) OptionFuncAlias { return func(option *ExecOption) { option.SnapshotTS = snapshot