sessionctx: support encoding and decoding session contexts (#35648)

close pingcap/tidb#35573
This commit is contained in:
djshow832
2022-06-23 12:46:36 +08:00
committed by GitHub
parent df9b54bcb3
commit e44277d8fe
5 changed files with 332 additions and 31 deletions

View File

@ -2069,17 +2069,20 @@ func runStmt(ctx context.Context, se *session, s sqlexec.Statement) (rs sqlexec.
sessVars := se.sessionVars
// Record diagnostic information for DML statements
if _, ok := s.(*executor.ExecStmt).StmtNode.(ast.DMLNode); ok {
defer func() {
sessVars.LastQueryInfo = variable.QueryInfo{
TxnScope: sessVars.CheckAndGetTxnScope(),
StartTS: sessVars.TxnCtx.StartTS,
ForUpdateTS: sessVars.TxnCtx.GetForUpdateTS(),
}
if err != nil {
sessVars.LastQueryInfo.ErrMsg = err.Error()
}
}()
if stmt, ok := s.(*executor.ExecStmt).StmtNode.(ast.DMLNode); ok {
// Keep the previous queryInfo for `show session_states` because the statement needs to encode it.
if showStmt, ok := stmt.(*ast.ShowStmt); !ok || showStmt.Tp != ast.ShowSessionStates {
defer func() {
sessVars.LastQueryInfo = sessionstates.QueryInfo{
TxnScope: sessVars.CheckAndGetTxnScope(),
StartTS: sessVars.TxnCtx.StartTS,
ForUpdateTS: sessVars.TxnCtx.GetForUpdateTS(),
}
if err != nil {
sessVars.LastQueryInfo.ErrMsg = err.Error()
}
}()
}
}
// Save origTxnCtx here to avoid it reset in the transaction retry.

View File

@ -15,14 +15,41 @@
package sessionstates
import (
"time"
ptypes "github.com/pingcap/tidb/parser/types"
"github.com/pingcap/tidb/types"
)
// QueryInfo represents the information of last executed query. It's used to expose information for test purpose.
type QueryInfo struct {
TxnScope string `json:"txn_scope"`
StartTS uint64 `json:"start_ts"`
ForUpdateTS uint64 `json:"for_update_ts"`
ErrMsg string `json:"error,omitempty"`
}
// LastDDLInfo represents the information of last DDL. It's used to expose information for test purpose.
type LastDDLInfo struct {
Query string `json:"query"`
SeqNum uint64 `json:"seq_num"`
}
// SessionStates contains all the states in the session that should be migrated when the session
// is migrated to another server. It is shown by `show session_states` and recovered by `set session_states`.
type SessionStates struct {
UserVars map[string]*types.Datum `json:"user-var-values,omitempty"`
UserVarTypes map[string]*ptypes.FieldType `json:"user-var-types,omitempty"`
SystemVars map[string]string `json:"sys-vars,omitempty"`
UserVars map[string]*types.Datum `json:"user-var-values,omitempty"`
UserVarTypes map[string]*ptypes.FieldType `json:"user-var-types,omitempty"`
SystemVars map[string]string `json:"sys-vars,omitempty"`
PreparedStmtID uint32 `json:"prepared-stmt-id,omitempty"`
Status uint16 `json:"status,omitempty"`
CurrentDB string `json:"current-db,omitempty"`
LastTxnInfo string `json:"txn-info,omitempty"`
LastQueryInfo *QueryInfo `json:"query-info,omitempty"`
LastDDLInfo *LastDDLInfo `json:"ddl-info,omitempty"`
LastFoundRows uint64 `json:"found-rows,omitempty"`
FoundInPlanCache bool `json:"in-plan-cache,omitempty"`
FoundInBinding bool `json:"in-binding,omitempty"`
SequenceLatestValues map[int64]int64 `json:"seq-values,omitempty"`
MPPStoreLastFailTime map[string]time.Time `json:"store-fail-time,omitempty"`
}

View File

@ -19,8 +19,10 @@ import (
"strconv"
"strings"
"testing"
"time"
"github.com/pingcap/tidb/errno"
"github.com/pingcap/tidb/parser/mysql"
"github.com/pingcap/tidb/sessionctx/variable"
"github.com/pingcap/tidb/testkit"
"github.com/pingcap/tidb/util/sem"
@ -167,7 +169,10 @@ func TestSystemVars(t *testing.T) {
},
}
sem.Enable()
if !sem.IsEnabled() {
sem.Enable()
defer sem.Disable()
}
for _, tt := range tests {
tk1 := testkit.NewTestKit(t, store)
for _, stmt := range tt.stmts {
@ -206,6 +211,230 @@ func TestSystemVars(t *testing.T) {
}
}
func TestSessionCtx(t *testing.T) {
store, clean := testkit.CreateMockStore(t)
defer clean()
tk := testkit.NewTestKit(t, store)
tk.MustExec("create table test.t1(id int)")
tests := []struct {
setFunc func(tk *testkit.TestKit) any
checkFunc func(tk *testkit.TestKit, param any)
}{
{
// check PreparedStmtID
checkFunc: func(tk *testkit.TestKit, param any) {
require.Equal(t, uint32(1), tk.Session().GetSessionVars().GetNextPreparedStmtID())
},
},
{
// check PreparedStmtID
setFunc: func(tk *testkit.TestKit) any {
tk.MustExec("prepare stmt from 'select ?'")
return nil
},
checkFunc: func(tk *testkit.TestKit, param any) {
require.Equal(t, uint32(2), tk.Session().GetSessionVars().GetNextPreparedStmtID())
},
},
{
// check Status
checkFunc: func(tk *testkit.TestKit, param any) {
require.Equal(t, mysql.ServerStatusAutocommit, tk.Session().GetSessionVars().Status&mysql.ServerStatusAutocommit)
},
},
{
// check Status
setFunc: func(tk *testkit.TestKit) any {
tk.MustExec("set autocommit=0")
return nil
},
checkFunc: func(tk *testkit.TestKit, param any) {
require.Equal(t, uint16(0), tk.Session().GetSessionVars().Status&mysql.ServerStatusAutocommit)
},
},
{
// check CurrentDB
checkFunc: func(tk *testkit.TestKit, param any) {
tk.MustQuery("select database()").Check(testkit.Rows("<nil>"))
},
},
{
// check CurrentDB
setFunc: func(tk *testkit.TestKit) any {
tk.MustExec("use test")
return nil
},
checkFunc: func(tk *testkit.TestKit, param any) {
tk.MustQuery("select database()").Check(testkit.Rows("test"))
},
},
{
// check LastTxnInfo
checkFunc: func(tk *testkit.TestKit, param any) {
tk.MustQuery("select @@tidb_last_txn_info").Check(testkit.Rows(""))
},
},
{
// check LastTxnInfo
setFunc: func(tk *testkit.TestKit) any {
tk.MustExec("begin")
tk.MustExec("insert test.t1 value(1)")
tk.MustExec("commit")
rows := tk.MustQuery("select @@tidb_last_txn_info").Rows()
require.NotEqual(t, "", rows[0][0].(string))
return rows
},
checkFunc: func(tk *testkit.TestKit, param any) {
tk.MustQuery("select @@tidb_last_txn_info").Check(param.([][]interface{}))
},
},
{
// check LastQueryInfo
setFunc: func(tk *testkit.TestKit) any {
rows := tk.MustQuery("select @@tidb_last_query_info").Rows()
require.NotEqual(t, "", rows[0][0].(string))
return rows
},
checkFunc: func(tk *testkit.TestKit, param any) {
tk.MustQuery("select @@tidb_last_query_info").Check(param.([][]interface{}))
},
},
{
// check LastQueryInfo
setFunc: func(tk *testkit.TestKit) any {
tk.MustQuery("select * from test.t1")
startTS := tk.Session().GetSessionVars().LastQueryInfo.StartTS
require.NotEqual(t, uint64(0), startTS)
return startTS
},
checkFunc: func(tk *testkit.TestKit, param any) {
startTS := tk.Session().GetSessionVars().LastQueryInfo.StartTS
require.Equal(t, param.(uint64), startTS)
},
},
{
// check LastDDLInfo
setFunc: func(tk *testkit.TestKit) any {
rows := tk.MustQuery("select @@tidb_last_ddl_info").Rows()
require.NotEqual(t, "", rows[0][0].(string))
return rows
},
checkFunc: func(tk *testkit.TestKit, param any) {
tk.MustQuery("select @@tidb_last_ddl_info").Check(param.([][]interface{}))
},
},
{
// check LastDDLInfo
setFunc: func(tk *testkit.TestKit) any {
tk.MustExec("truncate table test.t1")
rows := tk.MustQuery("select @@tidb_last_ddl_info").Rows()
require.NotEqual(t, "", rows[0][0].(string))
return rows
},
checkFunc: func(tk *testkit.TestKit, param any) {
tk.MustQuery("select @@tidb_last_ddl_info").Check(param.([][]interface{}))
},
},
{
// check LastFoundRows
setFunc: func(tk *testkit.TestKit) any {
tk.MustExec("insert test.t1 value(1), (2), (3), (4), (5)")
// SQL_CALC_FOUND_ROWS is not supported now, so we just test normal select.
rows := tk.MustQuery("select * from test.t1 limit 3").Rows()
require.Equal(t, 3, len(rows))
return "3"
},
checkFunc: func(tk *testkit.TestKit, param any) {
tk.MustQuery("select found_rows()").Check(testkit.Rows(param.(string)))
},
},
{
// check SequenceState
setFunc: func(tk *testkit.TestKit) any {
tk.MustExec("create sequence test.s")
tk.MustQuery("select nextval(test.s)").Check(testkit.Rows("1"))
tk.MustQuery("select lastval(test.s)").Check(testkit.Rows("1"))
return nil
},
checkFunc: func(tk *testkit.TestKit, param any) {
tk.MustQuery("select lastval(test.s)").Check(testkit.Rows("1"))
tk.MustQuery("select nextval(test.s)").Check(testkit.Rows("2"))
},
},
{
// check MPPStoreLastFailTime
setFunc: func(tk *testkit.TestKit) any {
tk.Session().GetSessionVars().MPPStoreLastFailTime = map[string]time.Time{"store1": time.Now()}
return tk.Session().GetSessionVars().MPPStoreLastFailTime
},
checkFunc: func(tk *testkit.TestKit, param any) {
failTime := tk.Session().GetSessionVars().MPPStoreLastFailTime
require.Equal(t, 1, len(failTime))
tm, ok := failTime["store1"]
require.True(t, ok)
require.True(t, param.(map[string]time.Time)["store1"].Equal(tm))
},
},
{
// check FoundInPlanCache
setFunc: func(tk *testkit.TestKit) any {
require.False(t, tk.Session().GetSessionVars().FoundInPlanCache)
return nil
},
checkFunc: func(tk *testkit.TestKit, param any) {
tk.MustQuery("select @@last_plan_from_cache").Check(testkit.Rows("0"))
},
},
{
// check FoundInPlanCache
setFunc: func(tk *testkit.TestKit) any {
tk.MustExec("prepare stmt from 'select * from test.t1'")
tk.MustQuery("execute stmt")
tk.MustQuery("execute stmt")
require.True(t, tk.Session().GetSessionVars().FoundInPlanCache)
return nil
},
checkFunc: func(tk *testkit.TestKit, param any) {
tk.MustQuery("select @@last_plan_from_cache").Check(testkit.Rows("1"))
},
},
{
// check FoundInBinding
setFunc: func(tk *testkit.TestKit) any {
require.False(t, tk.Session().GetSessionVars().FoundInBinding)
return nil
},
checkFunc: func(tk *testkit.TestKit, param any) {
tk.MustQuery("select @@last_plan_from_binding").Check(testkit.Rows("0"))
},
},
{
// check FoundInBinding
setFunc: func(tk *testkit.TestKit) any {
tk.MustExec("create session binding for select * from test.t1 using select * from test.t1")
tk.MustQuery("select * from test.t1")
require.True(t, tk.Session().GetSessionVars().FoundInBinding)
return nil
},
checkFunc: func(tk *testkit.TestKit, param any) {
tk.MustQuery("select @@last_plan_from_binding").Check(testkit.Rows("1"))
},
},
}
for _, tt := range tests {
tk1 := testkit.NewTestKit(t, store)
var param any
if tt.setFunc != nil {
param = tt.setFunc(tk1)
}
tk2 := testkit.NewTestKit(t, store)
showSessionStatesAndSet(t, tk1, tk2)
tt.checkFunc(tk2, param)
}
}
func showSessionStatesAndSet(t *testing.T, tk1, tk2 *testkit.TestKit) {
rows := tk1.MustQuery("show session_states").Rows()
require.Len(t, rows, 1)

View File

@ -50,3 +50,23 @@ func (ss *SequenceState) GetLastValue(sequenceID int64) (int64, bool, error) {
}
return 0, true, nil
}
// GetAllStates returns a copied latestValueMap.
func (ss *SequenceState) GetAllStates() map[int64]int64 {
ss.mu.Lock()
defer ss.mu.Unlock()
latestValueMap := make(map[int64]int64, len(ss.latestValueMap))
for seqID, latestValue := range ss.latestValueMap {
latestValueMap[seqID] = latestValue
}
return latestValueMap
}
// SetAllStates sets latestValueMap as a whole.
func (ss *SequenceState) SetAllStates(latestValueMap map[int64]int64) {
ss.mu.Lock()
defer ss.mu.Unlock()
for seqID, latestValue := range latestValueMap {
ss.latestValueMap[seqID] = latestValue
}
}

View File

@ -1041,10 +1041,10 @@ type SessionVars struct {
LastTxnInfo string
// LastQueryInfo keeps track the info of last query.
LastQueryInfo QueryInfo
LastQueryInfo sessionstates.QueryInfo
// LastDDLInfo keeps track the info of last DDL.
LastDDLInfo LastDDLInfo
LastDDLInfo sessionstates.LastDDLInfo
// PartitionPruneMode indicates how and when to prune partitions.
PartitionPruneMode atomic2.String
@ -1850,6 +1850,23 @@ func (s *SessionVars) EncodeSessionStates(ctx context.Context, sessionStates *se
sessionStates.UserVarTypes[name] = userVarType.Clone()
}
s.UsersLock.RUnlock()
// Encode other session contexts.
sessionStates.PreparedStmtID = s.preparedStmtID
sessionStates.Status = s.Status
sessionStates.CurrentDB = s.CurrentDB
sessionStates.LastTxnInfo = s.LastTxnInfo
if s.LastQueryInfo.StartTS != 0 {
sessionStates.LastQueryInfo = &s.LastQueryInfo
}
if s.LastDDLInfo.SeqNum != 0 {
sessionStates.LastDDLInfo = &s.LastDDLInfo
}
sessionStates.LastFoundRows = s.LastFoundRows
sessionStates.SequenceLatestValues = s.SequenceState.GetAllStates()
sessionStates.MPPStoreLastFailTime = s.MPPStoreLastFailTime
sessionStates.FoundInPlanCache = s.PrevFoundInPlanCache
sessionStates.FoundInBinding = s.PrevFoundInBinding
return
}
@ -1866,6 +1883,25 @@ func (s *SessionVars) DecodeSessionStates(ctx context.Context, sessionStates *se
s.UserVarTypes[name] = userVarType.Clone()
}
s.UsersLock.Unlock()
// Decode other session contexts.
s.preparedStmtID = sessionStates.PreparedStmtID
s.Status = sessionStates.Status
s.CurrentDB = sessionStates.CurrentDB
s.LastTxnInfo = sessionStates.LastTxnInfo
if sessionStates.LastQueryInfo != nil {
s.LastQueryInfo = *sessionStates.LastQueryInfo
}
if sessionStates.LastDDLInfo != nil {
s.LastDDLInfo = *sessionStates.LastDDLInfo
}
s.LastFoundRows = sessionStates.LastFoundRows
s.SequenceState.SetAllStates(sessionStates.SequenceLatestValues)
if sessionStates.MPPStoreLastFailTime != nil {
s.MPPStoreLastFailTime = sessionStates.MPPStoreLastFailTime
}
s.FoundInPlanCache = sessionStates.FoundInPlanCache
s.FoundInBinding = sessionStates.FoundInBinding
return
}
@ -2458,20 +2494,6 @@ func writeSlowLogItem(buf *bytes.Buffer, key, value string) {
buf.WriteString(SlowLogRowPrefixStr + key + SlowLogSpaceMarkStr + value + "\n")
}
// QueryInfo represents the information of last executed query. It's used to expose information for test purpose.
type QueryInfo struct {
TxnScope string `json:"txn_scope"`
StartTS uint64 `json:"start_ts"`
ForUpdateTS uint64 `json:"for_update_ts"`
ErrMsg string `json:"error,omitempty"`
}
// LastDDLInfo represents the information of last DDL. It's used to expose information for test purpose.
type LastDDLInfo struct {
Query string `json:"query"`
SeqNum uint64 `json:"seq_num"`
}
// TxnReadTS indicates the value and used situation for tx_read_ts
type TxnReadTS struct {
readTS uint64