334 lines
12 KiB
Go
334 lines
12 KiB
Go
// Copyright 2024 PingCAP, Inc.
|
|
//
|
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
|
// you may not use this file except in compliance with the License.
|
|
// You may obtain a copy of the License at
|
|
//
|
|
// http://www.apache.org/licenses/LICENSE-2.0
|
|
//
|
|
// Unless required by applicable law or agreed to in writing, software
|
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
// See the License for the specific language governing permissions and
|
|
// limitations under the License.
|
|
|
|
package sessionexpr_test
|
|
|
|
import (
|
|
"sync/atomic"
|
|
"testing"
|
|
"time"
|
|
|
|
"github.com/pingcap/errors"
|
|
"github.com/pingcap/tidb/pkg/errctx"
|
|
"github.com/pingcap/tidb/pkg/expression/exprctx"
|
|
"github.com/pingcap/tidb/pkg/expression/expropt"
|
|
"github.com/pingcap/tidb/pkg/expression/sessionexpr"
|
|
"github.com/pingcap/tidb/pkg/parser/auth"
|
|
"github.com/pingcap/tidb/pkg/parser/mysql"
|
|
"github.com/pingcap/tidb/pkg/privilege"
|
|
"github.com/pingcap/tidb/pkg/types"
|
|
contextutil "github.com/pingcap/tidb/pkg/util/context"
|
|
"github.com/pingcap/tidb/pkg/util/mathutil"
|
|
"github.com/pingcap/tidb/pkg/util/mock"
|
|
tmock "github.com/stretchr/testify/mock"
|
|
"github.com/stretchr/testify/require"
|
|
"github.com/tikv/client-go/v2/oracle"
|
|
)
|
|
|
|
func TestSessionEvalContextBasic(t *testing.T) {
|
|
ctx := mock.NewContext()
|
|
vars := ctx.GetSessionVars()
|
|
sc := vars.StmtCtx
|
|
impl := sessionexpr.NewEvalContext(ctx)
|
|
require.True(t, impl.GetOptionalPropSet().IsFull())
|
|
|
|
// should contain all the optional properties
|
|
for i := range exprctx.OptPropsCnt {
|
|
provider, ok := impl.GetOptionalPropProvider(exprctx.OptionalEvalPropKey(i))
|
|
require.True(t, ok)
|
|
require.NotNil(t, provider)
|
|
require.Same(t, exprctx.OptionalEvalPropKey(i).Desc(), provider.Desc())
|
|
}
|
|
|
|
ctx.ResetSessionAndStmtTimeZone(time.FixedZone("UTC+11", 11*3600))
|
|
vars.SQLMode = mysql.ModeStrictTransTables | mysql.ModeNoZeroDate
|
|
sc.SetTypeFlags(types.FlagIgnoreInvalidDateErr | types.FlagSkipUTF8Check)
|
|
sc.SetErrLevels(errctx.LevelMap{
|
|
errctx.ErrGroupDupKey: errctx.LevelWarn,
|
|
errctx.ErrGroupBadNull: errctx.LevelIgnore,
|
|
errctx.ErrGroupNoDefault: errctx.LevelIgnore,
|
|
})
|
|
vars.CurrentDB = "db1"
|
|
vars.MaxAllowedPacket = 123456
|
|
|
|
// basic fields
|
|
tc, ec := impl.TypeCtx(), sc.ErrCtx()
|
|
require.Equal(t, tc, sc.TypeCtx())
|
|
require.Equal(t, ec, impl.ErrCtx())
|
|
require.Equal(t, vars.SQLMode, impl.SQLMode())
|
|
require.Same(t, vars.Location(), impl.Location())
|
|
require.Same(t, sc.TimeZone(), impl.Location())
|
|
require.Same(t, tc.Location(), impl.Location())
|
|
require.Equal(t, "db1", impl.CurrentDB())
|
|
require.Equal(t, uint64(123456), impl.GetMaxAllowedPacket())
|
|
require.Equal(t, "0", impl.GetDefaultWeekFormatMode())
|
|
require.NoError(t, ctx.GetSessionVars().SetSystemVar("default_week_format", "5"))
|
|
require.Equal(t, "5", impl.GetDefaultWeekFormatMode())
|
|
require.Same(t, vars.UserVars, impl.GetUserVarsReader())
|
|
|
|
// handle warnings
|
|
require.Equal(t, 0, impl.WarningCount())
|
|
impl.AppendWarning(errors.New("err1"))
|
|
require.Equal(t, 1, impl.WarningCount())
|
|
tc.AppendWarning(errors.New("err2"))
|
|
require.Equal(t, 2, impl.WarningCount())
|
|
ec.AppendWarning(errors.New("err3"))
|
|
require.Equal(t, 3, impl.WarningCount())
|
|
|
|
for _, dst := range [][]contextutil.SQLWarn{
|
|
nil,
|
|
make([]contextutil.SQLWarn, 1),
|
|
make([]contextutil.SQLWarn, 3),
|
|
make([]contextutil.SQLWarn, 0, 3),
|
|
} {
|
|
warnings := impl.CopyWarnings(dst)
|
|
require.Equal(t, 3, len(warnings))
|
|
require.Equal(t, contextutil.WarnLevelWarning, warnings[0].Level)
|
|
require.Equal(t, contextutil.WarnLevelWarning, warnings[1].Level)
|
|
require.Equal(t, contextutil.WarnLevelWarning, warnings[2].Level)
|
|
require.Equal(t, "err1", warnings[0].Err.Error())
|
|
require.Equal(t, "err2", warnings[1].Err.Error())
|
|
require.Equal(t, "err3", warnings[2].Err.Error())
|
|
}
|
|
|
|
warnings := impl.TruncateWarnings(1)
|
|
require.Equal(t, 2, len(warnings))
|
|
require.Equal(t, contextutil.WarnLevelWarning, warnings[0].Level)
|
|
require.Equal(t, contextutil.WarnLevelWarning, warnings[1].Level)
|
|
require.Equal(t, "err2", warnings[0].Err.Error())
|
|
require.Equal(t, "err3", warnings[1].Err.Error())
|
|
|
|
warnings = impl.TruncateWarnings(0)
|
|
require.Equal(t, 1, len(warnings))
|
|
require.Equal(t, contextutil.WarnLevelWarning, warnings[0].Level)
|
|
require.Equal(t, "err1", warnings[0].Err.Error())
|
|
}
|
|
|
|
func TestSessionEvalContextCurrentTime(t *testing.T) {
|
|
ctx := mock.NewContext()
|
|
vars := ctx.GetSessionVars()
|
|
sc := vars.StmtCtx
|
|
impl := sessionexpr.NewEvalContext(ctx)
|
|
|
|
var now atomic.Pointer[time.Time]
|
|
sc.SetStaleTSOProvider(func() (uint64, error) {
|
|
v := time.UnixMilli(123456789)
|
|
// should only be called once
|
|
require.True(t, now.CompareAndSwap(nil, &v))
|
|
return oracle.GoTimeToTS(v), nil
|
|
})
|
|
|
|
// now should return the stable TSO if set
|
|
tm, err := impl.CurrentTime()
|
|
require.NoError(t, err)
|
|
v := now.Load()
|
|
require.NotNil(t, v)
|
|
require.Equal(t, v.UnixNano(), tm.UnixNano())
|
|
|
|
// The second call should return the same value
|
|
tm, err = impl.CurrentTime()
|
|
require.NoError(t, err)
|
|
require.Equal(t, v.UnixNano(), tm.UnixNano())
|
|
|
|
// now should return the system variable if "timestamp" is set
|
|
sc.SetStaleTSOProvider(nil)
|
|
sc.Reset()
|
|
require.NoError(t, vars.SetSystemVar("timestamp", "7654321.875"))
|
|
tm, err = impl.CurrentTime()
|
|
require.NoError(t, err)
|
|
require.Equal(t, int64(7654321_875_000_000), tm.UnixNano())
|
|
|
|
// The second call should return the same value
|
|
tm, err = impl.CurrentTime()
|
|
require.NoError(t, err)
|
|
require.Equal(t, int64(7654321_875_000_000), tm.UnixNano())
|
|
|
|
// now should return the system current time if not stale TSO or "timestamp" is set
|
|
require.NoError(t, vars.SetSystemVar("timestamp", "0"))
|
|
sc.Reset()
|
|
tm, err = impl.CurrentTime()
|
|
require.NoError(t, err)
|
|
require.InDelta(t, time.Now().Unix(), tm.Unix(), 5)
|
|
|
|
// The second call should return the same value
|
|
tm2, err := impl.CurrentTime()
|
|
require.NoError(t, err)
|
|
require.Equal(t, tm.UnixNano(), tm2.UnixNano())
|
|
}
|
|
|
|
type mockPrivManager struct {
|
|
tmock.Mock
|
|
privilege.Manager
|
|
}
|
|
|
|
func (m *mockPrivManager) RequestVerification(
|
|
activeRole []*auth.RoleIdentity, db, table, column string, priv mysql.PrivilegeType,
|
|
) bool {
|
|
return m.Called(activeRole, db, table, column, priv).Bool(0)
|
|
}
|
|
|
|
func (m *mockPrivManager) RequestDynamicVerification(
|
|
activeRoles []*auth.RoleIdentity, privName string, grantable bool,
|
|
) bool {
|
|
return m.Called(activeRoles, privName, grantable).Bool(0)
|
|
}
|
|
|
|
func TestSessionEvalContextPrivilegeCheck(t *testing.T) {
|
|
ctx := mock.NewContext()
|
|
impl := sessionexpr.NewEvalContext(ctx)
|
|
activeRoles := []*auth.RoleIdentity{
|
|
{Username: "role1", Hostname: "host1"},
|
|
{Username: "role2", Hostname: "host2"},
|
|
}
|
|
ctx.GetSessionVars().ActiveRoles = activeRoles
|
|
|
|
// no privilege manager should always return true for privilege check
|
|
privilege.BindPrivilegeManager(ctx, nil)
|
|
require.True(t, impl.RequestVerification("test", "tbl1", "col1", mysql.SuperPriv))
|
|
require.True(t, impl.RequestDynamicVerification("RESTRICTED_TABLES_ADMIN", true))
|
|
require.True(t, impl.RequestDynamicVerification("RESTRICTED_TABLES_ADMIN", false))
|
|
|
|
// if privilege manager bound, it should return the privilege manager value
|
|
mgr := &mockPrivManager{}
|
|
privilege.BindPrivilegeManager(ctx, mgr)
|
|
mgr.On("RequestVerification", activeRoles, "db1", "t1", "c1", mysql.CreatePriv).
|
|
Return(true).Once()
|
|
require.True(t, impl.RequestVerification("db1", "t1", "c1", mysql.CreatePriv))
|
|
mgr.AssertExpectations(t)
|
|
|
|
mgr.On("RequestVerification", activeRoles, "db2", "t2", "c2", mysql.SuperPriv).
|
|
Return(false).Once()
|
|
require.False(t, impl.RequestVerification("db2", "t2", "c2", mysql.SuperPriv))
|
|
mgr.AssertExpectations(t)
|
|
|
|
mgr.On("RequestDynamicVerification", activeRoles, "RESTRICTED_USER_ADMIN", false).
|
|
Return(true).Once()
|
|
require.True(t, impl.RequestDynamicVerification("RESTRICTED_USER_ADMIN", false))
|
|
|
|
mgr.On("RequestDynamicVerification", activeRoles, "RESTRICTED_CONNECTION_ADMIN", true).
|
|
Return(false).Once()
|
|
require.False(t, impl.RequestDynamicVerification("RESTRICTED_CONNECTION_ADMIN", true))
|
|
}
|
|
|
|
func getProvider[T exprctx.OptionalEvalPropProvider](
|
|
t *testing.T,
|
|
impl *sessionexpr.EvalContext,
|
|
key exprctx.OptionalEvalPropKey,
|
|
) T {
|
|
val, ok := impl.GetOptionalPropProvider(key)
|
|
require.True(t, ok)
|
|
p, ok := val.(T)
|
|
require.True(t, ok)
|
|
require.Equal(t, key, p.Desc().Key())
|
|
return p
|
|
}
|
|
|
|
func TestSessionEvalContextOptProps(t *testing.T) {
|
|
ctx := mock.NewContext()
|
|
impl := sessionexpr.NewEvalContext(ctx)
|
|
|
|
// test for OptPropCurrentUser
|
|
ctx.GetSessionVars().User = &auth.UserIdentity{Username: "user1", Hostname: "host1"}
|
|
ctx.GetSessionVars().ActiveRoles = []*auth.RoleIdentity{
|
|
{Username: "role1", Hostname: "host1"},
|
|
{Username: "role2", Hostname: "host2"},
|
|
}
|
|
user, roles := getProvider[expropt.CurrentUserPropProvider](t, impl, exprctx.OptPropCurrentUser)()
|
|
require.Equal(t, ctx.GetSessionVars().User, user)
|
|
require.Equal(t, ctx.GetSessionVars().ActiveRoles, roles)
|
|
|
|
// test for OptPropSessionVars
|
|
sessVarsProvider := getProvider[*expropt.SessionVarsPropProvider](t, impl, exprctx.OptPropSessionVars)
|
|
require.NotNil(t, sessVarsProvider)
|
|
gotVars, err := expropt.SessionVarsPropReader{}.GetSessionVars(impl)
|
|
require.NoError(t, err)
|
|
require.Same(t, ctx.GetSessionVars(), gotVars)
|
|
|
|
// test for OptPropAdvisoryLock
|
|
lockProvider := getProvider[*expropt.AdvisoryLockPropProvider](t, impl, exprctx.OptPropAdvisoryLock)
|
|
gotCtx, ok := lockProvider.AdvisoryLockContext.(*mock.Context)
|
|
require.True(t, ok)
|
|
require.Same(t, ctx, gotCtx)
|
|
|
|
// test for OptPropDDLOwnerInfo
|
|
ddlInfoProvider := getProvider[expropt.DDLOwnerInfoProvider](t, impl, exprctx.OptPropDDLOwnerInfo)
|
|
require.False(t, ddlInfoProvider())
|
|
ctx.SetIsDDLOwner(true)
|
|
require.True(t, ddlInfoProvider())
|
|
|
|
// test for OptPropPrivilegeChecker
|
|
privCheckerProvider := getProvider[expropt.PrivilegeCheckerProvider](t, impl, exprctx.OptPropPrivilegeChecker)
|
|
privChecker := privCheckerProvider()
|
|
require.NotNil(t, privChecker)
|
|
require.Same(t, impl, privChecker)
|
|
}
|
|
|
|
func TestSessionBuildContext(t *testing.T) {
|
|
ctx := mock.NewContext()
|
|
impl := sessionexpr.NewExprContext(ctx)
|
|
evalCtx, ok := impl.GetEvalCtx().(*sessionexpr.EvalContext)
|
|
require.True(t, ok)
|
|
require.Same(t, evalCtx, impl.EvalContext)
|
|
require.True(t, evalCtx.GetOptionalPropSet().IsFull())
|
|
require.Same(t, ctx, evalCtx.Sctx())
|
|
|
|
// charset and collation
|
|
vars := ctx.GetSessionVars()
|
|
err := vars.SetSystemVar("character_set_connection", "gbk")
|
|
require.NoError(t, err)
|
|
err = vars.SetSystemVar("collation_connection", "gbk_chinese_ci")
|
|
require.NoError(t, err)
|
|
vars.DefaultCollationForUTF8MB4 = "utf8mb4_0900_ai_ci"
|
|
|
|
charset, collate := impl.GetCharsetInfo()
|
|
require.Equal(t, "gbk", charset)
|
|
require.Equal(t, "gbk_chinese_ci", collate)
|
|
require.Equal(t, "utf8mb4_0900_ai_ci", impl.GetDefaultCollationForUTF8MB4())
|
|
|
|
// SysdateIsNow
|
|
vars.SysdateIsNow = true
|
|
require.True(t, impl.GetSysdateIsNow())
|
|
|
|
// NoopFuncsMode
|
|
vars.NoopFuncsMode = 2
|
|
require.Equal(t, 2, impl.GetNoopFuncsMode())
|
|
|
|
// Rng
|
|
vars.Rng = mathutil.NewWithSeed(123)
|
|
require.Same(t, vars.Rng, impl.Rng())
|
|
|
|
// PlanCache
|
|
vars.StmtCtx.EnablePlanCache()
|
|
require.True(t, impl.IsUseCache())
|
|
impl.SetSkipPlanCache("mockReason")
|
|
require.False(t, impl.IsUseCache())
|
|
|
|
// Alloc column id
|
|
prevID := vars.PlanColumnID.Load()
|
|
colID := impl.AllocPlanColumnID()
|
|
require.Equal(t, colID, prevID+1)
|
|
colID = impl.AllocPlanColumnID()
|
|
require.Equal(t, colID, prevID+2)
|
|
vars.AllocPlanColumnID()
|
|
colID = impl.AllocPlanColumnID()
|
|
require.Equal(t, colID, prevID+4)
|
|
|
|
// InNullRejectCheck
|
|
require.False(t, impl.IsInNullRejectCheck())
|
|
|
|
// ConnID
|
|
vars.ConnectionID = 123
|
|
require.Equal(t, uint64(123), impl.ConnectionID())
|
|
}
|