452 lines
16 KiB
Go
452 lines
16 KiB
Go
// Copyright 2024 PingCAP, Inc.
|
|
//
|
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
|
// you may not use this file except in compliance with the License.
|
|
// You may obtain a copy of the License at
|
|
//
|
|
// http://www.apache.org/licenses/LICENSE-2.0
|
|
//
|
|
// Unless required by applicable law or agreed to in writing, software
|
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
// See the License for the specific language governing permissions and
|
|
// limitations under the License.
|
|
|
|
package contextstatic
|
|
|
|
import (
|
|
"fmt"
|
|
"testing"
|
|
"time"
|
|
|
|
"github.com/pingcap/errors"
|
|
"github.com/pingcap/tidb/pkg/errctx"
|
|
"github.com/pingcap/tidb/pkg/expression/context"
|
|
"github.com/pingcap/tidb/pkg/expression/contextopt"
|
|
infoschema "github.com/pingcap/tidb/pkg/infoschema/context"
|
|
"github.com/pingcap/tidb/pkg/parser/auth"
|
|
"github.com/pingcap/tidb/pkg/parser/mysql"
|
|
"github.com/pingcap/tidb/pkg/sessionctx/variable"
|
|
"github.com/pingcap/tidb/pkg/types"
|
|
contextutil "github.com/pingcap/tidb/pkg/util/context"
|
|
"github.com/stretchr/testify/require"
|
|
)
|
|
|
|
func TestNewStaticEvalCtx(t *testing.T) {
|
|
// default context
|
|
prevID := contextutil.GenContextID()
|
|
ctx := NewStaticEvalContext()
|
|
require.Equal(t, prevID+1, ctx.CtxID())
|
|
checkDefaultStaticEvalCtx(t, ctx)
|
|
|
|
// with options
|
|
prevID = ctx.CtxID()
|
|
options, stateForTest := getEvalCtxOptionsForTest(t)
|
|
ctx = NewStaticEvalContext(options...)
|
|
require.Equal(t, prevID+1, ctx.CtxID())
|
|
checkOptionsStaticEvalCtx(t, ctx, stateForTest)
|
|
}
|
|
|
|
func checkDefaultStaticEvalCtx(t *testing.T, ctx *StaticEvalContext) {
|
|
mode, err := mysql.GetSQLMode(mysql.DefaultSQLMode)
|
|
require.NoError(t, err)
|
|
require.Equal(t, mode, ctx.SQLMode())
|
|
require.Same(t, time.UTC, ctx.Location())
|
|
require.Equal(t, types.NewContext(types.StrictFlags, time.UTC, ctx), ctx.TypeCtx())
|
|
require.Equal(t, errctx.NewContextWithLevels(errctx.LevelMap{}, ctx), ctx.ErrCtx())
|
|
require.Equal(t, "", ctx.CurrentDB())
|
|
require.Equal(t, variable.DefMaxAllowedPacket, ctx.GetMaxAllowedPacket())
|
|
require.Equal(t, variable.DefDefaultWeekFormat, ctx.GetDefaultWeekFormatMode())
|
|
require.Equal(t, variable.DefDivPrecisionIncrement, ctx.GetDivPrecisionIncrement())
|
|
require.Nil(t, ctx.requestVerificationFn)
|
|
require.Nil(t, ctx.requestDynamicVerificationFn)
|
|
require.True(t, ctx.RequestVerification("test", "t1", "", mysql.CreatePriv))
|
|
require.True(t, ctx.RequestDynamicVerification("RESTRICTED_USER_ADMIN", true))
|
|
require.True(t, ctx.GetOptionalPropSet().IsEmpty())
|
|
p, ok := ctx.GetOptionalPropProvider(context.OptPropAdvisoryLock)
|
|
require.Nil(t, p)
|
|
require.False(t, ok)
|
|
|
|
tm, err := ctx.CurrentTime()
|
|
require.NoError(t, err)
|
|
require.Same(t, time.UTC, tm.Location())
|
|
require.InDelta(t, time.Now().Unix(), tm.Unix(), 5)
|
|
|
|
warnHandler, ok := ctx.warnHandler.(*contextutil.StaticWarnHandler)
|
|
require.True(t, ok)
|
|
require.Equal(t, 0, warnHandler.WarningCount())
|
|
}
|
|
|
|
type evalCtxOptionsTestState struct {
|
|
now time.Time
|
|
loc *time.Location
|
|
warnHandler *contextutil.StaticWarnHandler
|
|
ddlOwner bool
|
|
privCheckArgs []any
|
|
privRet bool
|
|
}
|
|
|
|
func getEvalCtxOptionsForTest(t *testing.T) ([]StaticEvalCtxOption, *evalCtxOptionsTestState) {
|
|
loc, err := time.LoadLocation("US/Eastern")
|
|
require.NoError(t, err)
|
|
s := &evalCtxOptionsTestState{
|
|
now: time.Now(),
|
|
loc: loc,
|
|
warnHandler: contextutil.NewStaticWarnHandler(8),
|
|
}
|
|
|
|
provider1 := contextopt.CurrentUserPropProvider(func() (*auth.UserIdentity, []*auth.RoleIdentity) {
|
|
return &auth.UserIdentity{Username: "user1", Hostname: "host1"},
|
|
[]*auth.RoleIdentity{{Username: "role1", Hostname: "host2"}}
|
|
})
|
|
|
|
provider2 := contextopt.DDLOwnerInfoProvider(func() bool {
|
|
return s.ddlOwner
|
|
})
|
|
|
|
return []StaticEvalCtxOption{
|
|
WithWarnHandler(s.warnHandler),
|
|
WithSQLMode(mysql.ModeNoZeroDate | mysql.ModeStrictTransTables),
|
|
WithTypeFlags(types.FlagAllowNegativeToUnsigned | types.FlagSkipASCIICheck),
|
|
WithErrLevelMap(errctx.LevelMap{
|
|
errctx.ErrGroupBadNull: errctx.LevelError,
|
|
errctx.ErrGroupDividedByZero: errctx.LevelWarn,
|
|
}),
|
|
WithLocation(loc),
|
|
WithCurrentDB("db1"),
|
|
WithCurrentTime(func() (time.Time, error) {
|
|
return s.now, nil
|
|
}),
|
|
WithMaxAllowedPacket(12345),
|
|
WithDefaultWeekFormatMode("3"),
|
|
WithDivPrecisionIncrement(5),
|
|
WithPrivCheck(func(db, table, column string, priv mysql.PrivilegeType) bool {
|
|
require.Nil(t, s.privCheckArgs)
|
|
s.privCheckArgs = []any{db, table, column, priv}
|
|
return s.privRet
|
|
}),
|
|
WithDynamicPrivCheck(func(privName string, grantable bool) bool {
|
|
require.Nil(t, s.privCheckArgs)
|
|
s.privCheckArgs = []any{privName, grantable}
|
|
return s.privRet
|
|
}),
|
|
WithOptionalProperty(provider1, provider2),
|
|
}, s
|
|
}
|
|
|
|
func checkOptionsStaticEvalCtx(t *testing.T, ctx *StaticEvalContext, s *evalCtxOptionsTestState) {
|
|
require.Same(t, ctx.warnHandler, s.warnHandler)
|
|
require.Equal(t, mysql.ModeNoZeroDate|mysql.ModeStrictTransTables, ctx.SQLMode())
|
|
require.Equal(t,
|
|
types.NewContext(types.FlagAllowNegativeToUnsigned|types.FlagSkipASCIICheck, s.loc, ctx),
|
|
ctx.TypeCtx(),
|
|
)
|
|
require.Equal(t, errctx.NewContextWithLevels(errctx.LevelMap{
|
|
errctx.ErrGroupBadNull: errctx.LevelError,
|
|
errctx.ErrGroupDividedByZero: errctx.LevelWarn,
|
|
}, ctx), ctx.ErrCtx())
|
|
require.Same(t, s.loc, ctx.Location())
|
|
require.Equal(t, "db1", ctx.CurrentDB())
|
|
current, err := ctx.CurrentTime()
|
|
require.NoError(t, err)
|
|
require.Equal(t, current.UnixNano(), s.now.UnixNano())
|
|
require.Same(t, s.loc, current.Location())
|
|
require.Equal(t, uint64(12345), ctx.GetMaxAllowedPacket())
|
|
require.Equal(t, "3", ctx.GetDefaultWeekFormatMode())
|
|
require.Equal(t, 5, ctx.GetDivPrecisionIncrement())
|
|
|
|
s.privCheckArgs, s.privRet = nil, false
|
|
require.False(t, ctx.RequestVerification("db", "table", "column", mysql.CreatePriv))
|
|
require.Equal(t, []any{"db", "table", "column", mysql.CreatePriv}, s.privCheckArgs)
|
|
s.privCheckArgs, s.privRet = nil, true
|
|
require.True(t, ctx.RequestVerification("db2", "table2", "column2", mysql.UpdatePriv))
|
|
require.Equal(t, []any{"db2", "table2", "column2", mysql.UpdatePriv}, s.privCheckArgs)
|
|
s.privCheckArgs, s.privRet = nil, false
|
|
require.False(t, ctx.RequestDynamicVerification("RESTRICTED_USER_ADMIN", true))
|
|
require.Equal(t, []any{"RESTRICTED_USER_ADMIN", true}, s.privCheckArgs)
|
|
s.privCheckArgs, s.privRet = nil, true
|
|
require.True(t, ctx.RequestDynamicVerification("RESTRICTED_TABLES_ADMIN", false))
|
|
require.Equal(t, []any{"RESTRICTED_TABLES_ADMIN", false}, s.privCheckArgs)
|
|
|
|
var optSet context.OptionalEvalPropKeySet
|
|
optSet = optSet.Add(context.OptPropCurrentUser).Add(context.OptPropDDLOwnerInfo)
|
|
require.Equal(t, optSet, ctx.GetOptionalPropSet())
|
|
p, ok := ctx.GetOptionalPropProvider(context.OptPropCurrentUser)
|
|
require.True(t, ok)
|
|
user, roles := p.(contextopt.CurrentUserPropProvider)()
|
|
require.Equal(t, &auth.UserIdentity{Username: "user1", Hostname: "host1"}, user)
|
|
require.Equal(t, []*auth.RoleIdentity{{Username: "role1", Hostname: "host2"}}, roles)
|
|
p, ok = ctx.GetOptionalPropProvider(context.OptPropDDLOwnerInfo)
|
|
s.ddlOwner = true
|
|
require.True(t, ok)
|
|
require.True(t, p.(contextopt.DDLOwnerInfoProvider)())
|
|
s.ddlOwner = false
|
|
require.False(t, p.(contextopt.DDLOwnerInfoProvider)())
|
|
p, ok = ctx.GetOptionalPropProvider(context.OptPropInfoSchema)
|
|
require.False(t, ok)
|
|
require.Nil(t, p)
|
|
}
|
|
|
|
func TestStaticEvalCtxCurrentTime(t *testing.T) {
|
|
loc1, err := time.LoadLocation("US/Eastern")
|
|
require.NoError(t, err)
|
|
|
|
tm := time.UnixMicro(123456789).In(loc1)
|
|
calls := 0
|
|
getTime := func() (time.Time, error) {
|
|
defer func() {
|
|
calls++
|
|
}()
|
|
|
|
if calls < 2 {
|
|
return time.Time{}, errors.NewNoStackError(fmt.Sprintf("err%d", calls))
|
|
}
|
|
|
|
if calls == 2 {
|
|
return tm, nil
|
|
}
|
|
|
|
return time.Time{}, errors.NewNoStackError("should not reach here")
|
|
}
|
|
|
|
ctx := NewStaticEvalContext(WithCurrentTime(getTime))
|
|
|
|
// get time for the first two times should fail
|
|
got, err := ctx.CurrentTime()
|
|
require.EqualError(t, err, "err0")
|
|
require.Equal(t, time.Time{}, got)
|
|
|
|
got, err = ctx.CurrentTime()
|
|
require.EqualError(t, err, "err1")
|
|
require.Equal(t, time.Time{}, got)
|
|
|
|
// the third time will success
|
|
got, err = ctx.CurrentTime()
|
|
require.Nil(t, err)
|
|
require.Equal(t, tm.UnixNano(), got.UnixNano())
|
|
require.Same(t, time.UTC, got.Location())
|
|
require.Equal(t, 3, calls)
|
|
|
|
// next ctx should cache the time without calling inner function
|
|
got, err = ctx.CurrentTime()
|
|
require.Nil(t, err)
|
|
require.Equal(t, tm.UnixNano(), got.UnixNano())
|
|
require.Same(t, time.UTC, got.Location())
|
|
require.Equal(t, 3, calls)
|
|
|
|
// CurrentTime should have the same location with `ctx.Location()`
|
|
loc2, err := time.LoadLocation("Australia/Sydney")
|
|
require.NoError(t, err)
|
|
ctx = NewStaticEvalContext(
|
|
WithLocation(loc2),
|
|
WithCurrentTime(func() (time.Time, error) {
|
|
return tm, nil
|
|
}),
|
|
)
|
|
got, err = ctx.CurrentTime()
|
|
require.NoError(t, err)
|
|
require.Equal(t, tm.UnixNano(), got.UnixNano())
|
|
require.Same(t, loc2, got.Location())
|
|
|
|
// Apply should copy the current time
|
|
ctx2 := ctx.Apply()
|
|
got, err = ctx2.CurrentTime()
|
|
require.NoError(t, err)
|
|
require.Equal(t, tm.UnixNano(), got.UnixNano())
|
|
require.Same(t, loc2, got.Location())
|
|
|
|
// Apply with location should change current time's location
|
|
ctx2 = ctx.Apply(WithLocation(loc1))
|
|
got, err = ctx2.CurrentTime()
|
|
require.NoError(t, err)
|
|
require.Equal(t, tm.UnixNano(), got.UnixNano())
|
|
require.Same(t, loc1, got.Location())
|
|
|
|
// Apply will not affect previous current time
|
|
got, err = ctx.CurrentTime()
|
|
require.NoError(t, err)
|
|
require.Equal(t, tm.UnixNano(), got.UnixNano())
|
|
require.Same(t, loc2, got.Location())
|
|
|
|
// Apply with a different current time func
|
|
ctx2 = ctx.Apply(WithCurrentTime(func() (time.Time, error) {
|
|
return time.UnixMicro(987654321), nil
|
|
}))
|
|
got, err = ctx2.CurrentTime()
|
|
require.NoError(t, err)
|
|
require.Equal(t, int64(987654321), got.UnixMicro())
|
|
require.Same(t, loc2, got.Location())
|
|
|
|
// Apply will not affect previous current time
|
|
got, err = ctx.CurrentTime()
|
|
require.NoError(t, err)
|
|
require.Equal(t, tm.UnixNano(), got.UnixNano())
|
|
require.Same(t, loc2, got.Location())
|
|
}
|
|
|
|
func TestStaticEvalCtxWarnings(t *testing.T) {
|
|
// default context should have a empty StaticWarningsHandler
|
|
ctx := NewStaticEvalContext()
|
|
h, ok := ctx.warnHandler.(*contextutil.StaticWarnHandler)
|
|
require.True(t, ok)
|
|
require.Equal(t, 0, h.WarningCount())
|
|
|
|
// WithWarnHandler should work
|
|
ignoreHandler := contextutil.IgnoreWarn
|
|
ctx = NewStaticEvalContext(WithWarnHandler(ignoreHandler))
|
|
require.True(t, ctx.warnHandler == ignoreHandler)
|
|
|
|
// All contexts should use the same warning handler
|
|
h = contextutil.NewStaticWarnHandler(8)
|
|
ctx = NewStaticEvalContext(WithWarnHandler(h))
|
|
tc, ec := ctx.TypeCtx(), ctx.ErrCtx()
|
|
h.AppendWarning(errors.NewNoStackError("warn0"))
|
|
ctx.AppendWarning(errors.NewNoStackError("warn1"))
|
|
tc.AppendWarning(errors.NewNoStackError("warn2"))
|
|
ec.AppendWarning(errors.NewNoStackError("warn3"))
|
|
require.Equal(t, 4, h.WarningCount())
|
|
require.Equal(t, h.WarningCount(), ctx.WarningCount())
|
|
|
|
// ctx.CopyWarnings
|
|
warnings := ctx.CopyWarnings(nil)
|
|
require.Equal(t, []contextutil.SQLWarn{
|
|
{Level: contextutil.WarnLevelWarning, Err: errors.NewNoStackError("warn0")},
|
|
{Level: contextutil.WarnLevelWarning, Err: errors.NewNoStackError("warn1")},
|
|
{Level: contextutil.WarnLevelWarning, Err: errors.NewNoStackError("warn2")},
|
|
{Level: contextutil.WarnLevelWarning, Err: errors.NewNoStackError("warn3")},
|
|
}, warnings)
|
|
require.Equal(t, 4, h.WarningCount())
|
|
require.Equal(t, h.WarningCount(), ctx.WarningCount())
|
|
|
|
// ctx.TruncateWarnings
|
|
warnings = ctx.TruncateWarnings(2)
|
|
require.Equal(t, []contextutil.SQLWarn{
|
|
{Level: contextutil.WarnLevelWarning, Err: errors.NewNoStackError("warn2")},
|
|
{Level: contextutil.WarnLevelWarning, Err: errors.NewNoStackError("warn3")},
|
|
}, warnings)
|
|
require.Equal(t, 2, h.WarningCount())
|
|
require.Equal(t, h.WarningCount(), ctx.WarningCount())
|
|
warnings = ctx.CopyWarnings(nil)
|
|
require.Equal(t, []contextutil.SQLWarn{
|
|
{Level: contextutil.WarnLevelWarning, Err: errors.NewNoStackError("warn0")},
|
|
{Level: contextutil.WarnLevelWarning, Err: errors.NewNoStackError("warn1")},
|
|
}, warnings)
|
|
|
|
// Apply should use the old warning handler by default
|
|
ctx2 := ctx.Apply()
|
|
require.NotSame(t, ctx, ctx2)
|
|
require.True(t, ctx.warnHandler == ctx2.warnHandler)
|
|
require.True(t, ctx.warnHandler == h)
|
|
|
|
// Apply with `WithWarnHandler`
|
|
h2 := contextutil.NewStaticWarnHandler(16)
|
|
ctx2 = ctx.Apply(WithWarnHandler(h2))
|
|
require.True(t, ctx2.warnHandler == h2)
|
|
require.True(t, ctx.warnHandler == h)
|
|
|
|
// The type context and error context should use the new handler.
|
|
ctx.TruncateWarnings(0)
|
|
tc, ec = ctx.TypeCtx(), ctx.ErrCtx()
|
|
tc2, ec2 := ctx2.TypeCtx(), ctx2.ErrCtx()
|
|
tc2.AppendWarning(errors.NewNoStackError("warn4"))
|
|
ec2.AppendWarning(errors.NewNoStackError("warn5"))
|
|
tc.AppendWarning(errors.NewNoStackError("warn6"))
|
|
ec.AppendWarning(errors.NewNoStackError("warn7"))
|
|
require.Equal(t, []contextutil.SQLWarn{
|
|
{Level: contextutil.WarnLevelWarning, Err: errors.NewNoStackError("warn4")},
|
|
{Level: contextutil.WarnLevelWarning, Err: errors.NewNoStackError("warn5")},
|
|
}, ctx2.CopyWarnings(nil))
|
|
require.Equal(t, []contextutil.SQLWarn{
|
|
{Level: contextutil.WarnLevelWarning, Err: errors.NewNoStackError("warn6")},
|
|
{Level: contextutil.WarnLevelWarning, Err: errors.NewNoStackError("warn7")},
|
|
}, ctx.CopyWarnings(nil))
|
|
}
|
|
|
|
func TestStaticEvalContextOptionalProps(t *testing.T) {
|
|
ctx := NewStaticEvalContext()
|
|
require.True(t, ctx.GetOptionalPropSet().IsEmpty())
|
|
|
|
ctx2 := ctx.Apply(WithOptionalProperty(
|
|
contextopt.CurrentUserPropProvider(func() (u *auth.UserIdentity, r []*auth.RoleIdentity) { return }),
|
|
))
|
|
var emptySet context.OptionalEvalPropKeySet
|
|
require.Equal(t, emptySet, ctx.GetOptionalPropSet())
|
|
require.Equal(t, emptySet.Add(context.OptPropCurrentUser), ctx2.GetOptionalPropSet())
|
|
|
|
// Apply should override all optional properties
|
|
ctx3 := ctx2.Apply(WithOptionalProperty(
|
|
contextopt.DDLOwnerInfoProvider(func() bool { return true }),
|
|
contextopt.InfoSchemaPropProvider(func(isDomain bool) infoschema.MetaOnlyInfoSchema { return nil }),
|
|
))
|
|
require.Equal(t,
|
|
emptySet.Add(context.OptPropDDLOwnerInfo).Add(context.OptPropInfoSchema),
|
|
ctx3.GetOptionalPropSet(),
|
|
)
|
|
require.Equal(t, emptySet, ctx.GetOptionalPropSet())
|
|
require.Equal(t, emptySet.Add(context.OptPropCurrentUser), ctx2.GetOptionalPropSet())
|
|
}
|
|
|
|
func TestUpdateStaticEvalContext(t *testing.T) {
|
|
oldCtx := NewStaticEvalContext()
|
|
ctx := oldCtx.Apply()
|
|
|
|
// Should return a different instance
|
|
require.NotSame(t, oldCtx, ctx)
|
|
|
|
// CtxID should be different
|
|
require.Greater(t, ctx.CtxID(), oldCtx.CtxID())
|
|
|
|
// inner state should not be the same address
|
|
require.NotSame(t, &oldCtx.staticEvalCtxState, &ctx.staticEvalCtxState)
|
|
|
|
// compare a state object by excluding some changed fields
|
|
excludeChangedFields := func(s *staticEvalCtxState) staticEvalCtxState {
|
|
state := *s
|
|
state.typeCtx = types.DefaultStmtNoWarningContext
|
|
state.errCtx = errctx.StrictNoWarningContext
|
|
state.currentTime = nil
|
|
return state
|
|
}
|
|
require.Equal(t, excludeChangedFields(&oldCtx.staticEvalCtxState), excludeChangedFields(&ctx.staticEvalCtxState))
|
|
|
|
// check fields
|
|
checkDefaultStaticEvalCtx(t, ctx)
|
|
|
|
// apply options
|
|
opts, optState := getEvalCtxOptionsForTest(t)
|
|
ctx2 := oldCtx.Apply(opts...)
|
|
require.Greater(t, ctx2.CtxID(), ctx.CtxID())
|
|
checkOptionsStaticEvalCtx(t, ctx2, optState)
|
|
|
|
// old ctx aren't affected
|
|
checkDefaultStaticEvalCtx(t, oldCtx)
|
|
|
|
// create with options
|
|
opts, optState = getEvalCtxOptionsForTest(t)
|
|
ctx3 := NewStaticEvalContext(opts...)
|
|
require.Greater(t, ctx3.CtxID(), ctx2.CtxID())
|
|
checkOptionsStaticEvalCtx(t, ctx3, optState)
|
|
}
|
|
|
|
func TestParamList(t *testing.T) {
|
|
paramList := variable.NewPlanCacheParamList()
|
|
paramList.Append(types.NewDatum(1))
|
|
paramList.Append(types.NewDatum(2))
|
|
paramList.Append(types.NewDatum(3))
|
|
ctx := NewStaticEvalContext(
|
|
WithParamList(paramList),
|
|
)
|
|
for i := 0; i < 3; i++ {
|
|
val := ctx.GetParamValue(i)
|
|
require.Equal(t, int64(i+1), val.GetInt64())
|
|
}
|
|
|
|
// after reset the paramList and append new one, the value is still persisted
|
|
paramList.Reset()
|
|
paramList.Append(types.NewDatum(4))
|
|
for i := 0; i < 3; i++ {
|
|
val := ctx.GetParamValue(i)
|
|
require.Equal(t, int64(i+1), val.GetInt64())
|
|
}
|
|
}
|