Files
tidb/pkg/session/syssession/pool_test.go
2025-05-07 22:47:30 +00:00

434 lines
11 KiB
Go

// Copyright 2025 PingCAP, Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package syssession
import (
"sync/atomic"
"testing"
"github.com/pingcap/errors"
"github.com/stretchr/testify/mock"
"github.com/stretchr/testify/require"
)
type mockSessionFactory struct {
mock.Mock
}
func (f *mockSessionFactory) create() (SessionContext, error) {
args := f.Called()
if args.Get(0) == nil {
return nil, args.Error(1)
}
return args.Get(0).(SessionContext), args.Error(1)
}
func TestNewSessionPool(t *testing.T) {
factory := func() (SessionContext, error) {
return &mockSessionContext{}, nil
}
p := NewAdvancedSessionPool(128, factory)
require.NotNil(t, p)
require.Equal(t, 128, cap(p.pool))
require.Equal(t, 0, len(p.pool))
require.False(t, p.IsClosed())
require.NotNil(t, p.ctx)
require.NoError(t, p.ctx.Err())
// pool with PoolMaxSize
p = NewAdvancedSessionPool(PoolMaxSize, factory)
require.Equal(t, PoolMaxSize, cap(p.pool))
require.False(t, p.IsClosed())
// pool with zero-size
WithSuppressAssert(func() {
p = NewAdvancedSessionPool(0, factory)
require.Equal(t, PoolMaxSize, cap(p.pool))
require.False(t, p.IsClosed())
})
// test pool size limit
WithSuppressAssert(func() {
p = NewAdvancedSessionPool(PoolMaxSize+1, factory)
require.Equal(t, PoolMaxSize, cap(p.pool))
require.False(t, p.IsClosed())
})
WithSuppressAssert(func() {
p = NewAdvancedSessionPool(-1, factory)
require.Equal(t, PoolMaxSize, cap(p.pool))
require.False(t, p.IsClosed())
})
}
func TestSessionPoolGet(t *testing.T) {
mockFactory := &mockSessionFactory{}
p := NewAdvancedSessionPool(128, mockFactory.create)
// get a new Session from pool
sctx := &mockSessionContext{}
mockFactory.On("create").Return(sctx, nil).Once()
se, err := p.Get()
require.NoError(t, err)
require.Same(t, se, se.internal.Owner())
require.False(t, se.internal.IsClosed())
require.Zero(t, se.internal.Inuse())
mockFactory.AssertExpectations(t)
sctx.AssertExpectations(t)
// reuse the session
sctx.MockNoPendingTxn()
sctx.MockResetState(p.ctx, "")
p.Put(se)
require.Equal(t, 1, len(p.pool))
sctx.AssertExpectations(t)
se2, err := p.Get()
require.NoError(t, err)
require.NotSame(t, se, se2)
require.Equal(t, 0, len(p.pool))
require.Same(t, se.internal, se2.internal)
require.Same(t, se2, se2.internal.Owner())
require.False(t, se2.internal.IsClosed())
require.Zero(t, se2.internal.Inuse())
mockFactory.AssertExpectations(t)
sctx.AssertExpectations(t)
// factory returns error
mockFactory.On("create").Return(nil, errors.New("mockErr")).Once()
se, err = p.Get()
require.EqualError(t, err, "mockErr")
require.Nil(t, se)
mockFactory.AssertExpectations(t)
// get session from a closed pool
p.Close()
se, err = p.Get()
require.EqualError(t, err, "session pool closed")
require.Nil(t, se)
}
func TestSessionPoolPut(t *testing.T) {
mockFactory := &mockSessionFactory{}
poolCap := 4
p := NewAdvancedSessionPool(poolCap, mockFactory.create)
require.Equal(t, 4, cap(p.pool))
// Put invalid Session
WithSuppressAssert(func() {
p.Put(nil)
p.Put(&Session{})
require.Equal(t, 0, len(p.pool))
})
getCachedSessionFromPool := func(sctx *mockSessionContext) *Session {
se, err := p.Get()
require.NoError(t, err)
require.Same(t, se, se.internal.Owner())
require.True(t, se.IsOwner())
mockFactory.AssertExpectations(t)
sctx.AssertExpectations(t)
return se
}
getNewSessionFromPool := func(sctx *mockSessionContext) *Session {
mockFactory.On("create").Return(sctx, nil).Once()
se, err := p.Get()
require.NoError(t, err)
require.Same(t, se, se.internal.Owner())
require.True(t, se.IsOwner())
mockFactory.AssertExpectations(t)
sctx.AssertExpectations(t)
return se
}
// Put a normal session
sctx := &mockSessionContext{}
se := getNewSessionFromPool(sctx)
sctx.MockResetState(p.ctx, "")
sctx.MockNoPendingTxn()
p.Put(se)
require.Same(t, p, se.internal.Owner())
require.False(t, se.IsOwner())
require.False(t, se.IsInternalClosed())
mockFactory.AssertExpectations(t)
sctx.AssertExpectations(t)
require.Equal(t, 1, len(p.pool))
// Get a cached Session and put the old one that is not the owner
se2 := getCachedSessionFromPool(sctx)
require.Equal(t, 0, len(p.pool))
p.Put(se)
require.Equal(t, 0, len(p.pool))
// Put a Session that is the owner
sctx.MockNoPendingTxn()
sctx.MockResetState(p.ctx, "")
p.Put(se2)
require.Same(t, p, se2.internal.Owner())
mockFactory.AssertExpectations(t)
sctx.AssertExpectations(t)
require.Equal(t, 1, len(p.pool))
// Put a Session again takes no effect
p.Put(se2)
require.Equal(t, 1, len(p.pool))
require.Same(t, p, se2.internal.Owner())
// Put a Session that is inUse
se = getCachedSessionFromPool(sctx)
require.Equal(t, 0, len(p.pool))
_, exit, err := se.internal.EnterOperation(se, false)
require.NoError(t, err)
WithSuppressAssert(func() {
p.Put(se)
})
require.True(t, se.internal.IsClosed())
require.False(t, se.IsOwner())
require.True(t, se.IsInternalClosed())
require.Equal(t, 0, len(p.pool))
sctx.On("Close").Once()
WithSuppressAssert(exit)
sctx.AssertExpectations(t)
// Put a Session that avoids reusing
se = getNewSessionFromPool(sctx)
require.Equal(t, 0, len(p.pool))
se.internal.avoidReuse = true
sctx.On("Close").Once()
p.Put(se)
require.True(t, se.internal.IsClosed())
require.Equal(t, 0, len(p.pool))
sctx.AssertExpectations(t)
// Put a Session that has pending txn
se = getNewSessionFromPool(sctx)
sctx.On("GetPreparedTxnFuture").Return(&mockPreparedFuture{}).Once()
sctx.On("Close").Once()
WithSuppressAssert(func() {
p.Put(se)
})
require.True(t, se.internal.IsClosed())
require.Equal(t, 0, len(p.pool))
sctx.AssertExpectations(t)
// Put a Session but `CheckPendingTxn` panics
se = getNewSessionFromPool(sctx)
sctx.On("GetPreparedTxnFuture").Panic("txnFuturePanic").Once()
sctx.On("Close").Once()
WithSuppressAssert(func() {
require.PanicsWithValue(t, "txnFuturePanic", func() {
p.Put(se)
})
})
require.True(t, se.internal.IsClosed())
require.Equal(t, 0, len(p.pool))
sctx.AssertExpectations(t)
// Put a Session but `OwnerResetState` panics
se = getNewSessionFromPool(sctx)
sctx.MockNoPendingTxn()
sctx.MockResetState(p.ctx, "resetStatePanic")
sctx.On("Close").Once()
WithSuppressAssert(func() {
require.PanicsWithValue(t, "resetStatePanic", func() {
p.Put(se)
})
})
require.True(t, se.internal.IsClosed())
require.Equal(t, 0, len(p.pool))
sctx.AssertExpectations(t)
// Put a closed session
se = getNewSessionFromPool(sctx)
require.Equal(t, 0, len(p.pool))
require.False(t, se.internal.IsClosed())
sctx.On("Close").Once()
se.Close()
require.True(t, se.internal.IsClosed())
p.Put(se)
require.Equal(t, 0, len(p.pool))
sctx.AssertExpectations(t)
// put a full pool
sessions := make([]*Session, poolCap+2)
for i := 0; i <= poolCap+1; i++ {
sctx = &mockSessionContext{}
se = getNewSessionFromPool(sctx)
sessions[i] = se
}
for i := range poolCap {
require.Equal(t, i, len(p.pool))
sctx = sessions[i].internal.sctx.(*mockSessionContext)
sctx.MockNoPendingTxn()
sctx.MockResetState(p.ctx, "")
p.Put(sessions[i])
require.Equal(t, i+1, len(p.pool))
require.Same(t, p, sessions[i].internal.Owner())
sctx.AssertExpectations(t)
}
se = sessions[poolCap]
sctx = se.internal.sctx.(*mockSessionContext)
sctx.MockNoPendingTxn()
sctx.MockResetState(p.ctx, "")
sctx.On("Close").Once()
p.Put(se)
require.Equal(t, poolCap, len(p.pool))
require.Nil(t, se.internal.Owner())
require.True(t, se.internal.IsClosed())
sctx.AssertExpectations(t)
// put a closed pool
for i := range poolCap {
sctx = sessions[i].internal.sctx.(*mockSessionContext)
sctx.On("Close").Once()
}
p.Close()
require.True(t, p.IsClosed())
require.Equal(t, 0, len(p.pool))
for i := range poolCap {
sctx = sessions[i].internal.sctx.(*mockSessionContext)
sctx.AssertExpectations(t)
}
se = sessions[poolCap+1]
sctx = se.internal.sctx.(*mockSessionContext)
sctx.MockNoPendingTxn()
sctx.MockResetState(p.ctx, "")
sctx.On("Close").Once()
p.Put(se)
require.Nil(t, se.internal.Owner())
require.True(t, se.internal.IsClosed())
sctx.AssertExpectations(t)
}
func TestSessionPoolWithSession(t *testing.T) {
factory := &mockSessionFactory{}
capacity := 8
sctx := &mockSessionContext{}
p := NewAdvancedSessionPool(capacity, factory.create)
var called atomic.Bool
fn := func(err error, panicS string) func(*Session) error {
return func(se *Session) error {
factory.AssertExpectations(t)
sctx.AssertExpectations(t)
require.Zero(t, len(p.pool))
require.True(t, called.CompareAndSwap(false, true))
if panicS != "" {
sctx.On("Close").Once()
panic(panicS)
}
if err != nil {
sctx.On("Close").Once()
return err
}
sctx.MockNoPendingTxn()
sctx.MockResetState(p.ctx, "")
return nil
}
}
// success case
require.Zero(t, len(p.pool))
factory.On("create").Return(sctx, nil).Once()
err := p.WithSession(fn(nil, ""))
require.Nil(t, err)
require.True(t, called.CompareAndSwap(true, false))
sctx.AssertExpectations(t)
require.Equal(t, 1, len(p.pool))
// error case
err = p.WithSession(fn(errors.New("mockErr1"), ""))
require.EqualError(t, err, "mockErr1")
require.True(t, called.CompareAndSwap(true, false))
sctx.AssertExpectations(t)
require.Zero(t, len(p.pool))
// panic case
factory.On("create").Return(sctx, nil).Once()
require.PanicsWithValue(t, "mockPanic1", func() {
_ = p.WithSession(fn(nil, "mockPanic1"))
})
require.True(t, called.CompareAndSwap(true, false))
sctx.AssertExpectations(t)
require.Zero(t, len(p.pool))
// p.Get returns error, the function should not be called
factory.On("create").Return(nil, errors.New("mockErr2")).Once()
err = p.WithSession(func(*Session) error {
require.FailNow(t, "should not be called")
return nil
})
require.EqualError(t, err, "mockErr2")
factory.AssertExpectations(t)
require.Zero(t, len(p.pool))
}
func TestSessionPoolClose(t *testing.T) {
factory := &mockSessionFactory{}
capacity := 8
p := NewAdvancedSessionPool(capacity, factory.create)
// make a pool with some sessions
sctxs := make([]*mockSessionContext, capacity)
ses := make([]*Session, capacity)
for i := range capacity {
sctx := &mockSessionContext{}
sctxs[i] = sctx
factory.On("create").Return(sctx, nil).Once()
se, err := p.Get()
require.NoError(t, err)
ses[i] = se
factory.AssertExpectations(t)
sctx.AssertExpectations(t)
}
for i := range capacity {
sctx := sctxs[i]
se := ses[i]
sctx.MockNoPendingTxn()
sctx.MockResetState(p.ctx, "")
p.Put(se)
sctx.AssertExpectations(t)
}
// close pool should close all sessions in it
for i := range capacity {
sctx := sctxs[i]
sctx.On("Close").Once()
}
p.Close()
require.True(t, p.IsClosed())
require.Error(t, p.ctx.Err())
require.Equal(t, 0, len(p.pool))
select {
case _, ok := <-p.pool:
require.False(t, ok)
default:
require.FailNow(t, "pool is still active")
}
for i := range capacity {
sctx := sctxs[i]
sctx.AssertExpectations(t)
}
// close a closed pool should take no effect
p.Close()
require.True(t, p.IsClosed())
}