// Copyright 2024 PingCAP, Inc. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. package exprstatic import ( "context" "math" "strings" "sync" "sync/atomic" "time" "github.com/pingcap/tidb/pkg/errctx" "github.com/pingcap/tidb/pkg/expression/exprctx" "github.com/pingcap/tidb/pkg/expression/expropt" "github.com/pingcap/tidb/pkg/parser/mysql" "github.com/pingcap/tidb/pkg/sessionctx/vardef" "github.com/pingcap/tidb/pkg/sessionctx/variable" "github.com/pingcap/tidb/pkg/types" contextutil "github.com/pingcap/tidb/pkg/util/context" "github.com/pingcap/tidb/pkg/util/intest" ) // EvalContext should implement `EvalContext` var _ exprctx.EvalContext = &EvalContext{} type timeOnce struct { sync.Mutex time atomic.Pointer[time.Time] timeFn func() (time.Time, error) } func (t *timeOnce) getTime(loc *time.Location) (tm time.Time, err error) { if p := t.time.Load(); p != nil { return *p, nil } t.Lock() defer t.Unlock() if p := t.time.Load(); p != nil { return *p, nil } if fn := t.timeFn; fn != nil { tm, err = fn() } else { tm, err = time.Now(), nil } if err != nil { return } tm = tm.In(loc) t.time.Store(&tm) return } // evalCtxState is the internal state for `EvalContext`. // We make it as a standalone private struct here to make sure `EvalCtxOption` can only be called in constructor. type evalCtxState struct { warnHandler contextutil.WarnHandler sqlMode mysql.SQLMode typeCtx types.Context errCtx errctx.Context currentDB string currentTime *timeOnce maxAllowedPacket uint64 enableRedactLog string defaultWeekFormatMode string divPrecisionIncrement int paramList []types.Datum userVars variable.UserVarsReader props expropt.OptionalEvalPropProviders } // EvalCtxOption is the option to set `EvalContext`. type EvalCtxOption func(*evalCtxState) // WithWarnHandler sets the warn handler for the `EvalContext`. func WithWarnHandler(h contextutil.WarnHandler) EvalCtxOption { intest.AssertNotNil(h) if h == nil { // this should not happen, just to keep code safe h = contextutil.IgnoreWarn } return func(s *evalCtxState) { s.warnHandler = h } } // WithSQLMode sets the sql mode for the `EvalContext`. func WithSQLMode(sqlMode mysql.SQLMode) EvalCtxOption { return func(s *evalCtxState) { s.sqlMode = sqlMode } } // WithTypeFlags sets the type flags for the `EvalContext`. func WithTypeFlags(flags types.Flags) EvalCtxOption { return func(s *evalCtxState) { s.typeCtx = s.typeCtx.WithFlags(flags) } } // WithLocation sets the timezone info for the `EvalContext`. func WithLocation(loc *time.Location) EvalCtxOption { intest.AssertNotNil(loc) if loc == nil { // this should not happen, just to keep code safe loc = time.UTC } return func(s *evalCtxState) { s.typeCtx = s.typeCtx.WithLocation(loc) } } // WithErrLevelMap sets the error level map for the `EvalContext`. func WithErrLevelMap(level errctx.LevelMap) EvalCtxOption { return func(s *evalCtxState) { s.errCtx = s.errCtx.WithErrGroupLevels(level) } } // WithCurrentDB sets the current database name for the `EvalContext`. func WithCurrentDB(db string) EvalCtxOption { return func(s *evalCtxState) { s.currentDB = db } } // WithCurrentTime sets the current time for the `EvalContext`. func WithCurrentTime(fn func() (time.Time, error)) EvalCtxOption { intest.AssertNotNil(fn) return func(s *evalCtxState) { s.currentTime = &timeOnce{timeFn: fn} } } // WithMaxAllowedPacket sets the value of the 'max_allowed_packet' system variable. func WithMaxAllowedPacket(size uint64) EvalCtxOption { return func(s *evalCtxState) { s.maxAllowedPacket = size } } // WithDefaultWeekFormatMode sets the value of the 'default_week_format' system variable. func WithDefaultWeekFormatMode(mode string) EvalCtxOption { return func(s *evalCtxState) { s.defaultWeekFormatMode = mode } } // WithDivPrecisionIncrement sets the value of the 'div_precision_increment' system variable. func WithDivPrecisionIncrement(inc int) EvalCtxOption { return func(s *evalCtxState) { s.divPrecisionIncrement = inc } } // WithOptionalProperty sets the optional property providers func WithOptionalProperty(providers ...exprctx.OptionalEvalPropProvider) EvalCtxOption { return func(s *evalCtxState) { s.props = expropt.OptionalEvalPropProviders{} for _, p := range providers { s.props.Add(p) } } } // WithParamList sets the param list for the `EvalContext`. func WithParamList(params *variable.PlanCacheParamList) EvalCtxOption { return func(s *evalCtxState) { s.paramList = make([]types.Datum, len(params.AllParamValues())) for i, v := range params.AllParamValues() { s.paramList[i] = v } } } // WithEnableRedactLog sets the value of the 'tidb_redact_log' system variable. func WithEnableRedactLog(enableRedactLog string) EvalCtxOption { return func(s *evalCtxState) { s.enableRedactLog = enableRedactLog } } // WithUserVarsReader set the user variables reader for the `EvalContext`. func WithUserVarsReader(vars variable.UserVarsReader) EvalCtxOption { return func(s *evalCtxState) { s.userVars = vars } } var defaultSQLMode = func() mysql.SQLMode { mode, err := mysql.GetSQLMode(mysql.DefaultSQLMode) if err != nil { panic(err) } return mode }() // EvalContext implements `EvalContext` to provide a static context for expression evaluation. // The "static" means comparing with `EvalContext`, its internal state does not relay on the session or other // complex contexts that keeps immutable for most fields. type EvalContext struct { id uint64 evalCtxState } // NewEvalContext creates a new `EvalContext` with the given options. func NewEvalContext(opt ...EvalCtxOption) *EvalContext { ctx := &EvalContext{ id: contextutil.GenContextID(), evalCtxState: evalCtxState{ currentTime: &timeOnce{}, sqlMode: defaultSQLMode, maxAllowedPacket: vardef.DefMaxAllowedPacket, enableRedactLog: vardef.DefTiDBRedactLog, defaultWeekFormatMode: vardef.DefDefaultWeekFormat, divPrecisionIncrement: vardef.DefDivPrecisionIncrement, }, } ctx.typeCtx = types.NewContext(types.StrictFlags, time.UTC, ctx) ctx.errCtx = errctx.NewContext(ctx) for _, o := range opt { o(&ctx.evalCtxState) } if ctx.warnHandler == nil { ctx.warnHandler = contextutil.NewStaticWarnHandler(0) } if ctx.userVars == nil { ctx.userVars = variable.NewUserVars() } return ctx } // CtxID returns the context ID. func (ctx *EvalContext) CtxID() uint64 { return ctx.id } // SQLMode returns the sql mode func (ctx *EvalContext) SQLMode() mysql.SQLMode { return ctx.sqlMode } // TypeCtx returns the types.Context func (ctx *EvalContext) TypeCtx() types.Context { return ctx.typeCtx } // ErrCtx returns the errctx.Context func (ctx *EvalContext) ErrCtx() errctx.Context { return ctx.errCtx } // Location returns the timezone info func (ctx *EvalContext) Location() *time.Location { return ctx.typeCtx.Location() } // AppendWarning append warnings to the context. func (ctx *EvalContext) AppendWarning(err error) { if h := ctx.warnHandler; h != nil { h.AppendWarning(err) } } // AppendNote appends notes to the context. func (ctx *EvalContext) AppendNote(err error) { if h := ctx.warnHandler; h != nil { h.AppendNote(err) } } // WarningCount gets warning count. func (ctx *EvalContext) WarningCount() int { if h := ctx.warnHandler; h != nil { return h.WarningCount() } return 0 } // TruncateWarnings truncates warnings begin from start and returns the truncated warnings. func (ctx *EvalContext) TruncateWarnings(start int) []contextutil.SQLWarn { if h := ctx.warnHandler; h != nil { return h.TruncateWarnings(start) } return nil } // CopyWarnings copies warnings to dst slice. func (ctx *EvalContext) CopyWarnings(dst []contextutil.SQLWarn) []contextutil.SQLWarn { if h := ctx.warnHandler; h != nil { return h.CopyWarnings(dst) } return nil } // CurrentDB return the current database name func (ctx *EvalContext) CurrentDB() string { return ctx.currentDB } // CurrentTime return the current time func (ctx *EvalContext) CurrentTime() (tm time.Time, err error) { return ctx.currentTime.getTime(ctx.Location()) } // GetMaxAllowedPacket returns the value of the 'max_allowed_packet' system variable. func (ctx *EvalContext) GetMaxAllowedPacket() uint64 { return ctx.maxAllowedPacket } // GetTiDBRedactLog returns the value of the 'tidb_redact_log' system variable. func (ctx *EvalContext) GetTiDBRedactLog() string { return ctx.enableRedactLog } // GetDefaultWeekFormatMode returns the value of the 'default_week_format' system variable. func (ctx *EvalContext) GetDefaultWeekFormatMode() string { return ctx.defaultWeekFormatMode } // GetDivPrecisionIncrement returns the specified value of DivPrecisionIncrement. func (ctx *EvalContext) GetDivPrecisionIncrement() int { return ctx.divPrecisionIncrement } // GetUserVarsReader returns the user variables. func (ctx *EvalContext) GetUserVarsReader() variable.UserVarsReader { return ctx.userVars } // GetOptionalPropSet gets the optional property set from context func (ctx *EvalContext) GetOptionalPropSet() exprctx.OptionalEvalPropKeySet { return ctx.props.PropKeySet() } // GetOptionalPropProvider gets the optional property provider by key func (ctx *EvalContext) GetOptionalPropProvider(key exprctx.OptionalEvalPropKey) (exprctx.OptionalEvalPropProvider, bool) { return ctx.props.Get(key) } // Apply returns a new `EvalContext` with the fields updated according to the given options. func (ctx *EvalContext) Apply(opt ...EvalCtxOption) *EvalContext { newCtx := &EvalContext{ id: contextutil.GenContextID(), evalCtxState: ctx.evalCtxState, } // current time should use the previous one by default newCtx.currentTime = &timeOnce{timeFn: ctx.CurrentTime} // typeCtx and errCtx should be reset because warn handler changed newCtx.typeCtx = types.NewContext(ctx.typeCtx.Flags(), ctx.typeCtx.Location(), newCtx) newCtx.errCtx = errctx.NewContextWithLevels(ctx.errCtx.LevelMap(), newCtx) // Apply options for _, o := range opt { o(&newCtx.evalCtxState) } return newCtx } // GetParamValue returns the value of the parameter by index. func (ctx *EvalContext) GetParamValue(idx int) (types.Datum, error) { if idx < 0 || idx >= len(ctx.paramList) { return types.Datum{}, exprctx.ErrParamIndexExceedParamCounts } return ctx.paramList[idx], nil } var _ exprctx.StaticConvertibleEvalContext = &EvalContext{} // AllParamValues implements context.StaticConvertibleEvalContext. func (ctx *EvalContext) AllParamValues() []types.Datum { return ctx.paramList } // GetWarnHandler implements context.StaticConvertibleEvalContext. func (ctx *EvalContext) GetWarnHandler() contextutil.WarnHandler { return ctx.warnHandler } // LoadSystemVars loads system variables and returns a new `EvalContext` with system variables loaded. func (ctx *EvalContext) LoadSystemVars(sysVars map[string]string) (*EvalContext, error) { sessionVars, err := newSessionVarsWithSystemVariables(sysVars) if err != nil { return nil, err } return ctx.loadSessionVarsInternal(sessionVars, sysVars), nil } func (ctx *EvalContext) loadSessionVarsInternal( sessionVars *variable.SessionVars, sysVars map[string]string, ) *EvalContext { opts := make([]EvalCtxOption, 0, 8) for name, val := range sysVars { name = strings.ToLower(name) switch name { case vardef.TimeZone: opts = append(opts, WithLocation(sessionVars.Location())) case vardef.SQLModeVar: opts = append(opts, WithSQLMode(sessionVars.SQLMode)) case vardef.Timestamp: opts = append(opts, WithCurrentTime(ctx.currentTimeFuncFromStringVal(val))) case vardef.MaxAllowedPacket: opts = append(opts, WithMaxAllowedPacket(sessionVars.MaxAllowedPacket)) case vardef.TiDBRedactLog: opts = append(opts, WithEnableRedactLog(sessionVars.EnableRedactLog)) case vardef.DefaultWeekFormat: opts = append(opts, WithDefaultWeekFormatMode(val)) case vardef.DivPrecisionIncrement: opts = append(opts, WithDivPrecisionIncrement(sessionVars.DivPrecisionIncrement)) } } return ctx.Apply(opts...) } func (ctx *EvalContext) currentTimeFuncFromStringVal(val string) func() (time.Time, error) { return func() (time.Time, error) { if val == vardef.DefTimestamp { return time.Now(), nil } ts, err := types.StrToFloat(types.StrictContext, val, false) if err != nil { return time.Time{}, err } seconds, fractionalSeconds := math.Modf(ts) return time.Unix(int64(seconds), int64(fractionalSeconds*float64(time.Second))), nil } } func newSessionVarsWithSystemVariables(vars map[string]string) (*variable.SessionVars, error) { sessionVars := variable.NewSessionVars(nil) var cs, col []string for name, val := range vars { switch strings.ToLower(name) { // `charset_connection` and `collation_connection` will overwrite each other. // To make the result more determinate, just set them at last step in order: // `charset_connection` first, then `collation_connection`. case vardef.CharacterSetConnection: cs = []string{name, val} case vardef.CollationConnection: col = []string{name, val} case vardef.TiDBRedactLog: sv := variable.GetSysVar(name) if err := sv.SetGlobalFromHook(context.TODO(), sessionVars, val, false); err != nil { return nil, err } default: if err := sessionVars.SetSystemVar(name, val); err != nil { return nil, err } } } if cs != nil { if err := sessionVars.SetSystemVar(cs[0], cs[1]); err != nil { return nil, err } } if col != nil { if err := sessionVars.SetSystemVar(col[0], col[1]); err != nil { return nil, err } } return sessionVars, nil } // MakeEvalContextStatic converts the `exprctx.StaticConvertibleEvalContext` to `EvalContext`. func MakeEvalContextStatic(ctx exprctx.StaticConvertibleEvalContext) *EvalContext { typeCtx := ctx.TypeCtx() errCtx := ctx.ErrCtx() // TODO: at least provide some optional eval prop provider which is suitable to be used in the static context. props := make([]exprctx.OptionalEvalPropProvider, 0, exprctx.OptPropsCnt) params := variable.NewPlanCacheParamList() for _, param := range ctx.AllParamValues() { params.Append(param) } // TODO: use a more structural way to replace the closure. // These closure makes sure the fields which may be changed in the execution of the next statement will not be embedded into them, to make // sure it's safe to call them after the session continues to execute other statements. staticCtx := NewEvalContext( WithWarnHandler(ctx.GetWarnHandler()), WithSQLMode(ctx.SQLMode()), WithTypeFlags(typeCtx.Flags()), WithLocation(typeCtx.Location()), WithErrLevelMap(errCtx.LevelMap()), WithCurrentDB(ctx.CurrentDB()), WithCurrentTime(func() func() (time.Time, error) { currentTime, currentTimeErr := ctx.CurrentTime() return func() (time.Time, error) { return currentTime, currentTimeErr } }()), WithMaxAllowedPacket(ctx.GetMaxAllowedPacket()), WithDefaultWeekFormatMode(ctx.GetDefaultWeekFormatMode()), WithDivPrecisionIncrement(ctx.GetDivPrecisionIncrement()), WithParamList(params), WithUserVarsReader(ctx.GetUserVarsReader().Clone()), WithOptionalProperty(props...), WithEnableRedactLog(ctx.GetTiDBRedactLog()), ) return staticCtx }