384 lines
11 KiB
Go
384 lines
11 KiB
Go
// Copyright 2024 PingCAP, Inc.
|
|
//
|
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
|
// you may not use this file except in compliance with the License.
|
|
// You may obtain a copy of the License at
|
|
//
|
|
// http://www.apache.org/licenses/LICENSE-2.0
|
|
//
|
|
// Unless required by applicable law or agreed to in writing, software
|
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
// See the License for the specific language governing permissions and
|
|
// limitations under the License.
|
|
|
|
package contextstatic
|
|
|
|
import (
|
|
"sync"
|
|
"sync/atomic"
|
|
"time"
|
|
|
|
"github.com/pingcap/tidb/pkg/errctx"
|
|
exprctx "github.com/pingcap/tidb/pkg/expression/context"
|
|
"github.com/pingcap/tidb/pkg/expression/contextopt"
|
|
"github.com/pingcap/tidb/pkg/parser/mysql"
|
|
"github.com/pingcap/tidb/pkg/sessionctx/variable"
|
|
"github.com/pingcap/tidb/pkg/types"
|
|
contextutil "github.com/pingcap/tidb/pkg/util/context"
|
|
"github.com/pingcap/tidb/pkg/util/intest"
|
|
)
|
|
|
|
// StaticEvalContext should implement `EvalContext`
|
|
var _ exprctx.EvalContext = &StaticEvalContext{}
|
|
|
|
type timeOnce struct {
|
|
sync.Mutex
|
|
time atomic.Pointer[time.Time]
|
|
timeFn func() (time.Time, error)
|
|
}
|
|
|
|
func (t *timeOnce) getTime(loc *time.Location) (tm time.Time, err error) {
|
|
if p := t.time.Load(); p != nil {
|
|
return *p, nil
|
|
}
|
|
|
|
t.Lock()
|
|
defer t.Unlock()
|
|
|
|
if p := t.time.Load(); p != nil {
|
|
return *p, nil
|
|
}
|
|
|
|
if fn := t.timeFn; fn != nil {
|
|
tm, err = fn()
|
|
} else {
|
|
tm, err = time.Now(), nil
|
|
}
|
|
|
|
if err != nil {
|
|
return
|
|
}
|
|
|
|
tm = tm.In(loc)
|
|
t.time.Store(&tm)
|
|
return
|
|
}
|
|
|
|
// staticEvalCtxState is the internal state for `StaticEvalContext`.
|
|
// We make it as a standalone private struct here to make sure `StaticEvalCtxOption` can only be called in constructor.
|
|
type staticEvalCtxState struct {
|
|
warnHandler contextutil.WarnHandler
|
|
sqlMode mysql.SQLMode
|
|
typeCtx types.Context
|
|
errCtx errctx.Context
|
|
currentDB string
|
|
currentTime *timeOnce
|
|
maxAllowedPacket uint64
|
|
defaultWeekFormatMode string
|
|
divPrecisionIncrement int
|
|
requestVerificationFn func(db, table, column string, priv mysql.PrivilegeType) bool
|
|
requestDynamicVerificationFn func(privName string, grantable bool) bool
|
|
paramList []types.Datum
|
|
props contextopt.OptionalEvalPropProviders
|
|
}
|
|
|
|
// StaticEvalCtxOption is the option to set `StaticEvalContext`.
|
|
type StaticEvalCtxOption func(*staticEvalCtxState)
|
|
|
|
// WithWarnHandler sets the warn handler for the `StaticEvalContext`.
|
|
func WithWarnHandler(h contextutil.WarnHandler) StaticEvalCtxOption {
|
|
intest.AssertNotNil(h)
|
|
if h == nil {
|
|
// this should not happen, just to keep code safe
|
|
h = contextutil.IgnoreWarn
|
|
}
|
|
|
|
return func(s *staticEvalCtxState) {
|
|
s.warnHandler = h
|
|
}
|
|
}
|
|
|
|
// WithSQLMode sets the sql mode for the `StaticEvalContext`.
|
|
func WithSQLMode(sqlMode mysql.SQLMode) StaticEvalCtxOption {
|
|
return func(s *staticEvalCtxState) {
|
|
s.sqlMode = sqlMode
|
|
}
|
|
}
|
|
|
|
// WithTypeFlags sets the type flags for the `StaticEvalContext`.
|
|
func WithTypeFlags(flags types.Flags) StaticEvalCtxOption {
|
|
return func(s *staticEvalCtxState) {
|
|
s.typeCtx = s.typeCtx.WithFlags(flags)
|
|
}
|
|
}
|
|
|
|
// WithLocation sets the timezone info for the `StaticEvalContext`.
|
|
func WithLocation(loc *time.Location) StaticEvalCtxOption {
|
|
intest.AssertNotNil(loc)
|
|
if loc == nil {
|
|
// this should not happen, just to keep code safe
|
|
loc = time.UTC
|
|
}
|
|
return func(s *staticEvalCtxState) {
|
|
s.typeCtx = s.typeCtx.WithLocation(loc)
|
|
}
|
|
}
|
|
|
|
// WithErrLevelMap sets the error level map for the `StaticEvalContext`.
|
|
func WithErrLevelMap(level errctx.LevelMap) StaticEvalCtxOption {
|
|
return func(s *staticEvalCtxState) {
|
|
s.errCtx = s.errCtx.WithErrGroupLevels(level)
|
|
}
|
|
}
|
|
|
|
// WithCurrentDB sets the current database name for the `StaticEvalContext`.
|
|
func WithCurrentDB(db string) StaticEvalCtxOption {
|
|
return func(s *staticEvalCtxState) {
|
|
s.currentDB = db
|
|
}
|
|
}
|
|
|
|
// WithCurrentTime sets the current time for the `StaticEvalContext`.
|
|
func WithCurrentTime(fn func() (time.Time, error)) StaticEvalCtxOption {
|
|
intest.AssertNotNil(fn)
|
|
return func(s *staticEvalCtxState) {
|
|
s.currentTime = &timeOnce{timeFn: fn}
|
|
}
|
|
}
|
|
|
|
// WithMaxAllowedPacket sets the value of the 'max_allowed_packet' system variable.
|
|
func WithMaxAllowedPacket(size uint64) StaticEvalCtxOption {
|
|
return func(s *staticEvalCtxState) {
|
|
s.maxAllowedPacket = size
|
|
}
|
|
}
|
|
|
|
// WithDefaultWeekFormatMode sets the value of the 'default_week_format' system variable.
|
|
func WithDefaultWeekFormatMode(mode string) StaticEvalCtxOption {
|
|
return func(s *staticEvalCtxState) {
|
|
s.defaultWeekFormatMode = mode
|
|
}
|
|
}
|
|
|
|
// WithDivPrecisionIncrement sets the value of the 'div_precision_increment' system variable.
|
|
func WithDivPrecisionIncrement(inc int) StaticEvalCtxOption {
|
|
return func(s *staticEvalCtxState) {
|
|
s.divPrecisionIncrement = inc
|
|
}
|
|
}
|
|
|
|
// WithPrivCheck sets the requestVerificationFn
|
|
func WithPrivCheck(fn func(db, table, column string, priv mysql.PrivilegeType) bool) StaticEvalCtxOption {
|
|
return func(s *staticEvalCtxState) {
|
|
s.requestVerificationFn = fn
|
|
}
|
|
}
|
|
|
|
// WithDynamicPrivCheck sets the requestDynamicVerificationFn
|
|
func WithDynamicPrivCheck(fn func(privName string, grantable bool) bool) StaticEvalCtxOption {
|
|
return func(s *staticEvalCtxState) {
|
|
s.requestDynamicVerificationFn = fn
|
|
}
|
|
}
|
|
|
|
// WithOptionalProperty sets the optional property providers
|
|
func WithOptionalProperty(providers ...exprctx.OptionalEvalPropProvider) StaticEvalCtxOption {
|
|
return func(s *staticEvalCtxState) {
|
|
s.props = contextopt.OptionalEvalPropProviders{}
|
|
for _, p := range providers {
|
|
s.props.Add(p)
|
|
}
|
|
}
|
|
}
|
|
|
|
// WithParamList sets the param list for the `StaticEvalContext`.
|
|
func WithParamList(params *variable.PlanCacheParamList) StaticEvalCtxOption {
|
|
return func(s *staticEvalCtxState) {
|
|
s.paramList = make([]types.Datum, len(params.AllParamValues()))
|
|
for i, v := range params.AllParamValues() {
|
|
s.paramList[i] = v
|
|
}
|
|
}
|
|
}
|
|
|
|
var defaultSQLMode = func() mysql.SQLMode {
|
|
mode, err := mysql.GetSQLMode(mysql.DefaultSQLMode)
|
|
if err != nil {
|
|
panic(err)
|
|
}
|
|
return mode
|
|
}()
|
|
|
|
// StaticEvalContext implements `EvalContext` to provide a static context for expression evaluation.
|
|
// The "static" means comparing with `SessionEvalContext`, its internal state does not relay on the session or other
|
|
// complex contexts that keeps immutable for most fields.
|
|
type StaticEvalContext struct {
|
|
id uint64
|
|
staticEvalCtxState
|
|
}
|
|
|
|
// NewStaticEvalContext creates a new `StaticEvalContext` with the given options.
|
|
func NewStaticEvalContext(opt ...StaticEvalCtxOption) *StaticEvalContext {
|
|
ctx := &StaticEvalContext{
|
|
id: contextutil.GenContextID(),
|
|
staticEvalCtxState: staticEvalCtxState{
|
|
currentTime: &timeOnce{},
|
|
sqlMode: defaultSQLMode,
|
|
maxAllowedPacket: variable.DefMaxAllowedPacket,
|
|
defaultWeekFormatMode: variable.DefDefaultWeekFormat,
|
|
divPrecisionIncrement: variable.DefDivPrecisionIncrement,
|
|
},
|
|
}
|
|
|
|
ctx.typeCtx = types.NewContext(types.StrictFlags, time.UTC, ctx)
|
|
ctx.errCtx = errctx.NewContext(ctx)
|
|
|
|
for _, o := range opt {
|
|
o(&ctx.staticEvalCtxState)
|
|
}
|
|
|
|
if ctx.warnHandler == nil {
|
|
ctx.warnHandler = contextutil.NewStaticWarnHandler(0)
|
|
}
|
|
|
|
return ctx
|
|
}
|
|
|
|
// CtxID returns the context ID.
|
|
func (ctx *StaticEvalContext) CtxID() uint64 {
|
|
return ctx.id
|
|
}
|
|
|
|
// SQLMode returns the sql mode
|
|
func (ctx *StaticEvalContext) SQLMode() mysql.SQLMode {
|
|
return ctx.sqlMode
|
|
}
|
|
|
|
// TypeCtx returns the types.Context
|
|
func (ctx *StaticEvalContext) TypeCtx() types.Context {
|
|
return ctx.typeCtx
|
|
}
|
|
|
|
// ErrCtx returns the errctx.Context
|
|
func (ctx *StaticEvalContext) ErrCtx() errctx.Context {
|
|
return ctx.errCtx
|
|
}
|
|
|
|
// Location returns the timezone info
|
|
func (ctx *StaticEvalContext) Location() *time.Location {
|
|
return ctx.typeCtx.Location()
|
|
}
|
|
|
|
// AppendWarning append warnings to the context.
|
|
func (ctx *StaticEvalContext) AppendWarning(err error) {
|
|
if h := ctx.warnHandler; h != nil {
|
|
h.AppendWarning(err)
|
|
}
|
|
}
|
|
|
|
// WarningCount gets warning count.
|
|
func (ctx *StaticEvalContext) WarningCount() int {
|
|
if h := ctx.warnHandler; h != nil {
|
|
return h.WarningCount()
|
|
}
|
|
return 0
|
|
}
|
|
|
|
// TruncateWarnings truncates warnings begin from start and returns the truncated warnings.
|
|
func (ctx *StaticEvalContext) TruncateWarnings(start int) []contextutil.SQLWarn {
|
|
if h := ctx.warnHandler; h != nil {
|
|
return h.TruncateWarnings(start)
|
|
}
|
|
return nil
|
|
}
|
|
|
|
// CopyWarnings copies warnings to dst slice.
|
|
func (ctx *StaticEvalContext) CopyWarnings(dst []contextutil.SQLWarn) []contextutil.SQLWarn {
|
|
if h := ctx.warnHandler; h != nil {
|
|
return h.CopyWarnings(dst)
|
|
}
|
|
return nil
|
|
}
|
|
|
|
// CurrentDB return the current database name
|
|
func (ctx *StaticEvalContext) CurrentDB() string {
|
|
return ctx.currentDB
|
|
}
|
|
|
|
// CurrentTime return the current time
|
|
func (ctx *StaticEvalContext) CurrentTime() (tm time.Time, err error) {
|
|
return ctx.currentTime.getTime(ctx.Location())
|
|
}
|
|
|
|
// GetMaxAllowedPacket returns the value of the 'max_allowed_packet' system variable.
|
|
func (ctx *StaticEvalContext) GetMaxAllowedPacket() uint64 {
|
|
return ctx.maxAllowedPacket
|
|
}
|
|
|
|
// GetDefaultWeekFormatMode returns the value of the 'default_week_format' system variable.
|
|
func (ctx *StaticEvalContext) GetDefaultWeekFormatMode() string {
|
|
return ctx.defaultWeekFormatMode
|
|
}
|
|
|
|
// GetDivPrecisionIncrement returns the specified value of DivPrecisionIncrement.
|
|
func (ctx *StaticEvalContext) GetDivPrecisionIncrement() int {
|
|
return ctx.divPrecisionIncrement
|
|
}
|
|
|
|
// RequestVerification verifies user privilege
|
|
func (ctx *StaticEvalContext) RequestVerification(db, table, column string, priv mysql.PrivilegeType) bool {
|
|
if fn := ctx.requestVerificationFn; fn != nil {
|
|
return fn(db, table, column, priv)
|
|
}
|
|
return true
|
|
}
|
|
|
|
// RequestDynamicVerification verifies user privilege for a DYNAMIC privilege.
|
|
func (ctx *StaticEvalContext) RequestDynamicVerification(privName string, grantable bool) bool {
|
|
if fn := ctx.requestDynamicVerificationFn; fn != nil {
|
|
return fn(privName, grantable)
|
|
}
|
|
return true
|
|
}
|
|
|
|
// GetOptionalPropSet gets the optional property set from context
|
|
func (ctx *StaticEvalContext) GetOptionalPropSet() exprctx.OptionalEvalPropKeySet {
|
|
return ctx.props.PropKeySet()
|
|
}
|
|
|
|
// GetOptionalPropProvider gets the optional property provider by key
|
|
func (ctx *StaticEvalContext) GetOptionalPropProvider(key exprctx.OptionalEvalPropKey) (exprctx.OptionalEvalPropProvider, bool) {
|
|
return ctx.props.Get(key)
|
|
}
|
|
|
|
// Apply returns a new `StaticEvalContext` with the fields updated according to the given options.
|
|
func (ctx *StaticEvalContext) Apply(opt ...StaticEvalCtxOption) *StaticEvalContext {
|
|
newCtx := &StaticEvalContext{
|
|
id: contextutil.GenContextID(),
|
|
staticEvalCtxState: ctx.staticEvalCtxState,
|
|
}
|
|
|
|
// current time should use the previous one by default
|
|
newCtx.currentTime = &timeOnce{timeFn: ctx.CurrentTime}
|
|
|
|
// typeCtx and errCtx should be reset because warn handler changed
|
|
newCtx.typeCtx = types.NewContext(ctx.typeCtx.Flags(), ctx.typeCtx.Location(), newCtx)
|
|
newCtx.errCtx = errctx.NewContextWithLevels(ctx.errCtx.LevelMap(), newCtx)
|
|
|
|
// Apply options
|
|
for _, o := range opt {
|
|
o(&newCtx.staticEvalCtxState)
|
|
}
|
|
|
|
return newCtx
|
|
}
|
|
|
|
// GetParamValue returns the value of the parameter by index.
|
|
func (ctx *StaticEvalContext) GetParamValue(idx int) types.Datum {
|
|
if idx < 0 || idx >= len(ctx.paramList) {
|
|
return types.Datum{}
|
|
}
|
|
return ctx.paramList[idx]
|
|
}
|