Files
tidb/pkg/session/syssession/session_test.go
2025-05-07 22:47:30 +00:00

1357 lines
38 KiB
Go

// Copyright 2025 PingCAP, Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package syssession
import (
"context"
"fmt"
"testing"
"time"
"github.com/pingcap/errors"
"github.com/pingcap/tidb/pkg/kv"
"github.com/pingcap/tidb/pkg/parser/ast"
"github.com/pingcap/tidb/pkg/planner/core/resolve"
"github.com/pingcap/tidb/pkg/sessionctx"
"github.com/pingcap/tidb/pkg/util"
"github.com/pingcap/tidb/pkg/util/chunk"
"github.com/pingcap/tidb/pkg/util/sqlexec"
"github.com/stretchr/testify/mock"
"github.com/stretchr/testify/require"
)
// WithSuppressAssert suppress asserts in test
func WithSuppressAssert(fn func()) {
defer func() {
suppressAssertInTest = false
}()
suppressAssertInTest = true
fn()
}
type mockOwner struct {
sessionOwner
mock.Mock
}
func (m *mockOwner) onBecameOwner(sctx sessionctx.Context) error {
return m.Called(sctx).Error(0)
}
func (m *mockOwner) onResignOwner(sctx sessionctx.Context) error {
return m.Called(sctx).Error(0)
}
type mockTxn struct {
mock.Mock
kv.Transaction
valid bool
}
func (txn *mockTxn) Valid() bool {
return txn.valid
}
func (txn *mockTxn) String() string {
return txn.Transaction.String()
}
type mockPreparedFuture struct {
sessionctx.TxnFuture
}
type mockSessionContext struct {
sessionctx.Context
util.SessionManager
mock.Mock
}
func (m *mockSessionContext) Close() {
m.Called()
}
func (m *mockSessionContext) RollbackTxn(ctx context.Context) {
m.Called(ctx)
}
func (m *mockSessionContext) GetSQLExecutor() sqlexec.SQLExecutor {
return m
}
func (m *mockSessionContext) Execute(ctx context.Context, sql string) (rs []sqlexec.RecordSet, err error) {
args := m.Called(ctx, sql)
if arg := args.Get(0); arg != nil {
rs = arg.([]sqlexec.RecordSet)
}
err = args.Error(1)
return
}
func (m *mockSessionContext) ExecuteInternal(ctx context.Context, sql string, sqlArgs ...any) (rs sqlexec.RecordSet, err error) {
args := m.Called(ctx, sql, sqlArgs)
if arg := args.Get(0); arg != nil {
rs = arg.(sqlexec.RecordSet)
}
err = args.Error(1)
return
}
func (m *mockSessionContext) ExecuteStmt(ctx context.Context, stmtNode ast.StmtNode) (rs sqlexec.RecordSet, err error) {
args := m.Called(ctx, stmtNode)
if arg := args.Get(0); arg != nil {
rs = arg.(sqlexec.RecordSet)
}
err = args.Error(1)
return
}
func (m *mockSessionContext) ParseWithParams(ctx context.Context, sql string, sqlArgs ...any) (n ast.StmtNode, err error) {
args := m.Called(ctx, sql, sqlArgs)
if arg := args.Get(0); arg != nil {
n = arg.(ast.StmtNode)
}
err = args.Error(1)
return
}
func (m *mockSessionContext) ExecRestrictedStmt(ctx context.Context, stmt ast.StmtNode, opts ...sqlexec.OptionFuncAlias) (rows []chunk.Row, fields []*resolve.ResultField, err error) {
args := m.Called(ctx, stmt, opts)
if arg := args.Get(0); arg != nil {
rows = arg.([]chunk.Row)
}
if arg := args.Get(1); arg != nil {
fields = arg.([]*resolve.ResultField)
}
err = args.Error(2)
return
}
func (m *mockSessionContext) ExecRestrictedSQL(ctx context.Context, opts []sqlexec.OptionFuncAlias, sql string, sqlArgs ...any) (rows []chunk.Row, fields []*resolve.ResultField, err error) {
args := m.Called(ctx, opts, sql, sqlArgs)
if arg := args.Get(0); arg != nil {
rows = arg.([]chunk.Row)
}
if arg := args.Get(1); arg != nil {
fields = arg.([]*resolve.ResultField)
}
err = args.Error(2)
return
}
func (m *mockSessionContext) GetRestrictedSQLExecutor() sqlexec.RestrictedSQLExecutor {
return m
}
func (m *mockSessionContext) GetPreparedTxnFuture() sessionctx.TxnFuture {
if arg := m.Called().Get(0); arg != nil {
return arg.(sessionctx.TxnFuture)
}
return nil
}
func (m *mockSessionContext) Txn(active bool) (txn kv.Transaction, err error) {
args := m.Called(active)
err = args.Error(1)
if args.Get(0) != nil {
txn = args.Get(0).(kv.Transaction)
}
return
}
func (m *mockSessionContext) MockNoPendingTxn() {
m.On("Txn", false).Return(&mockTxn{valid: false}, nil).Once()
m.On("GetPreparedTxnFuture").Return(nil).Once()
}
func (m *mockSessionContext) MockResetState(ctx context.Context, panicStr string) {
m.On("RollbackTxn", ctx).Run(func(args mock.Arguments) {
if panicStr != "" {
panic(panicStr)
}
}).Once()
}
func TestNewInternalSession(t *testing.T) {
sctx := &mockSessionContext{}
owner := &mockOwner{}
// newInternalSession success case
owner.On("onBecameOwner", sctx).Return(nil).Once()
se, err := newInternalSession(sctx, owner)
require.NoError(t, err)
require.NotNil(t, se)
require.Same(t, owner, se.owner)
require.Same(t, sctx, se.sctx)
require.Zero(t, se.inUse)
require.False(t, se.avoidReuse)
owner.AssertExpectations(t)
// onBecameOwner returns an error
mockErr := errors.New("mockOnBecameOwnerErr")
owner.On("onBecameOwner", sctx).Return(mockErr).Once()
se, err = newInternalSession(sctx, owner)
require.EqualError(t, err, mockErr.Error())
require.Nil(t, se)
owner.AssertExpectations(t)
}
func mockInternalSession(t *testing.T, sctx *mockSessionContext, owner *mockOwner) *session {
owner.On("onBecameOwner", sctx).Return(nil).Once()
se, err := newInternalSession(sctx, owner)
require.NoError(t, err)
require.Same(t, owner, se.Owner())
require.False(t, se.IsClosed())
owner.AssertExpectations(t)
return se
}
func mockSession(t *testing.T, sctx *mockSessionContext) *Session {
se, err := NewSessionForTest(sctx)
require.NoError(t, err)
require.Same(t, se, se.internal.Owner())
require.False(t, se.internal.IsClosed())
sctx.AssertExpectations(t)
return se
}
func TestResignOwnerAndCloseSctx(t *testing.T) {
sctx := &mockSessionContext{}
owner := &mockOwner{}
// success case
owner.On("onResignOwner", sctx).Return(nil).Once()
sctx.On("Close").Once()
require.NoError(t, resignOwnerAndCloseSctx(owner, sctx))
owner.AssertExpectations(t)
sctx.AssertExpectations(t)
// onResignOwner returns an error, should also close the session
owner.On("onResignOwner", sctx).Return(errors.New("mockErr")).Once()
sctx.On("Close").Once()
require.EqualError(t, resignOwnerAndCloseSctx(owner, sctx), "mockErr")
owner.AssertExpectations(t)
sctx.AssertExpectations(t)
// onResignOwner panics, should also close the session
owner.On("onResignOwner", sctx).Panic("mockPanic").Once()
sctx.On("Close").Once()
require.PanicsWithValue(t, "mockPanic", func() {
_ = resignOwnerAndCloseSctx(owner, sctx)
})
owner.AssertExpectations(t)
sctx.AssertExpectations(t)
// Close panics
owner.On("onResignOwner", sctx).Return(nil).Once()
sctx.On("Close").Panic("mockClosePanic")
require.PanicsWithValue(t, "mockClosePanic", func() {
_ = resignOwnerAndCloseSctx(owner, sctx)
})
owner.AssertExpectations(t)
sctx.AssertExpectations(t)
}
func TestInternalSessionTransferOwner(t *testing.T) {
sctx := &mockSessionContext{}
// TransferOwner success
owner1 := &mockOwner{}
se := mockInternalSession(t, sctx, owner1)
owner2 := &mockOwner{}
owner1.On("onResignOwner", sctx).Return(nil).Once()
owner2.On("onBecameOwner", sctx).Return(nil).Once()
require.NoError(t, se.TransferOwner(owner1, owner2))
require.Same(t, owner2, se.Owner())
owner1.AssertExpectations(t)
owner2.AssertExpectations(t)
owner3 := &mockOwner{}
// transfer from an invalid owner should fail
require.Error(t, se.TransferOwner(owner1, owner3))
require.Same(t, owner2, se.Owner())
require.Error(t, se.TransferOwner(owner1, owner2))
require.Same(t, owner2, se.Owner())
// transfer to the same owner should take no effect
require.NoError(t, se.TransferOwner(owner2, owner2))
require.Same(t, owner2, se.Owner())
// TransferOwner from=nil should fail
require.Error(t, se.TransferOwner(nil, owner3))
require.Same(t, owner2, se.Owner())
// TransferOwner to=nil should fail
require.Error(t, se.TransferOwner(owner2, nil))
require.Same(t, owner2, se.Owner())
// TransferOwner from=to=nil should fail
require.Error(t, se.TransferOwner(nil, nil))
require.Same(t, owner2, se.Owner())
require.False(t, se.IsClosed())
// onResignOwner returns an error and should close the session
mockErr := errors.New("mockOnResignOwnerErr")
owner2.On("onResignOwner", sctx).Return(mockErr).Once()
sctx.On("Close").Once()
require.EqualError(t, se.TransferOwner(owner2, owner1), mockErr.Error())
require.Nil(t, se.Owner())
require.True(t, se.IsClosed())
owner2.AssertExpectations(t)
sctx.AssertExpectations(t)
// onResignOwner panics should close the session
se = mockInternalSession(t, sctx, owner2)
owner2.On("onResignOwner", sctx).Panic("mock panic1").Once()
sctx.On("Close").Once()
require.PanicsWithValue(t, "mock panic1", func() {
_ = se.TransferOwner(owner2, owner3)
})
require.Nil(t, se.Owner())
require.True(t, se.IsClosed())
owner2.AssertExpectations(t)
sctx.AssertExpectations(t)
// onBecameOwner returns an error and should close the session
se = mockInternalSession(t, sctx, owner2)
mockErr = errors.New("mockOnBecameOwner")
owner2.On("onResignOwner", sctx).Return(nil).Once()
owner3.On("onBecameOwner", sctx).Return(mockErr).Once()
sctx.On("Close").Once()
require.EqualError(t, se.TransferOwner(owner2, owner3), mockErr.Error())
require.Nil(t, se.Owner())
require.True(t, se.IsClosed())
owner2.AssertExpectations(t)
owner3.AssertExpectations(t)
sctx.AssertExpectations(t)
// onBecameOwner panics should close the session
se = mockInternalSession(t, sctx, owner2)
mockErr = errors.New("mockOnBecameOwner")
owner2.On("onResignOwner", sctx).Return(nil).Once()
owner3.On("onBecameOwner", sctx).Panic("mock panic2").Once()
sctx.On("Close").Once()
require.PanicsWithValue(t, "mock panic2", func() {
_ = se.TransferOwner(owner2, owner3)
})
require.Nil(t, se.Owner())
require.True(t, se.IsClosed())
owner2.AssertExpectations(t)
owner3.AssertExpectations(t)
sctx.AssertExpectations(t)
// close panics
se = mockInternalSession(t, sctx, owner2)
owner2.On("onResignOwner", sctx).Return(nil).Once()
owner3.On("onBecameOwner", sctx).Return(mockErr).Once()
sctx.On("Close").Panic("close panic").Once()
require.PanicsWithValue(t, "close panic", func() {
_ = se.TransferOwner(owner2, owner3)
})
require.Nil(t, se.Owner())
require.True(t, se.IsClosed())
owner2.AssertExpectations(t)
owner3.AssertExpectations(t)
sctx.AssertExpectations(t)
// A closed session should not transfer the owner again
require.Error(t, se.TransferOwner(nil, owner1))
require.Nil(t, se.Owner())
require.True(t, se.IsClosed())
require.Error(t, se.TransferOwner(owner2, owner1))
require.Nil(t, se.Owner())
require.True(t, se.IsClosed())
require.Error(t, se.TransferOwner(nil, nil))
require.Nil(t, se.Owner())
require.True(t, se.IsClosed())
// inUse session should not transfer the owner
se = mockInternalSession(t, sctx, owner1)
_, exit, err := se.EnterOperation(owner1, false)
require.NoError(t, err)
require.Equal(t, uint64(1), se.Inuse())
err = se.TransferOwner(owner1, owner2)
require.Error(t, err)
require.Contains(t, err.Error(), "session is still inUse: 1")
require.Same(t, owner1, se.Owner())
require.False(t, se.IsClosed())
// after exit, the session can be transferred again
exit()
owner1.On("onResignOwner", sctx).Return(nil).Once()
owner2.On("onBecameOwner", sctx).Return(nil).Once()
require.NoError(t, se.TransferOwner(owner1, owner2))
require.Same(t, owner2, se.Owner())
require.False(t, se.IsClosed())
owner1.AssertExpectations(t)
owner2.AssertExpectations(t)
}
func TestInternalSessionClose(t *testing.T) {
sctx := &mockSessionContext{}
owner := &mockOwner{}
se := mockInternalSession(t, sctx, owner)
require.False(t, se.IsClosed())
require.Same(t, owner, se.Owner())
// Close with a nil owner
se.OwnerClose(nil)
require.False(t, se.IsClosed())
require.Same(t, owner, se.Owner())
// Close with an invalid owner
owner2 := &mockOwner{}
se.OwnerClose(owner2)
require.False(t, se.IsClosed())
require.Same(t, owner, se.Owner())
// Close with the current owner
owner.On("onResignOwner", sctx).Return(nil).Once()
sctx.On("Close").Once()
se.OwnerClose(owner)
require.True(t, se.IsClosed())
require.Nil(t, se.Owner())
owner.AssertExpectations(t)
sctx.AssertExpectations(t)
// Close a closed session
se.OwnerClose(owner)
se.OwnerClose(nil)
se.OwnerClose(owner2)
require.True(t, se.IsClosed())
require.Nil(t, se.Owner())
// Close without an owner
se = mockInternalSession(t, sctx, owner)
owner.On("onResignOwner", sctx).Return(nil).Once()
sctx.On("Close").Once()
se.Close()
require.Nil(t, se.Owner())
require.True(t, se.IsClosed())
// Close after close
se.Close()
require.Nil(t, se.Owner())
require.True(t, se.IsClosed())
// test with error when close
testWithErrorMock := func(closeFn func(se *session, owner *mockOwner)) {
// onResignOwner failed should also close the context
se = mockInternalSession(t, sctx, owner)
owner.On("onResignOwner", sctx).Return(errors.New("mockErr1")).Once()
sctx.On("Close").Once()
WithSuppressAssert(func() {
closeFn(se, owner)
})
require.True(t, se.IsClosed())
require.Nil(t, se.Owner())
owner.AssertExpectations(t)
sctx.AssertExpectations(t)
// onResignOwner panics should also close the context
se = mockInternalSession(t, sctx, owner)
owner.On("onResignOwner", sctx).Panic("panic1").Once()
sctx.On("Close").Once()
require.PanicsWithValue(t, "panic1", func() {
closeFn(se, owner)
})
require.True(t, se.IsClosed())
require.Nil(t, se.Owner())
owner.AssertExpectations(t)
sctx.AssertExpectations(t)
// context.Close() panics
se = mockInternalSession(t, sctx, owner)
owner.On("onResignOwner", sctx).Return(nil).Once()
sctx.On("Close").Panic("panic2").Once()
require.PanicsWithValue(t, "panic2", func() {
closeFn(se, owner)
})
require.True(t, se.IsClosed())
require.Nil(t, se.Owner())
owner.AssertExpectations(t)
sctx.AssertExpectations(t)
// inUse > 0 when close
se = mockInternalSession(t, sctx, owner)
_, exit1, err := se.EnterOperation(owner, false)
require.NoError(t, err)
require.Equal(t, uint64(1), se.Inuse())
_, exit2, err := se.EnterOperation(owner, true)
require.NoError(t, err)
require.Equal(t, uint64(2), se.Inuse())
_, exit3, err := se.EnterOperation(owner, true)
require.NoError(t, err)
require.Equal(t, uint64(3), se.Inuse())
// should not call `sctx.Close()` when inuse > 0 to avoid some data race
WithSuppressAssert(func() {
closeFn(se, owner)
})
require.True(t, se.IsClosed())
require.Nil(t, se.Owner())
// sctx.Close should not be called when inuse > 0 after exit
require.Equal(t, uint64(3), se.Inuse())
WithSuppressAssert(exit1)
require.Equal(t, uint64(2), se.Inuse())
WithSuppressAssert(exit3)
require.Equal(t, uint64(1), se.Inuse())
// sctx.Close should be called after inuse decreased to 0
owner.On("onResignOwner", sctx).Return(nil).Once()
sctx.On("Close").Once()
WithSuppressAssert(exit2)
require.Zero(t, se.Inuse())
sctx.AssertExpectations(t)
owner.AssertExpectations(t)
}
testWithErrorMock(func(se *session, owner *mockOwner) {
se.OwnerClose(owner)
})
testWithErrorMock(func(_ *session, _ *mockOwner) {
se.Close()
})
}
func TestInternalSessionEnterOperation(t *testing.T) {
sctx := &mockSessionContext{}
owner := &mockOwner{}
se := mockInternalSession(t, sctx, owner)
// test EnterOperation will add inUse
gotSctx, exit1, err := se.EnterOperation(owner, true)
require.NoError(t, err)
require.Same(t, sctx, gotSctx)
require.Equal(t, uint64(1), se.Inuse())
gotSctx, exit2, err := se.EnterOperation(owner, false)
require.NoError(t, err)
require.Same(t, sctx, gotSctx)
require.Equal(t, uint64(2), se.Inuse())
gotSctx, exit3, err := se.EnterOperation(owner, true)
require.NoError(t, err)
require.Same(t, sctx, gotSctx)
require.Equal(t, uint64(3), se.Inuse())
// test exit will decrease inUse
exit1()
require.Equal(t, uint64(2), se.Inuse())
exit3()
require.Equal(t, uint64(1), se.Inuse())
WithSuppressAssert(func() {
// multiple exit should take no effect
exit1()
require.Equal(t, uint64(1), se.Inuse())
})
require.Equal(t, uint64(1), se.Inuse())
exit2()
require.Equal(t, uint64(0), se.Inuse())
// call with an invalid owner should report an error
WithSuppressAssert(func() {
gotSctx, exit, err := se.EnterOperation(&mockOwner{}, false)
require.Error(t, err)
require.Contains(t, err.Error(), "caller is not the owner")
require.Nil(t, gotSctx)
require.Nil(t, exit)
require.Equal(t, uint64(0), se.Inuse())
})
// call with a nil owner should report an error
WithSuppressAssert(func() {
gotSctx, exit, err := se.EnterOperation(nil, false)
require.Error(t, err)
require.Contains(t, err.Error(), "caller is not the owner")
require.Nil(t, gotSctx)
require.Nil(t, exit)
require.Equal(t, uint64(0), se.Inuse())
})
// call in a closed session should report an error
owner.On("onResignOwner", sctx).Return(nil).Once()
sctx.On("Close").Once()
se.Close()
WithSuppressAssert(func() {
gotSctx, exit, err := se.EnterOperation(owner, false)
require.Error(t, err)
require.Contains(t, err.Error(), "session is closed")
require.Nil(t, gotSctx)
require.Nil(t, exit)
require.Equal(t, uint64(0), se.Inuse())
})
owner.AssertExpectations(t)
sctx.AssertExpectations(t)
// close the session after enter operation
se = mockInternalSession(t, sctx, owner)
gotSctx, exit4, err := se.EnterOperation(owner, true)
require.NoError(t, err)
require.Same(t, sctx, gotSctx)
require.Equal(t, uint64(1), se.Inuse())
gotSctx, exit5, err := se.EnterOperation(owner, true)
require.NoError(t, err)
require.Same(t, sctx, gotSctx)
require.Equal(t, uint64(2), se.Inuse())
// Close the session before exit
WithSuppressAssert(se.Close)
require.True(t, se.IsClosed())
require.Equal(t, uint64(2), se.Inuse())
require.True(t, se.IsClosed())
// new EnterOperation should be rejected after close but inuse > 0
WithSuppressAssert(func() {
_, _, err = se.EnterOperation(owner, false)
require.Error(t, err)
require.Contains(t, err.Error(), "session is closed")
})
// The exit should still work to decrease inUse
WithSuppressAssert(exit4)
require.Equal(t, uint64(1), se.Inuse())
// when inuse > 0, the session should not call sctx.Close() even if onResignOwner fails
owner.On("onResignOwner", sctx).Return(errors.New("mockErr")).Once()
sctx.On("Close").Once()
WithSuppressAssert(exit5)
require.Equal(t, uint64(0), se.Inuse())
owner.AssertExpectations(t)
sctx.AssertExpectations(t)
}
func TestInternalSessionOwnerWithSctx(t *testing.T) {
sctx := &mockSessionContext{}
owner := &mockOwner{}
se := mockInternalSession(t, sctx, owner)
mockCb := &mock.Mock{}
cb := func(sctx SessionContext) error {
require.Equal(t, uint64(1), se.Inuse())
return mockCb.MethodCalled("cb", sctx).Error(0)
}
// normal case
mockCb.On("cb", sctx).Return(nil).Once()
require.NoError(t, se.OwnerWithSctx(owner, cb))
require.Zero(t, se.Inuse())
mockCb.AssertExpectations(t)
// invalid owner
WithSuppressAssert(func() {
err := se.OwnerWithSctx(&mockOwner{}, cb)
require.Error(t, err)
require.Contains(t, err.Error(), "caller is not the owner")
require.Zero(t, se.Inuse())
})
// error in cb
mockCb.On("cb", sctx).Return(errors.New("mockErr")).Once()
err := se.OwnerWithSctx(owner, cb)
require.Error(t, err)
require.Contains(t, err.Error(), "mockErr")
require.Zero(t, se.Inuse())
mockCb.AssertExpectations(t)
// panic in cb
mockCb.On("cb", sctx).Panic("mockPanic").Once()
require.PanicsWithValue(t, "mockPanic", func() {
_ = se.OwnerWithSctx(owner, cb)
})
require.Zero(t, se.Inuse())
mockCb.AssertExpectations(t)
// session closed
sctx.On("Close").Once()
owner.On("onResignOwner", sctx).Return(nil).Once()
se.Close()
require.True(t, se.IsClosed())
sctx.AssertExpectations(t)
owner.AssertExpectations(t)
WithSuppressAssert(func() {
err := se.OwnerWithSctx(owner, cb)
require.Error(t, err)
require.Contains(t, err.Error(), "session is closed")
require.Zero(t, se.Inuse())
})
}
func TestInternalSessionAvoidReuse(t *testing.T) {
sctx := &mockSessionContext{}
owner := &mockOwner{}
se := mockInternalSession(t, sctx, owner)
require.False(t, se.IsAvoidReuse())
execute := func(do func()) {
gotSctx, exit, err := se.EnterOperation(owner, false)
require.NoError(t, err)
require.Same(t, sctx, gotSctx)
defer exit()
do()
}
// normal use
execute(func() {})
require.False(t, se.IsAvoidReuse())
require.False(t, se.IsClosed())
require.Same(t, owner, se.Owner())
// panic use
require.PanicsWithValue(t, "panic1", func() {
execute(func() {
panic("panic1")
})
})
require.True(t, se.IsAvoidReuse())
require.False(t, se.IsClosed())
require.Same(t, owner, se.Owner())
// allow to close
owner.On("onResignOwner", sctx).Return(nil).Once()
sctx.On("Close").Once()
se.OwnerClose(owner)
require.True(t, se.IsAvoidReuse())
require.True(t, se.IsClosed())
require.Nil(t, se.Owner())
owner.AssertExpectations(t)
sctx.AssertExpectations(t)
}
func TestInternalSessionCheckNoPendingTxn(t *testing.T) {
sctx := &mockSessionContext{}
se := mockInternalSession(t, sctx, &mockOwner{})
sctx.MockNoPendingTxn()
require.NoError(t, se.CheckNoPendingTxn())
sctx.AssertExpectations(t)
sctx.On("GetPreparedTxnFuture").Return(&mockPreparedFuture{}).Once()
require.EqualError(t, se.CheckNoPendingTxn(), "txn is pending for TSO")
sctx.AssertExpectations(t)
sctx.On("GetPreparedTxnFuture").Return(nil).Once()
sctx.On("Txn", false).Return(nil, errors.New("mockErr")).Once()
require.EqualError(t, se.CheckNoPendingTxn(), "mockErr")
sctx.AssertExpectations(t)
sctx.On("GetPreparedTxnFuture").Return(nil).Once()
sctx.On("Txn", false).Return(&mockTxn{valid: true}, nil).Once()
require.EqualError(t, se.CheckNoPendingTxn(), "txn is still valid")
sctx.AssertExpectations(t)
}
func TestInternalSessionResetState(t *testing.T) {
sctx := &mockSessionContext{}
owner := &mockOwner{}
se := mockInternalSession(t, sctx, owner)
ctx := context.WithValue(context.Background(), "a", "b")
checkInuse := func(mock.Arguments) { require.Equal(t, uint64(1), se.Inuse()) }
// normal case
sctx.On("RollbackTxn", ctx).Run(checkInuse).Once()
require.NoError(t, se.OwnerResetState(ctx, owner))
require.Zero(t, se.Inuse())
sctx.AssertExpectations(t)
// RollbackTxn panic
sctx.On("RollbackTxn", ctx).Run(checkInuse).Panic("mockPanic1").Once()
require.PanicsWithValue(t, "mockPanic1", func() {
_ = se.OwnerResetState(ctx, owner)
})
require.Zero(t, se.Inuse())
sctx.AssertExpectations(t)
// not owner
WithSuppressAssert(func() {
err := se.OwnerResetState(ctx, &mockOwner{})
require.Error(t, err)
require.Contains(t, err.Error(), "caller is not the owner")
require.Zero(t, se.Inuse())
})
// closed session
owner.On("onResignOwner", sctx).Return(nil).Once()
sctx.On("Close").Once()
se.Close()
sctx.AssertExpectations(t)
owner.AssertExpectations(t)
WithSuppressAssert(func() {
err := se.OwnerResetState(ctx, owner)
require.Error(t, err)
require.Contains(t, err.Error(), "session is closed")
require.Zero(t, se.Inuse())
})
}
func testCallProxyMethod(
t *testing.T,
sctx *mockSessionContext,
se *Session,
name string,
args []any,
returns []any,
callMethod func([]any) []any,
) {
// normal call
inuse := se.internal.Inuse()
require.GreaterOrEqual(t, inuse, uint64(0))
inCallInuse := uint64(0)
sctx.On(name, args...).Return(returns...).Run(func(_ mock.Arguments) {
inCallInuse = se.internal.Inuse()
}).Once()
actualRets := callMethod(args)
require.Equal(t, returns, actualRets)
require.Equal(t, inuse+1, inCallInuse)
require.Equal(t, inuse, se.internal.Inuse())
sctx.AssertExpectations(t)
// transfer the owner to make call fail
owner2 := noopOwnerHook{}
require.NoError(t, se.internal.TransferOwner(se, owner2))
sctx.AssertExpectations(t)
// error call
WithSuppressAssert(func() {
actualRets = callMethod(args)
})
require.Equal(t, inuse, se.internal.Inuse())
for i := range actualRets {
if i < len(actualRets)-1 {
require.Nil(t, actualRets[i], fmt.Sprintf("%d: %v", i, actualRets[i]))
} else {
err := actualRets[i].(error)
require.Contains(t, err.Error(), "caller is not the owner")
}
}
// transfer the owner back
require.NoError(t, se.internal.TransferOwner(owner2, se))
sctx.AssertExpectations(t)
// test panics
require.False(t, se.internal.IsAvoidReuse())
sctx.On(name, args...).Run(func(_ mock.Arguments) {
inCallInuse = se.internal.Inuse()
panic("panicTest")
}).Once()
require.PanicsWithValue(t, "panicTest", func() {
callMethod(args)
})
require.Equal(t, inuse+1, inCallInuse)
require.Equal(t, inuse, se.internal.Inuse())
require.True(t, se.internal.IsAvoidReuse())
se.internal.avoidReuse = false
sctx.AssertExpectations(t)
}
func testCallProxyMethod22[A1 any, A2 any, R1 any, R2 any](
t *testing.T,
sctx *mockSessionContext,
se *Session,
name string,
arg1 A1, arg2 A2,
r1 R1, r2 R2,
callMethod func(arg1 A1, arg2 A2) (R1, R2),
) {
testCallProxyMethod(t, sctx, se, name, []any{arg1, arg2}, []any{r1, r2}, func(args []any) []any {
ret1, ret2 := callMethod(args[0].(A1), args[1].(A2))
return []any{ret1, ret2}
})
}
func testCallProxyMethod32[A1 any, A2 any, A3 any, R1 any, R2 any](
t *testing.T,
sctx *mockSessionContext,
se *Session,
name string,
arg1 A1, arg2 A2, arg3 A3,
r1 R1, r2 R2,
callMethod func(arg1 A1, arg2 A2, arg3 A3) (R1, R2),
) {
testCallProxyMethod(t, sctx, se, name, []any{arg1, arg2, arg3}, []any{r1, r2}, func(args []any) []any {
ret1, ret2 := callMethod(args[0].(A1), args[1].(A2), args[2].(A3))
return []any{ret1, ret2}
})
}
func testCallProxyMethod33[A1 any, A2 any, A3 any, R1 any, R2 any, R3 any](
t *testing.T,
sctx *mockSessionContext,
se *Session,
name string,
arg1 A1, arg2 A2, arg3 A3,
r1 R1, r2 R2, r3 R3,
callMethod func(arg1 A1, arg2 A2, arg3 A3) (R1, R2, R3),
) {
testCallProxyMethod(t, sctx, se, name, []any{arg1, arg2, arg3}, []any{r1, r2, r3}, func(args []any) []any {
ret1, ret2, ret3 := callMethod(args[0].(A1), args[1].(A2), args[2].(A3))
return []any{ret1, ret2, ret3}
})
}
func testCallProxyMethod43[A1 any, A2 any, A3 any, A4 any, R1 any, R2 any, R3 any](
t *testing.T,
sctx *mockSessionContext,
se *Session,
name string,
arg1 A1, arg2 A2, arg3 A3, arg4 A4,
r1 R1, r2 R2, r3 R3,
callMethod func(arg1 A1, arg2 A2, arg3 A3, arg4 A4) (R1, R2, R3),
) {
testCallProxyMethod(t, sctx, se, name, []any{arg1, arg2, arg3, arg4}, []any{r1, r2, r3}, func(args []any) []any {
ret1, ret2, ret3 := callMethod(args[0].(A1), args[1].(A2), args[2].(A3), args[3].(A4))
return []any{ret1, ret2, ret3}
})
}
type mockRecordSet struct {
v int
sqlexec.RecordSet
}
func TestSessionProxyMethods(t *testing.T) {
sctx := &mockSessionContext{}
se := mockSession(t, sctx)
require.Same(t, se, se.GetSQLExecutor())
require.Same(t, se, se.GetRestrictedSQLExecutor())
ctx := context.WithValue(context.Background(), "a", "b")
testCallProxyMethod22(
t, sctx, se, "Execute",
ctx, "select 1, 2, 3",
[]sqlexec.RecordSet{&mockRecordSet{v: 1}, &mockRecordSet{v: 2}}, errors.New("err"),
se.Execute,
)
testCallProxyMethod32(
t, sctx, se, "ExecuteInternal",
ctx, "select 1, 2, 3", []any{"a", "c", 2},
sqlexec.RecordSet(&mockRecordSet{v: 3}), errors.New("err"),
func(arg1 context.Context, arg2 string, arg3 []any) (sqlexec.RecordSet, error) {
return se.ExecuteInternal(arg1, arg2, arg3...)
},
)
testCallProxyMethod22(
t, sctx, se, "ExecuteStmt",
ctx, ast.StmtNode(&ast.SelectStmt{Distinct: true}),
sqlexec.RecordSet(&mockRecordSet{v: 3}), errors.New("err"),
se.ExecuteStmt,
)
testCallProxyMethod32(
t, sctx, se, "ParseWithParams",
ctx, "select a+?, b+?, c+? from t", []any{"a", 10, "5"},
ast.StmtNode(&ast.SelectStmt{}), errors.New("mockErr"),
func(arg1 context.Context, arg2 string, arg3 []any) (ast.StmtNode, error) {
return se.ParseWithParams(arg1, arg2, arg3...)
},
)
testCallProxyMethod33(
t, sctx, se, "ExecRestrictedStmt",
ctx,
ast.StmtNode(&ast.SelectStmt{}),
[]sqlexec.OptionFuncAlias{sqlexec.ExecOptionIgnoreWarning, sqlexec.ExecOptionAnalyzeVer1},
[]chunk.Row{{}, {}},
[]*resolve.ResultField{{DBName: ast.NewCIStr("v1")}, {DBName: ast.NewCIStr("v2")}},
errors.New("mockErr"),
func(arg1 context.Context, arg2 ast.StmtNode, arg3 []sqlexec.OptionFuncAlias) (
[]chunk.Row, []*resolve.ResultField, error,
) {
return se.ExecRestrictedStmt(arg1, arg2, arg3...)
},
)
testCallProxyMethod43(
t, sctx, se, "ExecRestrictedSQL",
ctx,
[]sqlexec.OptionFuncAlias{sqlexec.ExecOptionIgnoreWarning, sqlexec.ExecOptionAnalyzeVer1},
"select ?, ?, ?",
[]any{1, 2, "3"},
[]chunk.Row{{}, {}},
[]*resolve.ResultField{{DBName: ast.NewCIStr("v1")}, {DBName: ast.NewCIStr("v2")}},
errors.New("mockErr"),
func(arg1 context.Context, arg2 []sqlexec.OptionFuncAlias, arg3 string, arg4 []any) (
[]chunk.Row, []*resolve.ResultField, error,
) {
return se.ExecRestrictedSQL(arg1, arg2, arg3, arg4...)
},
)
testCallProxyMethod22(
t, sctx, se, "Execute",
ctx, "select 1, 2, 3",
[]sqlexec.RecordSet{&mockRecordSet{v: 1}, &mockRecordSet{v: 2}}, errors.New("err"),
func(arg1 context.Context, arg2 string) (rs []sqlexec.RecordSet, err error) {
var execErr error
err = se.WithSessionContext(func(sctx sessionctx.Context) error {
rs, execErr = sctx.GetSQLExecutor().Execute(arg1, arg2)
return execErr
})
if execErr != nil {
require.Same(t, execErr, err)
}
return
},
)
}
func TestSessionClose(t *testing.T) {
sctx := &mockSessionContext{}
se := mockSession(t, sctx)
// normal close
sctx.On("Close").Once()
se.Close()
require.True(t, se.internal.IsClosed())
require.True(t, se.IsInternalClosed())
sctx.AssertExpectations(t)
// cannot use it again after closed
WithSuppressAssert(func() {
_, err := se.Execute(context.TODO(), "select 1")
require.Error(t, err)
require.Contains(t, err.Error(), "session is closed")
})
// close a closed session should take no effect
se.Close()
// close a not own session should take no effect
se2 := mockSession(t, sctx)
se.internal = se2.internal
require.False(t, se.IsInternalClosed())
se.Close()
require.False(t, se.IsInternalClosed())
require.False(t, se2.IsInternalClosed())
// owner should close
sctx.On("Close").Once()
se2.Close()
require.True(t, se.IsInternalClosed())
require.True(t, se2.IsInternalClosed())
sctx.AssertExpectations(t)
}
func TestSessionAvoidReuse(t *testing.T) {
// owner
sctx := &mockSessionContext{}
s, err := NewSessionForTest(sctx)
require.NoError(t, err)
require.False(t, s.internal.avoidReuse)
s.AvoidReuse()
require.True(t, s.internal.avoidReuse)
// multiple calls of AvoidReuse
s.AvoidReuse()
require.True(t, s.internal.avoidReuse)
// still can use the session after AvoidReuse
ctx := context.WithValue(context.Background(), "a", "b")
sctx.On("ExecuteInternal", ctx, "select 1", []any(nil)).Return(nil, nil).Once()
rs, err := s.ExecuteInternal(ctx, "select 1")
require.Nil(t, rs)
require.NoError(t, err)
sctx.AssertExpectations(t)
// not owner
s, err = NewSessionForTest(sctx)
require.NoError(t, err)
s2 := &Session{internal: s.internal}
require.False(t, s.internal.avoidReuse)
s2.AvoidReuse()
require.False(t, s.internal.avoidReuse)
}
func TestInternalSessionUnThreadSafeOperations(t *testing.T) {
sctx := &mockSessionContext{}
owner := &mockOwner{}
se := mockInternalSession(t, sctx, owner)
enterThreadSafeOperation := func() func() {
inUse, unsafe := se.inUse, se.unsafe
gotSctx, exit, err := se.EnterOperation(owner, true)
require.NoError(t, err)
require.Same(t, sctx, gotSctx)
require.NotNil(t, exit)
require.Equal(t, inUse+1, se.inUse)
require.Equal(t, unsafe, se.unsafe)
return exit
}
enterFirstThreadUnsafeOperation := func() func() {
inUse, unsafe := se.inUse, se.unsafe
gotSctx, exit, err := se.EnterOperation(owner, false)
require.NoError(t, err)
require.Same(t, sctx, gotSctx)
require.NotNil(t, exit)
require.Equal(t, inUse+1, se.inUse)
require.Equal(t, unsafe+1, se.unsafe)
return exit
}
enterThreadUnsafeOperationExpectError := func() {
inUse, unsafe := se.inUse, se.unsafe
WithSuppressAssert(func() {
gotSctx, exit, err := se.EnterOperation(owner, false)
require.EqualError(t, err, "EnterOperation error: race detected for concurrent thread-unsafe operations")
require.Nil(t, gotSctx)
require.Nil(t, exit)
require.Equal(t, inUse, se.inUse)
require.Equal(t, unsafe+1, se.unsafe)
})
}
// multiple thread-safe operations should be allowed
exit1 := enterThreadSafeOperation()
exit2 := enterThreadSafeOperation()
// the first thread-unsafe operation should be allowed
exitUnsafe := enterFirstThreadUnsafeOperation()
// next thread-unsafe operation should return error
enterThreadUnsafeOperationExpectError()
// net thread safe operation should still success
exit3 := enterThreadSafeOperation()
// next thread-unsafe operation should return error
enterThreadUnsafeOperationExpectError()
// exit
require.Equal(t, uint64(4), se.inUse)
require.Equal(t, uint64(3), se.unsafe)
exit1()
require.Equal(t, uint64(3), se.inUse)
require.Equal(t, uint64(3), se.unsafe)
exit3()
require.Equal(t, uint64(2), se.inUse)
require.Equal(t, uint64(3), se.unsafe)
WithSuppressAssert(exitUnsafe)
require.Equal(t, uint64(1), se.inUse)
require.Equal(t, uint64(0), se.unsafe)
exit2()
require.Equal(t, uint64(0), se.inUse)
require.Equal(t, uint64(0), se.unsafe)
// new thread-unsafe operation should be allowed when after previous exits
exitUnsafe = enterThreadSafeOperation()
exitUnsafe()
require.Equal(t, uint64(0), se.inUse)
require.Equal(t, uint64(0), se.unsafe)
// when panic happens, the session should also decrease `inUse` and `unsafe`
func() {
defer func() {
require.Equal(t, uint64(0), se.inUse)
require.Equal(t, uint64(0), se.unsafe)
r := recover()
require.Equal(t, "mockPanic", r)
}()
exitUnsafe = enterFirstThreadUnsafeOperation()
defer exitUnsafe()
require.Equal(t, uint64(1), se.inUse)
require.Equal(t, uint64(1), se.unsafe)
enterThreadUnsafeOperationExpectError()
require.Equal(t, uint64(1), se.inUse)
require.Equal(t, uint64(2), se.unsafe)
panic("mockPanic")
}()
}
func TestSessionThreadUnsafeOperations(t *testing.T) {
sctx := &mockSessionContext{}
se := mockSession(t, sctx)
ctx := context.WithValue(context.Background(), "a", "b")
ch := make(chan bool)
waitCh := func(inOp bool) {
select {
case fromOp, ok := <-ch:
require.True(t, ok)
require.True(t, inOp != fromOp)
return
case <-time.After(10 * time.Second):
require.FailNow(t, "timed out")
}
}
sendCh := func(fromOp bool) {
select {
case ch <- fromOp:
case <-time.After(10 * time.Second):
require.FailNow(t, "timed out")
}
}
called := false
onCalled := func(_ mock.Arguments) {
require.False(t, called)
called = true
require.Equal(t, uint64(1), se.internal.Inuse())
sendCh(true)
waitCh(true)
}
operations := []struct {
name string
call func(first bool) error
}{
{
name: "WithSessionContext",
call: func(first bool) error {
err := se.WithSessionContext(func(got sessionctx.Context) error {
require.Same(t, sctx, got)
onCalled(nil)
return nil
})
if first {
require.True(t, called)
}
return err
},
},
{
name: "Execute",
call: func(first bool) error {
if first {
sctx.On("Execute", ctx, "select 1").
Run(onCalled).
Return(nil, nil).
Once()
}
_, err := se.Execute(ctx, "select 1")
sctx.AssertExpectations(t)
return err
},
},
{
name: "ExecuteInternal",
call: func(first bool) error {
if first {
sctx.On("ExecuteInternal", ctx, "select 1", []any(nil)).
Run(onCalled).
Return(nil, nil).
Once()
}
_, err := se.ExecuteInternal(ctx, "select 1")
sctx.AssertExpectations(t)
return err
},
},
{
name: "ExecuteStmt",
call: func(first bool) error {
n := &ast.SelectStmt{}
if first {
sctx.On("ExecuteStmt", ctx, n).
Run(onCalled).
Return(nil, nil).
Once()
}
_, err := se.ExecuteStmt(ctx, n)
sctx.AssertExpectations(t)
return err
},
},
{
name: "ParseWithParams",
call: func(first bool) error {
if first {
sctx.On("ParseWithParams", ctx, "select 1", []any(nil)).
Run(onCalled).
Return(nil, nil).
Once()
}
_, err := se.ParseWithParams(ctx, "select 1")
sctx.AssertExpectations(t)
return err
},
},
{
name: "ExecRestrictedStmt",
call: func(first bool) error {
n := &ast.SelectStmt{}
if first {
sctx.On("ExecRestrictedStmt", ctx, n, []sqlexec.OptionFuncAlias(nil)).
Run(onCalled).
Return(nil, nil, nil).
Once()
}
_, _, err := se.ExecRestrictedStmt(ctx, n)
sctx.AssertExpectations(t)
return err
},
},
{
name: "ExecRestrictedSQL",
call: func(first bool) error {
if first {
sctx.On("ExecRestrictedSQL", ctx, []sqlexec.OptionFuncAlias(nil), "select 1", []any(nil)).
Run(onCalled).
Return(nil, nil, nil).
Once()
}
_, _, err := se.ExecRestrictedSQL(ctx, nil, "select 1")
sctx.AssertExpectations(t)
return err
},
},
}
for i, op := range operations {
t.Run(op.name, func(t *testing.T) {
WithSuppressAssert(func() {
called = false
require.Zero(t, se.internal.inUse)
require.Zero(t, se.internal.unsafe)
// the first operation should be allowed
go func() {
require.NoError(t, op.call(true))
require.True(t, called)
sendCh(true)
}()
waitCh(false)
require.Equal(t, uint64(1), se.internal.inUse)
require.Equal(t, uint64(1), se.internal.unsafe)
// other operations should return errors
for j := range 3 {
next := i + j
if next >= len(operations) {
next = len(operations) - 1
}
require.EqualError(
t, operations[next].call(false),
"EnterOperation error: race detected for concurrent thread-unsafe operations",
)
require.Equal(t, uint64(1), se.internal.inUse)
require.Equal(t, uint64(j+2), se.internal.unsafe)
}
// notify first op to continue
sendCh(false)
// wait first op exit
waitCh(false)
require.Zero(t, se.internal.inUse)
require.Zero(t, se.internal.unsafe)
})
})
}
}