679 lines
23 KiB
Go
679 lines
23 KiB
Go
// Copyright 2024 PingCAP, Inc.
|
|
//
|
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
|
// you may not use this file except in compliance with the License.
|
|
// You may obtain a copy of the License at
|
|
//
|
|
// http://www.apache.org/licenses/LICENSE-2.0
|
|
//
|
|
// Unless required by applicable law or agreed to in writing, software
|
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
// See the License for the specific language governing permissions and
|
|
// limitations under the License.
|
|
|
|
package exprstatic
|
|
|
|
import (
|
|
"context"
|
|
"fmt"
|
|
"strings"
|
|
"testing"
|
|
"time"
|
|
|
|
"github.com/pingcap/errors"
|
|
"github.com/pingcap/tidb/pkg/errctx"
|
|
"github.com/pingcap/tidb/pkg/expression/exprctx"
|
|
"github.com/pingcap/tidb/pkg/expression/expropt"
|
|
infoschema "github.com/pingcap/tidb/pkg/infoschema/context"
|
|
"github.com/pingcap/tidb/pkg/parser/auth"
|
|
"github.com/pingcap/tidb/pkg/parser/mysql"
|
|
"github.com/pingcap/tidb/pkg/sessionctx/vardef"
|
|
"github.com/pingcap/tidb/pkg/sessionctx/variable"
|
|
"github.com/pingcap/tidb/pkg/types"
|
|
contextutil "github.com/pingcap/tidb/pkg/util/context"
|
|
"github.com/pingcap/tidb/pkg/util/deeptest"
|
|
"github.com/stretchr/testify/require"
|
|
)
|
|
|
|
func TestNewStaticEvalCtx(t *testing.T) {
|
|
// default context
|
|
prevID := contextutil.GenContextID()
|
|
ctx := NewEvalContext()
|
|
require.Equal(t, prevID+1, ctx.CtxID())
|
|
checkDefaultStaticEvalCtx(t, ctx)
|
|
|
|
// with options
|
|
prevID = ctx.CtxID()
|
|
options, stateForTest := getEvalCtxOptionsForTest(t)
|
|
ctx = NewEvalContext(options...)
|
|
require.Equal(t, prevID+1, ctx.CtxID())
|
|
checkOptionsStaticEvalCtx(t, ctx, stateForTest)
|
|
}
|
|
|
|
func checkDefaultStaticEvalCtx(t *testing.T, ctx *EvalContext) {
|
|
mode, err := mysql.GetSQLMode(mysql.DefaultSQLMode)
|
|
require.NoError(t, err)
|
|
require.Equal(t, mode, ctx.SQLMode())
|
|
require.Same(t, time.UTC, ctx.Location())
|
|
require.Equal(t, types.NewContext(types.StrictFlags, time.UTC, ctx), ctx.TypeCtx())
|
|
require.Equal(t, errctx.NewContextWithLevels(errctx.LevelMap{}, ctx), ctx.ErrCtx())
|
|
require.Equal(t, "", ctx.CurrentDB())
|
|
require.Equal(t, vardef.DefMaxAllowedPacket, ctx.GetMaxAllowedPacket())
|
|
require.Equal(t, vardef.DefDefaultWeekFormat, ctx.GetDefaultWeekFormatMode())
|
|
require.Equal(t, vardef.DefDivPrecisionIncrement, ctx.GetDivPrecisionIncrement())
|
|
require.Empty(t, ctx.AllParamValues())
|
|
require.Equal(t, variable.NewUserVars(), ctx.GetUserVarsReader())
|
|
require.True(t, ctx.GetOptionalPropSet().IsEmpty())
|
|
p, ok := ctx.GetOptionalPropProvider(exprctx.OptPropAdvisoryLock)
|
|
require.Nil(t, p)
|
|
require.False(t, ok)
|
|
|
|
tm, err := ctx.CurrentTime()
|
|
require.NoError(t, err)
|
|
require.Same(t, time.UTC, tm.Location())
|
|
require.InDelta(t, time.Now().Unix(), tm.Unix(), 5)
|
|
|
|
warnHandler, ok := ctx.warnHandler.(*contextutil.StaticWarnHandler)
|
|
require.True(t, ok)
|
|
require.Equal(t, 0, warnHandler.WarningCount())
|
|
}
|
|
|
|
type evalCtxOptionsTestState struct {
|
|
now time.Time
|
|
loc *time.Location
|
|
warnHandler *contextutil.StaticWarnHandler
|
|
userVars *variable.UserVars
|
|
ddlOwner bool
|
|
}
|
|
|
|
func getEvalCtxOptionsForTest(t *testing.T) ([]EvalCtxOption, *evalCtxOptionsTestState) {
|
|
loc, err := time.LoadLocation("US/Eastern")
|
|
require.NoError(t, err)
|
|
s := &evalCtxOptionsTestState{
|
|
now: time.Now(),
|
|
loc: loc,
|
|
warnHandler: contextutil.NewStaticWarnHandler(8),
|
|
userVars: variable.NewUserVars(),
|
|
}
|
|
|
|
provider1 := expropt.CurrentUserPropProvider(func() (*auth.UserIdentity, []*auth.RoleIdentity) {
|
|
return &auth.UserIdentity{Username: "user1", Hostname: "host1"},
|
|
[]*auth.RoleIdentity{{Username: "role1", Hostname: "host2"}}
|
|
})
|
|
|
|
provider2 := expropt.DDLOwnerInfoProvider(func() bool {
|
|
return s.ddlOwner
|
|
})
|
|
|
|
return []EvalCtxOption{
|
|
WithWarnHandler(s.warnHandler),
|
|
WithSQLMode(mysql.ModeNoZeroDate | mysql.ModeStrictTransTables),
|
|
WithTypeFlags(types.FlagAllowNegativeToUnsigned | types.FlagSkipASCIICheck),
|
|
WithErrLevelMap(errctx.LevelMap{
|
|
errctx.ErrGroupBadNull: errctx.LevelError,
|
|
errctx.ErrGroupNoDefault: errctx.LevelError,
|
|
errctx.ErrGroupDividedByZero: errctx.LevelWarn,
|
|
}),
|
|
WithLocation(loc),
|
|
WithCurrentDB("db1"),
|
|
WithCurrentTime(func() (time.Time, error) {
|
|
return s.now, nil
|
|
}),
|
|
WithMaxAllowedPacket(12345),
|
|
WithDefaultWeekFormatMode("3"),
|
|
WithDivPrecisionIncrement(5),
|
|
WithUserVarsReader(s.userVars),
|
|
WithOptionalProperty(provider1, provider2),
|
|
}, s
|
|
}
|
|
|
|
func checkOptionsStaticEvalCtx(t *testing.T, ctx *EvalContext, s *evalCtxOptionsTestState) {
|
|
require.Same(t, ctx.warnHandler, s.warnHandler)
|
|
require.Equal(t, mysql.ModeNoZeroDate|mysql.ModeStrictTransTables, ctx.SQLMode())
|
|
require.Equal(t,
|
|
types.NewContext(types.FlagAllowNegativeToUnsigned|types.FlagSkipASCIICheck, s.loc, ctx),
|
|
ctx.TypeCtx(),
|
|
)
|
|
require.Equal(t, errctx.NewContextWithLevels(errctx.LevelMap{
|
|
errctx.ErrGroupBadNull: errctx.LevelError,
|
|
errctx.ErrGroupNoDefault: errctx.LevelError,
|
|
errctx.ErrGroupDividedByZero: errctx.LevelWarn,
|
|
}, ctx), ctx.ErrCtx())
|
|
require.Same(t, s.loc, ctx.Location())
|
|
require.Equal(t, "db1", ctx.CurrentDB())
|
|
current, err := ctx.CurrentTime()
|
|
require.NoError(t, err)
|
|
require.Equal(t, current.UnixNano(), s.now.UnixNano())
|
|
require.Same(t, s.loc, current.Location())
|
|
require.Equal(t, uint64(12345), ctx.GetMaxAllowedPacket())
|
|
require.Equal(t, "3", ctx.GetDefaultWeekFormatMode())
|
|
require.Equal(t, 5, ctx.GetDivPrecisionIncrement())
|
|
require.Same(t, s.userVars, ctx.GetUserVarsReader())
|
|
|
|
var optSet exprctx.OptionalEvalPropKeySet
|
|
optSet = optSet.Add(exprctx.OptPropCurrentUser).Add(exprctx.OptPropDDLOwnerInfo)
|
|
require.Equal(t, optSet, ctx.GetOptionalPropSet())
|
|
p, ok := ctx.GetOptionalPropProvider(exprctx.OptPropCurrentUser)
|
|
require.True(t, ok)
|
|
user, roles := p.(expropt.CurrentUserPropProvider)()
|
|
require.Equal(t, &auth.UserIdentity{Username: "user1", Hostname: "host1"}, user)
|
|
require.Equal(t, []*auth.RoleIdentity{{Username: "role1", Hostname: "host2"}}, roles)
|
|
p, ok = ctx.GetOptionalPropProvider(exprctx.OptPropDDLOwnerInfo)
|
|
s.ddlOwner = true
|
|
require.True(t, ok)
|
|
require.True(t, p.(expropt.DDLOwnerInfoProvider)())
|
|
s.ddlOwner = false
|
|
require.False(t, p.(expropt.DDLOwnerInfoProvider)())
|
|
p, ok = ctx.GetOptionalPropProvider(exprctx.OptPropInfoSchema)
|
|
require.False(t, ok)
|
|
require.Nil(t, p)
|
|
}
|
|
|
|
func TestStaticEvalCtxCurrentTime(t *testing.T) {
|
|
loc1, err := time.LoadLocation("US/Eastern")
|
|
require.NoError(t, err)
|
|
|
|
tm := time.UnixMicro(123456789).In(loc1)
|
|
calls := 0
|
|
getTime := func() (time.Time, error) {
|
|
defer func() {
|
|
calls++
|
|
}()
|
|
|
|
if calls < 2 {
|
|
return time.Time{}, errors.NewNoStackError(fmt.Sprintf("err%d", calls))
|
|
}
|
|
|
|
if calls == 2 {
|
|
return tm, nil
|
|
}
|
|
|
|
return time.Time{}, errors.NewNoStackError("should not reach here")
|
|
}
|
|
|
|
ctx := NewEvalContext(WithCurrentTime(getTime))
|
|
|
|
// get time for the first two times should fail
|
|
got, err := ctx.CurrentTime()
|
|
require.EqualError(t, err, "err0")
|
|
require.Equal(t, time.Time{}, got)
|
|
|
|
got, err = ctx.CurrentTime()
|
|
require.EqualError(t, err, "err1")
|
|
require.Equal(t, time.Time{}, got)
|
|
|
|
// the third time will success
|
|
got, err = ctx.CurrentTime()
|
|
require.Nil(t, err)
|
|
require.Equal(t, tm.UnixNano(), got.UnixNano())
|
|
require.Same(t, time.UTC, got.Location())
|
|
require.Equal(t, 3, calls)
|
|
|
|
// next ctx should cache the time without calling inner function
|
|
got, err = ctx.CurrentTime()
|
|
require.Nil(t, err)
|
|
require.Equal(t, tm.UnixNano(), got.UnixNano())
|
|
require.Same(t, time.UTC, got.Location())
|
|
require.Equal(t, 3, calls)
|
|
|
|
// CurrentTime should have the same location with `ctx.Location()`
|
|
loc2, err := time.LoadLocation("Australia/Sydney")
|
|
require.NoError(t, err)
|
|
ctx = NewEvalContext(
|
|
WithLocation(loc2),
|
|
WithCurrentTime(func() (time.Time, error) {
|
|
return tm, nil
|
|
}),
|
|
)
|
|
got, err = ctx.CurrentTime()
|
|
require.NoError(t, err)
|
|
require.Equal(t, tm.UnixNano(), got.UnixNano())
|
|
require.Same(t, loc2, got.Location())
|
|
|
|
// Apply should copy the current time
|
|
ctx2 := ctx.Apply()
|
|
got, err = ctx2.CurrentTime()
|
|
require.NoError(t, err)
|
|
require.Equal(t, tm.UnixNano(), got.UnixNano())
|
|
require.Same(t, loc2, got.Location())
|
|
|
|
// Apply with location should change current time's location
|
|
ctx2 = ctx.Apply(WithLocation(loc1))
|
|
got, err = ctx2.CurrentTime()
|
|
require.NoError(t, err)
|
|
require.Equal(t, tm.UnixNano(), got.UnixNano())
|
|
require.Same(t, loc1, got.Location())
|
|
|
|
// Apply will not affect previous current time
|
|
got, err = ctx.CurrentTime()
|
|
require.NoError(t, err)
|
|
require.Equal(t, tm.UnixNano(), got.UnixNano())
|
|
require.Same(t, loc2, got.Location())
|
|
|
|
// Apply with a different current time func
|
|
ctx2 = ctx.Apply(WithCurrentTime(func() (time.Time, error) {
|
|
return time.UnixMicro(987654321), nil
|
|
}))
|
|
got, err = ctx2.CurrentTime()
|
|
require.NoError(t, err)
|
|
require.Equal(t, int64(987654321), got.UnixMicro())
|
|
require.Same(t, loc2, got.Location())
|
|
|
|
// Apply will not affect previous current time
|
|
got, err = ctx.CurrentTime()
|
|
require.NoError(t, err)
|
|
require.Equal(t, tm.UnixNano(), got.UnixNano())
|
|
require.Same(t, loc2, got.Location())
|
|
}
|
|
|
|
func TestStaticEvalCtxWarnings(t *testing.T) {
|
|
// default context should have a empty StaticWarningsHandler
|
|
ctx := NewEvalContext()
|
|
h, ok := ctx.warnHandler.(*contextutil.StaticWarnHandler)
|
|
require.True(t, ok)
|
|
require.Equal(t, 0, h.WarningCount())
|
|
|
|
// WithWarnHandler should work
|
|
ignoreHandler := contextutil.IgnoreWarn
|
|
ctx = NewEvalContext(WithWarnHandler(ignoreHandler))
|
|
require.True(t, ctx.warnHandler == ignoreHandler)
|
|
|
|
// All contexts should use the same warning handler
|
|
h = contextutil.NewStaticWarnHandler(8)
|
|
ctx = NewEvalContext(WithWarnHandler(h))
|
|
tc, ec := ctx.TypeCtx(), ctx.ErrCtx()
|
|
h.AppendWarning(errors.NewNoStackError("warn0"))
|
|
ctx.AppendWarning(errors.NewNoStackError("warn1"))
|
|
ctx.AppendNote(errors.NewNoStackError("note1"))
|
|
tc.AppendWarning(errors.NewNoStackError("warn2"))
|
|
ec.AppendWarning(errors.NewNoStackError("warn3"))
|
|
require.Equal(t, 5, h.WarningCount())
|
|
require.Equal(t, h.WarningCount(), ctx.WarningCount())
|
|
|
|
// ctx.CopyWarnings
|
|
warnings := ctx.CopyWarnings(nil)
|
|
require.Equal(t, []contextutil.SQLWarn{
|
|
{Level: contextutil.WarnLevelWarning, Err: errors.NewNoStackError("warn0")},
|
|
{Level: contextutil.WarnLevelWarning, Err: errors.NewNoStackError("warn1")},
|
|
{Level: contextutil.WarnLevelNote, Err: errors.NewNoStackError("note1")},
|
|
{Level: contextutil.WarnLevelWarning, Err: errors.NewNoStackError("warn2")},
|
|
{Level: contextutil.WarnLevelWarning, Err: errors.NewNoStackError("warn3")},
|
|
}, warnings)
|
|
require.Equal(t, 5, h.WarningCount())
|
|
require.Equal(t, h.WarningCount(), ctx.WarningCount())
|
|
|
|
// ctx.TruncateWarnings
|
|
warnings = ctx.TruncateWarnings(2)
|
|
require.Equal(t, []contextutil.SQLWarn{
|
|
{Level: contextutil.WarnLevelNote, Err: errors.NewNoStackError("note1")},
|
|
{Level: contextutil.WarnLevelWarning, Err: errors.NewNoStackError("warn2")},
|
|
{Level: contextutil.WarnLevelWarning, Err: errors.NewNoStackError("warn3")},
|
|
}, warnings)
|
|
require.Equal(t, 2, h.WarningCount())
|
|
require.Equal(t, h.WarningCount(), ctx.WarningCount())
|
|
warnings = ctx.CopyWarnings(nil)
|
|
require.Equal(t, []contextutil.SQLWarn{
|
|
{Level: contextutil.WarnLevelWarning, Err: errors.NewNoStackError("warn0")},
|
|
{Level: contextutil.WarnLevelWarning, Err: errors.NewNoStackError("warn1")},
|
|
}, warnings)
|
|
|
|
// Apply should use the old warning handler by default
|
|
ctx2 := ctx.Apply()
|
|
require.NotSame(t, ctx, ctx2)
|
|
require.True(t, ctx.warnHandler == ctx2.warnHandler)
|
|
require.True(t, ctx.warnHandler == h)
|
|
|
|
// Apply with `WithWarnHandler`
|
|
h2 := contextutil.NewStaticWarnHandler(16)
|
|
ctx2 = ctx.Apply(WithWarnHandler(h2))
|
|
require.True(t, ctx2.warnHandler == h2)
|
|
require.True(t, ctx.warnHandler == h)
|
|
|
|
// The type context and error context should use the new handler.
|
|
ctx.TruncateWarnings(0)
|
|
tc, ec = ctx.TypeCtx(), ctx.ErrCtx()
|
|
tc2, ec2 := ctx2.TypeCtx(), ctx2.ErrCtx()
|
|
tc2.AppendWarning(errors.NewNoStackError("warn4"))
|
|
ec2.AppendWarning(errors.NewNoStackError("warn5"))
|
|
tc.AppendWarning(errors.NewNoStackError("warn6"))
|
|
ec.AppendWarning(errors.NewNoStackError("warn7"))
|
|
require.Equal(t, []contextutil.SQLWarn{
|
|
{Level: contextutil.WarnLevelWarning, Err: errors.NewNoStackError("warn4")},
|
|
{Level: contextutil.WarnLevelWarning, Err: errors.NewNoStackError("warn5")},
|
|
}, ctx2.CopyWarnings(nil))
|
|
require.Equal(t, []contextutil.SQLWarn{
|
|
{Level: contextutil.WarnLevelWarning, Err: errors.NewNoStackError("warn6")},
|
|
{Level: contextutil.WarnLevelWarning, Err: errors.NewNoStackError("warn7")},
|
|
}, ctx.CopyWarnings(nil))
|
|
}
|
|
|
|
func TestStaticEvalContextOptionalProps(t *testing.T) {
|
|
ctx := NewEvalContext()
|
|
require.True(t, ctx.GetOptionalPropSet().IsEmpty())
|
|
|
|
ctx2 := ctx.Apply(WithOptionalProperty(
|
|
expropt.CurrentUserPropProvider(func() (u *auth.UserIdentity, r []*auth.RoleIdentity) { return }),
|
|
))
|
|
var emptySet exprctx.OptionalEvalPropKeySet
|
|
require.Equal(t, emptySet, ctx.GetOptionalPropSet())
|
|
require.Equal(t, emptySet.Add(exprctx.OptPropCurrentUser), ctx2.GetOptionalPropSet())
|
|
|
|
// Apply should override all optional properties
|
|
ctx3 := ctx2.Apply(WithOptionalProperty(
|
|
expropt.DDLOwnerInfoProvider(func() bool { return true }),
|
|
expropt.InfoSchemaPropProvider(func(isDomain bool) infoschema.MetaOnlyInfoSchema { return nil }),
|
|
))
|
|
require.Equal(t,
|
|
emptySet.Add(exprctx.OptPropDDLOwnerInfo).Add(exprctx.OptPropInfoSchema),
|
|
ctx3.GetOptionalPropSet(),
|
|
)
|
|
require.Equal(t, emptySet, ctx.GetOptionalPropSet())
|
|
require.Equal(t, emptySet.Add(exprctx.OptPropCurrentUser), ctx2.GetOptionalPropSet())
|
|
}
|
|
|
|
func TestUpdateStaticEvalContext(t *testing.T) {
|
|
oldCtx := NewEvalContext()
|
|
ctx := oldCtx.Apply()
|
|
|
|
// Should return a different instance
|
|
require.NotSame(t, oldCtx, ctx)
|
|
|
|
// CtxID should be different
|
|
require.Greater(t, ctx.CtxID(), oldCtx.CtxID())
|
|
|
|
// inner state should not be the same address
|
|
require.NotSame(t, &oldCtx.evalCtxState, &ctx.evalCtxState)
|
|
|
|
// compare a state object by excluding some changed fields
|
|
excludeChangedFields := func(s *evalCtxState) evalCtxState {
|
|
state := *s
|
|
state.typeCtx = types.DefaultStmtNoWarningContext
|
|
state.errCtx = errctx.StrictNoWarningContext
|
|
state.currentTime = nil
|
|
return state
|
|
}
|
|
require.Equal(t, excludeChangedFields(&oldCtx.evalCtxState), excludeChangedFields(&ctx.evalCtxState))
|
|
|
|
// check fields
|
|
checkDefaultStaticEvalCtx(t, ctx)
|
|
|
|
// apply options
|
|
opts, optState := getEvalCtxOptionsForTest(t)
|
|
ctx2 := oldCtx.Apply(opts...)
|
|
require.Greater(t, ctx2.CtxID(), ctx.CtxID())
|
|
checkOptionsStaticEvalCtx(t, ctx2, optState)
|
|
|
|
// old ctx aren't affected
|
|
checkDefaultStaticEvalCtx(t, oldCtx)
|
|
|
|
// create with options
|
|
opts, optState = getEvalCtxOptionsForTest(t)
|
|
ctx3 := NewEvalContext(opts...)
|
|
require.Greater(t, ctx3.CtxID(), ctx2.CtxID())
|
|
checkOptionsStaticEvalCtx(t, ctx3, optState)
|
|
}
|
|
|
|
func TestParamList(t *testing.T) {
|
|
paramList := variable.NewPlanCacheParamList()
|
|
paramList.Append(types.NewDatum(1))
|
|
paramList.Append(types.NewDatum(2))
|
|
paramList.Append(types.NewDatum(3))
|
|
ctx := NewEvalContext(
|
|
WithParamList(paramList),
|
|
)
|
|
for i := range 3 {
|
|
val, err := ctx.GetParamValue(i)
|
|
require.NoError(t, err)
|
|
require.Equal(t, int64(i+1), val.GetInt64())
|
|
}
|
|
|
|
// after reset the paramList and append new one, the value is still persisted
|
|
paramList.Reset()
|
|
paramList.Append(types.NewDatum(4))
|
|
for i := range 3 {
|
|
val, err := ctx.GetParamValue(i)
|
|
require.NoError(t, err)
|
|
require.Equal(t, int64(i+1), val.GetInt64())
|
|
}
|
|
}
|
|
|
|
func TestMakeEvalContextStatic(t *testing.T) {
|
|
// This test is to ensure that the `MakeEvalContextStatic` function works as expected.
|
|
// It requires the developers to create a special `EvalContext`, whose every fields
|
|
// are non-empty. Then, the `MakeEvalContextStatic` function is called to create a new
|
|
// clone of it. Finally, the new clone is compared with the original one to ensure that
|
|
// the fields are correctly copied.
|
|
paramList := variable.NewPlanCacheParamList()
|
|
paramList.Append(types.NewDatum(1))
|
|
|
|
userVars := variable.NewUserVars()
|
|
userVars.SetUserVarVal("a", types.NewStringDatum("v1"))
|
|
userVars.SetUserVarVal("b", types.NewIntDatum(2))
|
|
|
|
provider := expropt.DDLOwnerInfoProvider(func() bool {
|
|
return true
|
|
})
|
|
|
|
obj := NewEvalContext(
|
|
WithWarnHandler(contextutil.NewStaticWarnHandler(16)),
|
|
WithSQLMode(mysql.ModeNoZeroDate|mysql.ModeStrictTransTables),
|
|
WithTypeFlags(types.FlagAllowNegativeToUnsigned|types.FlagSkipASCIICheck),
|
|
WithErrLevelMap(errctx.LevelMap{}),
|
|
WithLocation(time.UTC),
|
|
WithCurrentDB("db1"),
|
|
WithCurrentTime(func() (time.Time, error) {
|
|
return time.Now(), nil
|
|
}),
|
|
WithMaxAllowedPacket(12345),
|
|
WithDefaultWeekFormatMode("3"),
|
|
WithDivPrecisionIncrement(5),
|
|
WithParamList(paramList),
|
|
WithUserVarsReader(userVars),
|
|
WithOptionalProperty(provider),
|
|
WithEnableRedactLog("test"),
|
|
)
|
|
obj.AppendWarning(errors.New("test warning"))
|
|
|
|
ignorePath := []string{
|
|
"$.evalCtxState.warnHandler.**",
|
|
"$.evalCtxState.typeCtx.**",
|
|
"$.evalCtxState.errCtx.**",
|
|
"$.evalCtxState.currentTime.**",
|
|
"$.evalCtxState.userVars.lock",
|
|
"$.evalCtxState.props",
|
|
"$.id",
|
|
}
|
|
deeptest.AssertRecursivelyNotEqual(t, obj, NewEvalContext(),
|
|
deeptest.WithIgnorePath(ignorePath),
|
|
)
|
|
|
|
staticObj := MakeEvalContextStatic(obj)
|
|
|
|
deeptest.AssertDeepClonedEqual(t, obj, staticObj,
|
|
deeptest.WithIgnorePath(ignorePath),
|
|
deeptest.WithPointerComparePath([]string{
|
|
"$.evalCtxState.warnHandler",
|
|
"$.evalCtxState.paramList*.b",
|
|
}),
|
|
)
|
|
|
|
require.Equal(t, obj.GetWarnHandler(), staticObj.GetWarnHandler())
|
|
require.Equal(t, obj.typeCtx.Flags(), staticObj.typeCtx.Flags())
|
|
require.Equal(t, obj.errCtx.LevelMap(), staticObj.errCtx.LevelMap())
|
|
|
|
oldT, err := obj.CurrentTime()
|
|
require.NoError(t, err)
|
|
newT, err := staticObj.CurrentTime()
|
|
require.NoError(t, err)
|
|
require.Equal(t, oldT.Unix(), newT.Unix())
|
|
|
|
require.NotEqual(t, obj.GetOptionalPropSet(), staticObj.GetOptionalPropSet())
|
|
// Now, it didn't copy any optional properties.
|
|
require.Equal(t, exprctx.OptionalEvalPropKeySet(0), staticObj.GetOptionalPropSet())
|
|
}
|
|
|
|
func TestEvalCtxLoadSystemVars(t *testing.T) {
|
|
vars := []struct {
|
|
name string
|
|
val string
|
|
field string
|
|
assert func(ctx *EvalContext, vars *variable.SessionVars)
|
|
}{
|
|
{
|
|
name: "time_zone",
|
|
val: "Europe/Berlin",
|
|
field: "$.typeCtx.loc",
|
|
assert: func(ctx *EvalContext, vars *variable.SessionVars) {
|
|
require.Equal(t, "Europe/Berlin", ctx.Location().String())
|
|
require.Equal(t, vars.Location().String(), ctx.Location().String())
|
|
},
|
|
},
|
|
{
|
|
name: "sql_mode",
|
|
val: "ALLOW_INVALID_DATES,ONLY_FULL_GROUP_BY",
|
|
field: "$.sqlMode",
|
|
assert: func(ctx *EvalContext, vars *variable.SessionVars) {
|
|
require.Equal(t, mysql.ModeAllowInvalidDates|mysql.ModeOnlyFullGroupBy, ctx.SQLMode())
|
|
require.Equal(t, vars.SQLMode, ctx.SQLMode())
|
|
},
|
|
},
|
|
{
|
|
name: "timestamp",
|
|
val: "1234567890.123456",
|
|
field: "$.currentTime",
|
|
assert: func(ctx *EvalContext, vars *variable.SessionVars) {
|
|
currentTime, err := ctx.CurrentTime()
|
|
require.NoError(t, err)
|
|
require.Equal(t, int64(1234567890123456), currentTime.UnixMicro())
|
|
require.Equal(t, vars.Location().String(), currentTime.Location().String())
|
|
},
|
|
},
|
|
{
|
|
name: strings.ToUpper("max_allowed_packet"), // test for settings an upper case variable
|
|
val: "524288",
|
|
field: "$.maxAllowedPacket",
|
|
assert: func(ctx *EvalContext, vars *variable.SessionVars) {
|
|
require.Equal(t, uint64(524288), ctx.GetMaxAllowedPacket())
|
|
require.Equal(t, vars.MaxAllowedPacket, ctx.GetMaxAllowedPacket())
|
|
},
|
|
},
|
|
{
|
|
name: strings.ToUpper("tidb_redact_log"), // test for settings an upper case variable
|
|
val: "ON",
|
|
field: "$.enableRedactLog",
|
|
assert: func(ctx *EvalContext, vars *variable.SessionVars) {
|
|
require.Equal(t, "ON", ctx.GetTiDBRedactLog())
|
|
require.Equal(t, vars.EnableRedactLog, ctx.GetTiDBRedactLog())
|
|
},
|
|
},
|
|
{
|
|
name: "default_week_format",
|
|
val: "5",
|
|
field: "$.defaultWeekFormatMode",
|
|
assert: func(ctx *EvalContext, vars *variable.SessionVars) {
|
|
require.Equal(t, "5", ctx.GetDefaultWeekFormatMode())
|
|
mode, ok := vars.GetSystemVar(vardef.DefaultWeekFormat)
|
|
require.True(t, ok)
|
|
require.Equal(t, mode, ctx.GetDefaultWeekFormatMode())
|
|
},
|
|
},
|
|
{
|
|
name: "div_precision_increment",
|
|
val: "12",
|
|
field: "$.divPrecisionIncrement",
|
|
assert: func(ctx *EvalContext, vars *variable.SessionVars) {
|
|
require.Equal(t, 12, ctx.GetDivPrecisionIncrement())
|
|
require.Equal(t, vars.DivPrecisionIncrement, ctx.GetDivPrecisionIncrement())
|
|
},
|
|
},
|
|
}
|
|
|
|
// nonVarRelatedFields means the fields not related to any system variables.
|
|
// To make sure that all the variables which affect the context state are covered in the above test list,
|
|
// we need to test all inner fields except those in `nonVarRelatedFields` are changed after `LoadSystemVars`.
|
|
nonVarRelatedFields := []string{
|
|
"$.warnHandler",
|
|
"$.typeCtx.flags",
|
|
"$.typeCtx.warnHandler",
|
|
"$.errCtx",
|
|
"$.currentDB",
|
|
"$.paramList",
|
|
"$.userVars",
|
|
"$.props",
|
|
}
|
|
|
|
// varsRelatedFields means the fields related to
|
|
varsRelatedFields := make([]string, 0, len(vars))
|
|
varsMap := make(map[string]string)
|
|
sessionVars := variable.NewSessionVars(nil)
|
|
for _, sysVar := range vars {
|
|
varsMap[sysVar.name] = sysVar.val
|
|
if sysVar.field != "" {
|
|
varsRelatedFields = append(varsRelatedFields, sysVar.field)
|
|
}
|
|
sv := variable.GetSysVar(sysVar.name)
|
|
require.NotNil(t, sv)
|
|
if sv.HasSessionScope() && !sv.InternalSessionVariable {
|
|
require.NoError(t, sessionVars.SetSystemVar(sysVar.name, sysVar.val))
|
|
} else if sv.HasGlobalScope() {
|
|
require.NoError(t, sv.SetGlobalFromHook(context.TODO(), sessionVars, sysVar.val, false))
|
|
}
|
|
}
|
|
|
|
defaultEvalCtx := NewEvalContext()
|
|
ctx, err := defaultEvalCtx.LoadSystemVars(varsMap)
|
|
require.NoError(t, err)
|
|
require.Greater(t, ctx.CtxID(), defaultEvalCtx.CtxID())
|
|
|
|
// Check all fields except these in `nonVarRelatedFields` are changed after `LoadSystemVars` to make sure
|
|
// all system variables related fields are covered in the test list.
|
|
deeptest.AssertRecursivelyNotEqual(
|
|
t,
|
|
defaultEvalCtx.evalCtxState,
|
|
ctx.evalCtxState,
|
|
deeptest.WithIgnorePath(nonVarRelatedFields),
|
|
deeptest.WithPointerComparePath([]string{"$.currentTime"}),
|
|
)
|
|
|
|
// We need to compare the new context again with an empty one to make sure those values are set from sys vars,
|
|
// not inherited from the empty go value.
|
|
deeptest.AssertRecursivelyNotEqual(
|
|
t,
|
|
evalCtxState{},
|
|
ctx.evalCtxState,
|
|
deeptest.WithIgnorePath(nonVarRelatedFields),
|
|
deeptest.WithPointerComparePath([]string{"$.currentTime"}),
|
|
)
|
|
|
|
// Check all system vars unrelated fields are not changed after `LoadSystemVars`.
|
|
deeptest.AssertDeepClonedEqual(
|
|
t,
|
|
defaultEvalCtx.evalCtxState,
|
|
ctx.evalCtxState,
|
|
deeptest.WithIgnorePath(append(
|
|
varsRelatedFields,
|
|
// Do not check warnHandler in `typeCtx` and `errCtx` because they should be changed to even if
|
|
// they are not related to any system variable.
|
|
"$.typeCtx.warnHandler",
|
|
"$.errCtx.warnHandler",
|
|
)),
|
|
// LoadSystemVars only does shallow copy for `EvalContext` so we just need to compare the pointers.
|
|
deeptest.WithPointerComparePath(nonVarRelatedFields),
|
|
)
|
|
|
|
for _, sysVar := range vars {
|
|
sysVar.assert(ctx, sessionVars)
|
|
}
|
|
|
|
// additional check about @@timestamp
|
|
// setting to `variable.DefTimestamp` should return the current timestamp
|
|
ctx, err = defaultEvalCtx.LoadSystemVars(map[string]string{
|
|
"timestamp": vardef.DefTimestamp,
|
|
})
|
|
require.NoError(t, err)
|
|
tm, err := ctx.CurrentTime()
|
|
require.NoError(t, err)
|
|
require.InDelta(t, time.Now().Unix(), tm.Unix(), 5)
|
|
}
|