From e44277d8fec357ee0b2a2ebdb16fa3a9ffaf5620 Mon Sep 17 00:00:00 2001 From: djshow832 Date: Thu, 23 Jun 2022 12:46:36 +0800 Subject: [PATCH] sessionctx: support encoding and decoding session contexts (#35648) close pingcap/tidb#35573 --- session/session.go | 25 +- sessionctx/sessionstates/session_states.go | 33 ++- .../sessionstates/session_states_test.go | 231 +++++++++++++++++- sessionctx/variable/sequence_state.go | 20 ++ sessionctx/variable/session.go | 54 ++-- 5 files changed, 332 insertions(+), 31 deletions(-) diff --git a/session/session.go b/session/session.go index 8d03ba0870..c5c1ead4c6 100644 --- a/session/session.go +++ b/session/session.go @@ -2069,17 +2069,20 @@ func runStmt(ctx context.Context, se *session, s sqlexec.Statement) (rs sqlexec. sessVars := se.sessionVars // Record diagnostic information for DML statements - if _, ok := s.(*executor.ExecStmt).StmtNode.(ast.DMLNode); ok { - defer func() { - sessVars.LastQueryInfo = variable.QueryInfo{ - TxnScope: sessVars.CheckAndGetTxnScope(), - StartTS: sessVars.TxnCtx.StartTS, - ForUpdateTS: sessVars.TxnCtx.GetForUpdateTS(), - } - if err != nil { - sessVars.LastQueryInfo.ErrMsg = err.Error() - } - }() + if stmt, ok := s.(*executor.ExecStmt).StmtNode.(ast.DMLNode); ok { + // Keep the previous queryInfo for `show session_states` because the statement needs to encode it. + if showStmt, ok := stmt.(*ast.ShowStmt); !ok || showStmt.Tp != ast.ShowSessionStates { + defer func() { + sessVars.LastQueryInfo = sessionstates.QueryInfo{ + TxnScope: sessVars.CheckAndGetTxnScope(), + StartTS: sessVars.TxnCtx.StartTS, + ForUpdateTS: sessVars.TxnCtx.GetForUpdateTS(), + } + if err != nil { + sessVars.LastQueryInfo.ErrMsg = err.Error() + } + }() + } } // Save origTxnCtx here to avoid it reset in the transaction retry. diff --git a/sessionctx/sessionstates/session_states.go b/sessionctx/sessionstates/session_states.go index 312cf891ec..baf876ff87 100644 --- a/sessionctx/sessionstates/session_states.go +++ b/sessionctx/sessionstates/session_states.go @@ -15,14 +15,41 @@ package sessionstates import ( + "time" + ptypes "github.com/pingcap/tidb/parser/types" "github.com/pingcap/tidb/types" ) +// QueryInfo represents the information of last executed query. It's used to expose information for test purpose. +type QueryInfo struct { + TxnScope string `json:"txn_scope"` + StartTS uint64 `json:"start_ts"` + ForUpdateTS uint64 `json:"for_update_ts"` + ErrMsg string `json:"error,omitempty"` +} + +// LastDDLInfo represents the information of last DDL. It's used to expose information for test purpose. +type LastDDLInfo struct { + Query string `json:"query"` + SeqNum uint64 `json:"seq_num"` +} + // SessionStates contains all the states in the session that should be migrated when the session // is migrated to another server. It is shown by `show session_states` and recovered by `set session_states`. type SessionStates struct { - UserVars map[string]*types.Datum `json:"user-var-values,omitempty"` - UserVarTypes map[string]*ptypes.FieldType `json:"user-var-types,omitempty"` - SystemVars map[string]string `json:"sys-vars,omitempty"` + UserVars map[string]*types.Datum `json:"user-var-values,omitempty"` + UserVarTypes map[string]*ptypes.FieldType `json:"user-var-types,omitempty"` + SystemVars map[string]string `json:"sys-vars,omitempty"` + PreparedStmtID uint32 `json:"prepared-stmt-id,omitempty"` + Status uint16 `json:"status,omitempty"` + CurrentDB string `json:"current-db,omitempty"` + LastTxnInfo string `json:"txn-info,omitempty"` + LastQueryInfo *QueryInfo `json:"query-info,omitempty"` + LastDDLInfo *LastDDLInfo `json:"ddl-info,omitempty"` + LastFoundRows uint64 `json:"found-rows,omitempty"` + FoundInPlanCache bool `json:"in-plan-cache,omitempty"` + FoundInBinding bool `json:"in-binding,omitempty"` + SequenceLatestValues map[int64]int64 `json:"seq-values,omitempty"` + MPPStoreLastFailTime map[string]time.Time `json:"store-fail-time,omitempty"` } diff --git a/sessionctx/sessionstates/session_states_test.go b/sessionctx/sessionstates/session_states_test.go index 81e4cb6d52..847f50f4e9 100644 --- a/sessionctx/sessionstates/session_states_test.go +++ b/sessionctx/sessionstates/session_states_test.go @@ -19,8 +19,10 @@ import ( "strconv" "strings" "testing" + "time" "github.com/pingcap/tidb/errno" + "github.com/pingcap/tidb/parser/mysql" "github.com/pingcap/tidb/sessionctx/variable" "github.com/pingcap/tidb/testkit" "github.com/pingcap/tidb/util/sem" @@ -167,7 +169,10 @@ func TestSystemVars(t *testing.T) { }, } - sem.Enable() + if !sem.IsEnabled() { + sem.Enable() + defer sem.Disable() + } for _, tt := range tests { tk1 := testkit.NewTestKit(t, store) for _, stmt := range tt.stmts { @@ -206,6 +211,230 @@ func TestSystemVars(t *testing.T) { } } +func TestSessionCtx(t *testing.T) { + store, clean := testkit.CreateMockStore(t) + defer clean() + tk := testkit.NewTestKit(t, store) + tk.MustExec("create table test.t1(id int)") + + tests := []struct { + setFunc func(tk *testkit.TestKit) any + checkFunc func(tk *testkit.TestKit, param any) + }{ + { + // check PreparedStmtID + checkFunc: func(tk *testkit.TestKit, param any) { + require.Equal(t, uint32(1), tk.Session().GetSessionVars().GetNextPreparedStmtID()) + }, + }, + { + // check PreparedStmtID + setFunc: func(tk *testkit.TestKit) any { + tk.MustExec("prepare stmt from 'select ?'") + return nil + }, + checkFunc: func(tk *testkit.TestKit, param any) { + require.Equal(t, uint32(2), tk.Session().GetSessionVars().GetNextPreparedStmtID()) + }, + }, + { + // check Status + checkFunc: func(tk *testkit.TestKit, param any) { + require.Equal(t, mysql.ServerStatusAutocommit, tk.Session().GetSessionVars().Status&mysql.ServerStatusAutocommit) + }, + }, + { + // check Status + setFunc: func(tk *testkit.TestKit) any { + tk.MustExec("set autocommit=0") + return nil + }, + checkFunc: func(tk *testkit.TestKit, param any) { + require.Equal(t, uint16(0), tk.Session().GetSessionVars().Status&mysql.ServerStatusAutocommit) + }, + }, + { + // check CurrentDB + checkFunc: func(tk *testkit.TestKit, param any) { + tk.MustQuery("select database()").Check(testkit.Rows("")) + }, + }, + { + // check CurrentDB + setFunc: func(tk *testkit.TestKit) any { + tk.MustExec("use test") + return nil + }, + checkFunc: func(tk *testkit.TestKit, param any) { + tk.MustQuery("select database()").Check(testkit.Rows("test")) + }, + }, + { + // check LastTxnInfo + checkFunc: func(tk *testkit.TestKit, param any) { + tk.MustQuery("select @@tidb_last_txn_info").Check(testkit.Rows("")) + }, + }, + { + // check LastTxnInfo + setFunc: func(tk *testkit.TestKit) any { + tk.MustExec("begin") + tk.MustExec("insert test.t1 value(1)") + tk.MustExec("commit") + rows := tk.MustQuery("select @@tidb_last_txn_info").Rows() + require.NotEqual(t, "", rows[0][0].(string)) + return rows + }, + checkFunc: func(tk *testkit.TestKit, param any) { + tk.MustQuery("select @@tidb_last_txn_info").Check(param.([][]interface{})) + }, + }, + { + // check LastQueryInfo + setFunc: func(tk *testkit.TestKit) any { + rows := tk.MustQuery("select @@tidb_last_query_info").Rows() + require.NotEqual(t, "", rows[0][0].(string)) + return rows + }, + checkFunc: func(tk *testkit.TestKit, param any) { + tk.MustQuery("select @@tidb_last_query_info").Check(param.([][]interface{})) + }, + }, + { + // check LastQueryInfo + setFunc: func(tk *testkit.TestKit) any { + tk.MustQuery("select * from test.t1") + startTS := tk.Session().GetSessionVars().LastQueryInfo.StartTS + require.NotEqual(t, uint64(0), startTS) + return startTS + }, + checkFunc: func(tk *testkit.TestKit, param any) { + startTS := tk.Session().GetSessionVars().LastQueryInfo.StartTS + require.Equal(t, param.(uint64), startTS) + }, + }, + { + // check LastDDLInfo + setFunc: func(tk *testkit.TestKit) any { + rows := tk.MustQuery("select @@tidb_last_ddl_info").Rows() + require.NotEqual(t, "", rows[0][0].(string)) + return rows + }, + checkFunc: func(tk *testkit.TestKit, param any) { + tk.MustQuery("select @@tidb_last_ddl_info").Check(param.([][]interface{})) + }, + }, + { + // check LastDDLInfo + setFunc: func(tk *testkit.TestKit) any { + tk.MustExec("truncate table test.t1") + rows := tk.MustQuery("select @@tidb_last_ddl_info").Rows() + require.NotEqual(t, "", rows[0][0].(string)) + return rows + }, + checkFunc: func(tk *testkit.TestKit, param any) { + tk.MustQuery("select @@tidb_last_ddl_info").Check(param.([][]interface{})) + }, + }, + { + // check LastFoundRows + setFunc: func(tk *testkit.TestKit) any { + tk.MustExec("insert test.t1 value(1), (2), (3), (4), (5)") + // SQL_CALC_FOUND_ROWS is not supported now, so we just test normal select. + rows := tk.MustQuery("select * from test.t1 limit 3").Rows() + require.Equal(t, 3, len(rows)) + return "3" + }, + checkFunc: func(tk *testkit.TestKit, param any) { + tk.MustQuery("select found_rows()").Check(testkit.Rows(param.(string))) + }, + }, + { + // check SequenceState + setFunc: func(tk *testkit.TestKit) any { + tk.MustExec("create sequence test.s") + tk.MustQuery("select nextval(test.s)").Check(testkit.Rows("1")) + tk.MustQuery("select lastval(test.s)").Check(testkit.Rows("1")) + return nil + }, + checkFunc: func(tk *testkit.TestKit, param any) { + tk.MustQuery("select lastval(test.s)").Check(testkit.Rows("1")) + tk.MustQuery("select nextval(test.s)").Check(testkit.Rows("2")) + }, + }, + { + // check MPPStoreLastFailTime + setFunc: func(tk *testkit.TestKit) any { + tk.Session().GetSessionVars().MPPStoreLastFailTime = map[string]time.Time{"store1": time.Now()} + return tk.Session().GetSessionVars().MPPStoreLastFailTime + }, + checkFunc: func(tk *testkit.TestKit, param any) { + failTime := tk.Session().GetSessionVars().MPPStoreLastFailTime + require.Equal(t, 1, len(failTime)) + tm, ok := failTime["store1"] + require.True(t, ok) + require.True(t, param.(map[string]time.Time)["store1"].Equal(tm)) + }, + }, + { + // check FoundInPlanCache + setFunc: func(tk *testkit.TestKit) any { + require.False(t, tk.Session().GetSessionVars().FoundInPlanCache) + return nil + }, + checkFunc: func(tk *testkit.TestKit, param any) { + tk.MustQuery("select @@last_plan_from_cache").Check(testkit.Rows("0")) + }, + }, + { + // check FoundInPlanCache + setFunc: func(tk *testkit.TestKit) any { + tk.MustExec("prepare stmt from 'select * from test.t1'") + tk.MustQuery("execute stmt") + tk.MustQuery("execute stmt") + require.True(t, tk.Session().GetSessionVars().FoundInPlanCache) + return nil + }, + checkFunc: func(tk *testkit.TestKit, param any) { + tk.MustQuery("select @@last_plan_from_cache").Check(testkit.Rows("1")) + }, + }, + { + // check FoundInBinding + setFunc: func(tk *testkit.TestKit) any { + require.False(t, tk.Session().GetSessionVars().FoundInBinding) + return nil + }, + checkFunc: func(tk *testkit.TestKit, param any) { + tk.MustQuery("select @@last_plan_from_binding").Check(testkit.Rows("0")) + }, + }, + { + // check FoundInBinding + setFunc: func(tk *testkit.TestKit) any { + tk.MustExec("create session binding for select * from test.t1 using select * from test.t1") + tk.MustQuery("select * from test.t1") + require.True(t, tk.Session().GetSessionVars().FoundInBinding) + return nil + }, + checkFunc: func(tk *testkit.TestKit, param any) { + tk.MustQuery("select @@last_plan_from_binding").Check(testkit.Rows("1")) + }, + }, + } + + for _, tt := range tests { + tk1 := testkit.NewTestKit(t, store) + var param any + if tt.setFunc != nil { + param = tt.setFunc(tk1) + } + tk2 := testkit.NewTestKit(t, store) + showSessionStatesAndSet(t, tk1, tk2) + tt.checkFunc(tk2, param) + } +} + func showSessionStatesAndSet(t *testing.T, tk1, tk2 *testkit.TestKit) { rows := tk1.MustQuery("show session_states").Rows() require.Len(t, rows, 1) diff --git a/sessionctx/variable/sequence_state.go b/sessionctx/variable/sequence_state.go index bb8b468da2..38199b084f 100644 --- a/sessionctx/variable/sequence_state.go +++ b/sessionctx/variable/sequence_state.go @@ -50,3 +50,23 @@ func (ss *SequenceState) GetLastValue(sequenceID int64) (int64, bool, error) { } return 0, true, nil } + +// GetAllStates returns a copied latestValueMap. +func (ss *SequenceState) GetAllStates() map[int64]int64 { + ss.mu.Lock() + defer ss.mu.Unlock() + latestValueMap := make(map[int64]int64, len(ss.latestValueMap)) + for seqID, latestValue := range ss.latestValueMap { + latestValueMap[seqID] = latestValue + } + return latestValueMap +} + +// SetAllStates sets latestValueMap as a whole. +func (ss *SequenceState) SetAllStates(latestValueMap map[int64]int64) { + ss.mu.Lock() + defer ss.mu.Unlock() + for seqID, latestValue := range latestValueMap { + ss.latestValueMap[seqID] = latestValue + } +} diff --git a/sessionctx/variable/session.go b/sessionctx/variable/session.go index 735105c57d..fe4f469e76 100644 --- a/sessionctx/variable/session.go +++ b/sessionctx/variable/session.go @@ -1041,10 +1041,10 @@ type SessionVars struct { LastTxnInfo string // LastQueryInfo keeps track the info of last query. - LastQueryInfo QueryInfo + LastQueryInfo sessionstates.QueryInfo // LastDDLInfo keeps track the info of last DDL. - LastDDLInfo LastDDLInfo + LastDDLInfo sessionstates.LastDDLInfo // PartitionPruneMode indicates how and when to prune partitions. PartitionPruneMode atomic2.String @@ -1850,6 +1850,23 @@ func (s *SessionVars) EncodeSessionStates(ctx context.Context, sessionStates *se sessionStates.UserVarTypes[name] = userVarType.Clone() } s.UsersLock.RUnlock() + + // Encode other session contexts. + sessionStates.PreparedStmtID = s.preparedStmtID + sessionStates.Status = s.Status + sessionStates.CurrentDB = s.CurrentDB + sessionStates.LastTxnInfo = s.LastTxnInfo + if s.LastQueryInfo.StartTS != 0 { + sessionStates.LastQueryInfo = &s.LastQueryInfo + } + if s.LastDDLInfo.SeqNum != 0 { + sessionStates.LastDDLInfo = &s.LastDDLInfo + } + sessionStates.LastFoundRows = s.LastFoundRows + sessionStates.SequenceLatestValues = s.SequenceState.GetAllStates() + sessionStates.MPPStoreLastFailTime = s.MPPStoreLastFailTime + sessionStates.FoundInPlanCache = s.PrevFoundInPlanCache + sessionStates.FoundInBinding = s.PrevFoundInBinding return } @@ -1866,6 +1883,25 @@ func (s *SessionVars) DecodeSessionStates(ctx context.Context, sessionStates *se s.UserVarTypes[name] = userVarType.Clone() } s.UsersLock.Unlock() + + // Decode other session contexts. + s.preparedStmtID = sessionStates.PreparedStmtID + s.Status = sessionStates.Status + s.CurrentDB = sessionStates.CurrentDB + s.LastTxnInfo = sessionStates.LastTxnInfo + if sessionStates.LastQueryInfo != nil { + s.LastQueryInfo = *sessionStates.LastQueryInfo + } + if sessionStates.LastDDLInfo != nil { + s.LastDDLInfo = *sessionStates.LastDDLInfo + } + s.LastFoundRows = sessionStates.LastFoundRows + s.SequenceState.SetAllStates(sessionStates.SequenceLatestValues) + if sessionStates.MPPStoreLastFailTime != nil { + s.MPPStoreLastFailTime = sessionStates.MPPStoreLastFailTime + } + s.FoundInPlanCache = sessionStates.FoundInPlanCache + s.FoundInBinding = sessionStates.FoundInBinding return } @@ -2458,20 +2494,6 @@ func writeSlowLogItem(buf *bytes.Buffer, key, value string) { buf.WriteString(SlowLogRowPrefixStr + key + SlowLogSpaceMarkStr + value + "\n") } -// QueryInfo represents the information of last executed query. It's used to expose information for test purpose. -type QueryInfo struct { - TxnScope string `json:"txn_scope"` - StartTS uint64 `json:"start_ts"` - ForUpdateTS uint64 `json:"for_update_ts"` - ErrMsg string `json:"error,omitempty"` -} - -// LastDDLInfo represents the information of last DDL. It's used to expose information for test purpose. -type LastDDLInfo struct { - Query string `json:"query"` - SeqNum uint64 `json:"seq_num"` -} - // TxnReadTS indicates the value and used situation for tx_read_ts type TxnReadTS struct { readTS uint64