Files
tidb/pkg/expression/exprstatic/evalctx.go

531 lines
15 KiB
Go

// Copyright 2024 PingCAP, Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package exprstatic
import (
"math"
"strings"
"sync"
"sync/atomic"
"time"
"github.com/pingcap/tidb/pkg/errctx"
"github.com/pingcap/tidb/pkg/expression/exprctx"
"github.com/pingcap/tidb/pkg/expression/expropt"
"github.com/pingcap/tidb/pkg/parser/mysql"
"github.com/pingcap/tidb/pkg/sessionctx/vardef"
"github.com/pingcap/tidb/pkg/sessionctx/variable"
"github.com/pingcap/tidb/pkg/types"
contextutil "github.com/pingcap/tidb/pkg/util/context"
"github.com/pingcap/tidb/pkg/util/intest"
)
// EvalContext should implement `EvalContext`
var _ exprctx.EvalContext = &EvalContext{}
type timeOnce struct {
sync.Mutex
time atomic.Pointer[time.Time]
timeFn func() (time.Time, error)
}
func (t *timeOnce) getTime(loc *time.Location) (tm time.Time, err error) {
if p := t.time.Load(); p != nil {
return *p, nil
}
t.Lock()
defer t.Unlock()
if p := t.time.Load(); p != nil {
return *p, nil
}
if fn := t.timeFn; fn != nil {
tm, err = fn()
} else {
tm, err = time.Now(), nil
}
if err != nil {
return
}
tm = tm.In(loc)
t.time.Store(&tm)
return
}
// evalCtxState is the internal state for `EvalContext`.
// We make it as a standalone private struct here to make sure `EvalCtxOption` can only be called in constructor.
type evalCtxState struct {
warnHandler contextutil.WarnHandler
sqlMode mysql.SQLMode
typeCtx types.Context
errCtx errctx.Context
currentDB string
currentTime *timeOnce
maxAllowedPacket uint64
enableRedactLog string
defaultWeekFormatMode string
divPrecisionIncrement int
paramList []types.Datum
userVars variable.UserVarsReader
props expropt.OptionalEvalPropProviders
}
// EvalCtxOption is the option to set `EvalContext`.
type EvalCtxOption func(*evalCtxState)
// WithWarnHandler sets the warn handler for the `EvalContext`.
func WithWarnHandler(h contextutil.WarnHandler) EvalCtxOption {
intest.AssertNotNil(h)
if h == nil {
// this should not happen, just to keep code safe
h = contextutil.IgnoreWarn
}
return func(s *evalCtxState) {
s.warnHandler = h
}
}
// WithSQLMode sets the sql mode for the `EvalContext`.
func WithSQLMode(sqlMode mysql.SQLMode) EvalCtxOption {
return func(s *evalCtxState) {
s.sqlMode = sqlMode
}
}
// WithTypeFlags sets the type flags for the `EvalContext`.
func WithTypeFlags(flags types.Flags) EvalCtxOption {
return func(s *evalCtxState) {
s.typeCtx = s.typeCtx.WithFlags(flags)
}
}
// WithLocation sets the timezone info for the `EvalContext`.
func WithLocation(loc *time.Location) EvalCtxOption {
intest.AssertNotNil(loc)
if loc == nil {
// this should not happen, just to keep code safe
loc = time.UTC
}
return func(s *evalCtxState) {
s.typeCtx = s.typeCtx.WithLocation(loc)
}
}
// WithErrLevelMap sets the error level map for the `EvalContext`.
func WithErrLevelMap(level errctx.LevelMap) EvalCtxOption {
return func(s *evalCtxState) {
s.errCtx = s.errCtx.WithErrGroupLevels(level)
}
}
// WithCurrentDB sets the current database name for the `EvalContext`.
func WithCurrentDB(db string) EvalCtxOption {
return func(s *evalCtxState) {
s.currentDB = db
}
}
// WithCurrentTime sets the current time for the `EvalContext`.
func WithCurrentTime(fn func() (time.Time, error)) EvalCtxOption {
intest.AssertNotNil(fn)
return func(s *evalCtxState) {
s.currentTime = &timeOnce{timeFn: fn}
}
}
// WithMaxAllowedPacket sets the value of the 'max_allowed_packet' system variable.
func WithMaxAllowedPacket(size uint64) EvalCtxOption {
return func(s *evalCtxState) {
s.maxAllowedPacket = size
}
}
// WithDefaultWeekFormatMode sets the value of the 'default_week_format' system variable.
func WithDefaultWeekFormatMode(mode string) EvalCtxOption {
return func(s *evalCtxState) {
s.defaultWeekFormatMode = mode
}
}
// WithDivPrecisionIncrement sets the value of the 'div_precision_increment' system variable.
func WithDivPrecisionIncrement(inc int) EvalCtxOption {
return func(s *evalCtxState) {
s.divPrecisionIncrement = inc
}
}
// WithOptionalProperty sets the optional property providers
func WithOptionalProperty(providers ...exprctx.OptionalEvalPropProvider) EvalCtxOption {
return func(s *evalCtxState) {
s.props = expropt.OptionalEvalPropProviders{}
for _, p := range providers {
s.props.Add(p)
}
}
}
// WithParamList sets the param list for the `EvalContext`.
func WithParamList(params *variable.PlanCacheParamList) EvalCtxOption {
return func(s *evalCtxState) {
s.paramList = make([]types.Datum, len(params.AllParamValues()))
for i, v := range params.AllParamValues() {
s.paramList[i] = v
}
}
}
// WithEnableRedactLog sets the value of the 'tidb_redact_log' system variable.
func WithEnableRedactLog(enableRedactLog string) EvalCtxOption {
return func(s *evalCtxState) {
s.enableRedactLog = enableRedactLog
}
}
// WithUserVarsReader set the user variables reader for the `EvalContext`.
func WithUserVarsReader(vars variable.UserVarsReader) EvalCtxOption {
return func(s *evalCtxState) {
s.userVars = vars
}
}
var defaultSQLMode = func() mysql.SQLMode {
mode, err := mysql.GetSQLMode(mysql.DefaultSQLMode)
if err != nil {
panic(err)
}
return mode
}()
// EvalContext implements `EvalContext` to provide a static context for expression evaluation.
// The "static" means comparing with `EvalContext`, its internal state does not relay on the session or other
// complex contexts that keeps immutable for most fields.
type EvalContext struct {
id uint64
evalCtxState
}
// NewEvalContext creates a new `EvalContext` with the given options.
func NewEvalContext(opt ...EvalCtxOption) *EvalContext {
ctx := &EvalContext{
id: contextutil.GenContextID(),
evalCtxState: evalCtxState{
currentTime: &timeOnce{},
sqlMode: defaultSQLMode,
maxAllowedPacket: vardef.DefMaxAllowedPacket,
enableRedactLog: vardef.DefTiDBRedactLog,
defaultWeekFormatMode: vardef.DefDefaultWeekFormat,
divPrecisionIncrement: vardef.DefDivPrecisionIncrement,
},
}
ctx.typeCtx = types.NewContext(types.StrictFlags, time.UTC, ctx)
ctx.errCtx = errctx.NewContext(ctx)
for _, o := range opt {
o(&ctx.evalCtxState)
}
if ctx.warnHandler == nil {
ctx.warnHandler = contextutil.NewStaticWarnHandler(0)
}
if ctx.userVars == nil {
ctx.userVars = variable.NewUserVars()
}
return ctx
}
// CtxID returns the context ID.
func (ctx *EvalContext) CtxID() uint64 {
return ctx.id
}
// SQLMode returns the sql mode
func (ctx *EvalContext) SQLMode() mysql.SQLMode {
return ctx.sqlMode
}
// TypeCtx returns the types.Context
func (ctx *EvalContext) TypeCtx() types.Context {
return ctx.typeCtx
}
// ErrCtx returns the errctx.Context
func (ctx *EvalContext) ErrCtx() errctx.Context {
return ctx.errCtx
}
// Location returns the timezone info
func (ctx *EvalContext) Location() *time.Location {
return ctx.typeCtx.Location()
}
// AppendWarning append warnings to the context.
func (ctx *EvalContext) AppendWarning(err error) {
if h := ctx.warnHandler; h != nil {
h.AppendWarning(err)
}
}
// AppendNote appends notes to the context.
func (ctx *EvalContext) AppendNote(err error) {
if h := ctx.warnHandler; h != nil {
h.AppendNote(err)
}
}
// WarningCount gets warning count.
func (ctx *EvalContext) WarningCount() int {
if h := ctx.warnHandler; h != nil {
return h.WarningCount()
}
return 0
}
// TruncateWarnings truncates warnings begin from start and returns the truncated warnings.
func (ctx *EvalContext) TruncateWarnings(start int) []contextutil.SQLWarn {
if h := ctx.warnHandler; h != nil {
return h.TruncateWarnings(start)
}
return nil
}
// CopyWarnings copies warnings to dst slice.
func (ctx *EvalContext) CopyWarnings(dst []contextutil.SQLWarn) []contextutil.SQLWarn {
if h := ctx.warnHandler; h != nil {
return h.CopyWarnings(dst)
}
return nil
}
// CurrentDB return the current database name
func (ctx *EvalContext) CurrentDB() string {
return ctx.currentDB
}
// CurrentTime return the current time
func (ctx *EvalContext) CurrentTime() (tm time.Time, err error) {
return ctx.currentTime.getTime(ctx.Location())
}
// GetMaxAllowedPacket returns the value of the 'max_allowed_packet' system variable.
func (ctx *EvalContext) GetMaxAllowedPacket() uint64 {
return ctx.maxAllowedPacket
}
// GetTiDBRedactLog returns the value of the 'tidb_redact_log' system variable.
func (ctx *EvalContext) GetTiDBRedactLog() string {
return ctx.enableRedactLog
}
// GetDefaultWeekFormatMode returns the value of the 'default_week_format' system variable.
func (ctx *EvalContext) GetDefaultWeekFormatMode() string {
return ctx.defaultWeekFormatMode
}
// GetDivPrecisionIncrement returns the specified value of DivPrecisionIncrement.
func (ctx *EvalContext) GetDivPrecisionIncrement() int {
return ctx.divPrecisionIncrement
}
// GetUserVarsReader returns the user variables.
func (ctx *EvalContext) GetUserVarsReader() variable.UserVarsReader {
return ctx.userVars
}
// GetOptionalPropSet gets the optional property set from context
func (ctx *EvalContext) GetOptionalPropSet() exprctx.OptionalEvalPropKeySet {
return ctx.props.PropKeySet()
}
// GetOptionalPropProvider gets the optional property provider by key
func (ctx *EvalContext) GetOptionalPropProvider(key exprctx.OptionalEvalPropKey) (exprctx.OptionalEvalPropProvider, bool) {
return ctx.props.Get(key)
}
// Apply returns a new `EvalContext` with the fields updated according to the given options.
func (ctx *EvalContext) Apply(opt ...EvalCtxOption) *EvalContext {
newCtx := &EvalContext{
id: contextutil.GenContextID(),
evalCtxState: ctx.evalCtxState,
}
// current time should use the previous one by default
newCtx.currentTime = &timeOnce{timeFn: ctx.CurrentTime}
// typeCtx and errCtx should be reset because warn handler changed
newCtx.typeCtx = types.NewContext(ctx.typeCtx.Flags(), ctx.typeCtx.Location(), newCtx)
newCtx.errCtx = errctx.NewContextWithLevels(ctx.errCtx.LevelMap(), newCtx)
// Apply options
for _, o := range opt {
o(&newCtx.evalCtxState)
}
return newCtx
}
// GetParamValue returns the value of the parameter by index.
func (ctx *EvalContext) GetParamValue(idx int) (types.Datum, error) {
if idx < 0 || idx >= len(ctx.paramList) {
return types.Datum{}, exprctx.ErrParamIndexExceedParamCounts
}
return ctx.paramList[idx], nil
}
var _ exprctx.StaticConvertibleEvalContext = &EvalContext{}
// AllParamValues implements context.StaticConvertibleEvalContext.
func (ctx *EvalContext) AllParamValues() []types.Datum {
return ctx.paramList
}
// GetWarnHandler implements context.StaticConvertibleEvalContext.
func (ctx *EvalContext) GetWarnHandler() contextutil.WarnHandler {
return ctx.warnHandler
}
// LoadSystemVars loads system variables and returns a new `EvalContext` with system variables loaded.
func (ctx *EvalContext) LoadSystemVars(sysVars map[string]string) (*EvalContext, error) {
sessionVars, err := newSessionVarsWithSystemVariables(sysVars)
if err != nil {
return nil, err
}
return ctx.loadSessionVarsInternal(sessionVars, sysVars), nil
}
func (ctx *EvalContext) loadSessionVarsInternal(
sessionVars *variable.SessionVars, sysVars map[string]string,
) *EvalContext {
opts := make([]EvalCtxOption, 0, 8)
for name, val := range sysVars {
name = strings.ToLower(name)
switch name {
case vardef.TimeZone:
opts = append(opts, WithLocation(sessionVars.Location()))
case vardef.SQLModeVar:
opts = append(opts, WithSQLMode(sessionVars.SQLMode))
case vardef.Timestamp:
opts = append(opts, WithCurrentTime(ctx.currentTimeFuncFromStringVal(val)))
case vardef.MaxAllowedPacket:
opts = append(opts, WithMaxAllowedPacket(sessionVars.MaxAllowedPacket))
case vardef.TiDBRedactLog:
opts = append(opts, WithEnableRedactLog(sessionVars.EnableRedactLog))
case vardef.DefaultWeekFormat:
opts = append(opts, WithDefaultWeekFormatMode(val))
case vardef.DivPrecisionIncrement:
opts = append(opts, WithDivPrecisionIncrement(sessionVars.DivPrecisionIncrement))
}
}
return ctx.Apply(opts...)
}
func (ctx *EvalContext) currentTimeFuncFromStringVal(val string) func() (time.Time, error) {
return func() (time.Time, error) {
if val == vardef.DefTimestamp {
return time.Now(), nil
}
ts, err := types.StrToFloat(types.StrictContext, val, false)
if err != nil {
return time.Time{}, err
}
seconds, fractionalSeconds := math.Modf(ts)
return time.Unix(int64(seconds), int64(fractionalSeconds*float64(time.Second))), nil
}
}
func newSessionVarsWithSystemVariables(vars map[string]string) (*variable.SessionVars, error) {
sessionVars := variable.NewSessionVars(nil)
var cs, col []string
for name, val := range vars {
switch strings.ToLower(name) {
// `charset_connection` and `collation_connection` will overwrite each other.
// To make the result more determinate, just set them at last step in order:
// `charset_connection` first, then `collation_connection`.
case vardef.CharacterSetConnection:
cs = []string{name, val}
case vardef.CollationConnection:
col = []string{name, val}
default:
if err := sessionVars.SetSystemVar(name, val); err != nil {
return nil, err
}
}
}
if cs != nil {
if err := sessionVars.SetSystemVar(cs[0], cs[1]); err != nil {
return nil, err
}
}
if col != nil {
if err := sessionVars.SetSystemVar(col[0], col[1]); err != nil {
return nil, err
}
}
return sessionVars, nil
}
// MakeEvalContextStatic converts the `exprctx.StaticConvertibleEvalContext` to `EvalContext`.
func MakeEvalContextStatic(ctx exprctx.StaticConvertibleEvalContext) *EvalContext {
typeCtx := ctx.TypeCtx()
errCtx := ctx.ErrCtx()
// TODO: at least provide some optional eval prop provider which is suitable to be used in the static context.
props := make([]exprctx.OptionalEvalPropProvider, 0, exprctx.OptPropsCnt)
params := variable.NewPlanCacheParamList()
for _, param := range ctx.AllParamValues() {
params.Append(param)
}
// TODO: use a more structural way to replace the closure.
// These closure makes sure the fields which may be changed in the execution of the next statement will not be embedded into them, to make
// sure it's safe to call them after the session continues to execute other statements.
staticCtx := NewEvalContext(
WithWarnHandler(ctx.GetWarnHandler()),
WithSQLMode(ctx.SQLMode()),
WithTypeFlags(typeCtx.Flags()),
WithLocation(typeCtx.Location()),
WithErrLevelMap(errCtx.LevelMap()),
WithCurrentDB(ctx.CurrentDB()),
WithCurrentTime(func() func() (time.Time, error) {
currentTime, currentTimeErr := ctx.CurrentTime()
return func() (time.Time, error) {
return currentTime, currentTimeErr
}
}()),
WithMaxAllowedPacket(ctx.GetMaxAllowedPacket()),
WithDefaultWeekFormatMode(ctx.GetDefaultWeekFormatMode()),
WithDivPrecisionIncrement(ctx.GetDivPrecisionIncrement()),
WithParamList(params),
WithUserVarsReader(ctx.GetUserVarsReader().Clone()),
WithOptionalProperty(props...),
WithEnableRedactLog(ctx.GetTiDBRedactLog()),
)
return staticCtx
}