// Copyright 2016 PingCAP, Inc. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. package expression import ( "bytes" "cmp" "context" "encoding/binary" "fmt" "maps" "math" "slices" "strconv" "strings" "sync" "unicode" "unicode/utf8" "github.com/pingcap/errors" "github.com/pingcap/failpoint" "github.com/pingcap/tidb/pkg/expression/expropt" "github.com/pingcap/tidb/pkg/kv" "github.com/pingcap/tidb/pkg/param" "github.com/pingcap/tidb/pkg/parser/ast" "github.com/pingcap/tidb/pkg/parser/charset" "github.com/pingcap/tidb/pkg/parser/mysql" "github.com/pingcap/tidb/pkg/parser/opcode" "github.com/pingcap/tidb/pkg/parser/terror" "github.com/pingcap/tidb/pkg/types" driver "github.com/pingcap/tidb/pkg/types/parser_driver" "github.com/pingcap/tidb/pkg/util/chunk" "github.com/pingcap/tidb/pkg/util/collate" "github.com/pingcap/tidb/pkg/util/hack" "github.com/pingcap/tidb/pkg/util/intest" "github.com/pingcap/tidb/pkg/util/intset" "github.com/pingcap/tidb/pkg/util/logutil" "go.uber.org/zap" ) // cowExprRef is a copy-on-write slice ref util using in `ColumnSubstitute` // to reduce unnecessary allocation for Expression arguments array type cowExprRef struct { ref []Expression new []Expression } // Set will allocate new array if changed flag true func (c *cowExprRef) Set(i int, changed bool, val Expression) { if c.new != nil { c.new[i] = val return } if !changed { return } c.new = slices.Clone(c.ref) c.new[i] = val } // Result return the final reference func (c *cowExprRef) Result() []Expression { if c.new != nil { return c.new } return c.ref } // Filter the input expressions, append the results to result. func Filter(result []Expression, input []Expression, filter func(Expression) bool) []Expression { for _, e := range input { if filter(e) { result = append(result, e) } } return result } // FilterOutInPlace do the filtering out in place. // The remained are the ones who doesn't match the filter, storing in the original slice. // The filteredOut are the ones match the filter, storing in a new slice. func FilterOutInPlace(input []Expression, filter func(Expression) bool) (remained, filteredOut []Expression) { for i := len(input) - 1; i >= 0; i-- { if filter(input[i]) { filteredOut = append(filteredOut, input[i]) input = slices.Delete(input, i, i+1) } } return input, filteredOut } // ExtractDependentColumns extracts all dependent columns from a virtual column. func ExtractDependentColumns(expr Expression) []*Column { // Pre-allocate a slice to reduce allocation, 8 doesn't have special meaning. result := make([]*Column, 0, 8) return extractDependentColumns(result, expr) } func extractDependentColumns(result []*Column, expr Expression) []*Column { switch v := expr.(type) { case *Column: result = append(result, v) if v.VirtualExpr != nil { result = extractDependentColumns(result, v.VirtualExpr) } case *ScalarFunction: for _, arg := range v.GetArgs() { result = extractDependentColumns(result, arg) } } return result } // ExtractColumns extracts all columns from an expression. func ExtractColumns(expr Expression) []*Column { // Pre-allocate a slice to reduce allocation, 8 doesn't have special meaning. tmp := make(map[int64]*Column, 8) extractColumns(tmp, expr, nil) result := slices.Collect(maps.Values(tmp)) // The keys in a map are unordered, so to ensure stability, we need to sort them here. slices.SortFunc(result, func(a, b *Column) int { return cmp.Compare(a.UniqueID, b.UniqueID) }) return result } // ExtractCorColumns extracts correlated column from given expression. func ExtractCorColumns(expr Expression) (cols []*CorrelatedColumn) { switch v := expr.(type) { case *CorrelatedColumn: return []*CorrelatedColumn{v} case *ScalarFunction: for _, arg := range v.GetArgs() { cols = append(cols, ExtractCorColumns(arg)...) } } return } // ExtractColumnsFromExpressions is a more efficient version of ExtractColumns for batch operation. // filter can be nil, or a function to filter the result column. // It's often observed that the pattern of the caller like this: // // cols := ExtractColumns(...) // // for _, col := range cols { // if xxx(col) {...} // } // // Provide an additional filter argument, this can be done in one step. // To avoid allocation for cols that not need. func ExtractColumnsFromExpressions(exprs []Expression, filter func(*Column) bool) []*Column { if len(exprs) == 0 { return nil } m := make(map[int64]*Column, len(exprs)) for _, expr := range exprs { extractColumns(m, expr, filter) } result := slices.Collect(maps.Values(m)) // The keys in a map are unordered, so to ensure stability, we need to sort them here. slices.SortFunc(result, func(a, b *Column) int { return cmp.Compare(a.UniqueID, b.UniqueID) }) return result } // ExtractColumnsMapFromExpressions it the same as ExtractColumnsFromExpressions, but return a map func ExtractColumnsMapFromExpressions(filter func(*Column) bool, exprs ...Expression) map[int64]*Column { if len(exprs) == 0 { return nil } m := make(map[int64]*Column, len(exprs)) for _, expr := range exprs { extractColumns(m, expr, filter) } return m } var uniqueIDToColumnMapPool = sync.Pool{ New: func() any { return make(map[int64]*Column, 4) }, } // GetUniqueIDToColumnMap gets map[int64]*Column map from the pool. func GetUniqueIDToColumnMap() map[int64]*Column { return uniqueIDToColumnMapPool.Get().(map[int64]*Column) } // PutUniqueIDToColumnMap puts map[int64]*Column map back to the pool. func PutUniqueIDToColumnMap(m map[int64]*Column) { clear(m) uniqueIDToColumnMapPool.Put(m) } // ExtractColumnsMapFromExpressionsWithReusedMap is the same as ExtractColumnsFromExpressions, but map can be reused. func ExtractColumnsMapFromExpressionsWithReusedMap(m map[int64]*Column, filter func(*Column) bool, exprs ...Expression) { if len(exprs) == 0 { return } if m == nil { m = make(map[int64]*Column, len(exprs)) } for _, expr := range exprs { extractColumns(m, expr, filter) } } // ExtractAllColumnsFromExpressionsInUsedSlices is the same as ExtractColumns. but it can reuse the memory. func ExtractAllColumnsFromExpressionsInUsedSlices(reuse []*Column, filter func(*Column) bool, exprs ...Expression) []*Column { if len(exprs) == 0 { return nil } for _, expr := range exprs { reuse = extractColumnsSlices(reuse, expr, filter) } slices.SortFunc(reuse, func(a, b *Column) int { return cmp.Compare(a.UniqueID, b.UniqueID) }) reuse = slices.CompactFunc(reuse, func(a, b *Column) bool { return a.UniqueID == b.UniqueID }) return reuse } // ExtractAllColumnsFromExpressions is the same as ExtractColumnsFromExpressions. But this will not remove duplicates. func ExtractAllColumnsFromExpressions(exprs []Expression, filter func(*Column) bool) []*Column { if len(exprs) == 0 { return nil } result := make([]*Column, 0, 8) for _, expr := range exprs { result = extractColumnsSlices(result, expr, filter) } return result } // ExtractColumnsSetFromExpressions is the same as ExtractColumnsFromExpressions // it use the FastIntSet to save the Unique ID. func ExtractColumnsSetFromExpressions(m *intset.FastIntSet, filter func(*Column) bool, exprs ...Expression) { if len(exprs) == 0 { return } intest.Assert(m != nil) for _, expr := range exprs { extractColumnsSet(m, expr, filter) } } func extractColumns(result map[int64]*Column, expr Expression, filter func(*Column) bool) { switch v := expr.(type) { case *Column: if filter == nil || filter(v) { result[v.UniqueID] = v } case *ScalarFunction: for _, arg := range v.GetArgs() { extractColumns(result, arg, filter) } } } func extractColumnsSlices(result []*Column, expr Expression, filter func(*Column) bool) []*Column { switch v := expr.(type) { case *Column: if filter == nil || filter(v) { result = append(result, v) } case *ScalarFunction: for _, arg := range v.GetArgs() { result = extractColumnsSlices(result, arg, filter) } } return result } func extractColumnsSet(result *intset.FastIntSet, expr Expression, filter func(*Column) bool) { switch v := expr.(type) { case *Column: if filter == nil || filter(v) { result.Insert(int(v.UniqueID)) } case *ScalarFunction: for _, arg := range v.GetArgs() { extractColumnsSet(result, arg, filter) } } } // ExtractEquivalenceColumns detects the equivalence from CNF exprs. func ExtractEquivalenceColumns(result [][]Expression, exprs []Expression) [][]Expression { // exprs are CNF expressions, EQ condition only make sense in the top level of every expr. for _, expr := range exprs { result = extractEquivalenceColumns(result, expr) } return result } // FindUpperBound looks for column < constant or column <= constant and returns both the column // and constant. It return nil, 0 if the expression is not of this form. // It is used by derived Top N pattern and it is put here since it looks like // a general purpose routine. Similar routines can be added to find lower bound as well. func FindUpperBound(expr Expression) (*Column, int64) { scalarFunction, scalarFunctionOk := expr.(*ScalarFunction) if scalarFunctionOk { args := scalarFunction.GetArgs() if len(args) == 2 { col, colOk := args[0].(*Column) constant, constantOk := args[1].(*Constant) if colOk && constantOk && (scalarFunction.FuncName.L == ast.LT || scalarFunction.FuncName.L == ast.LE) { value, valueOk := constant.Value.GetValue().(int64) if valueOk { if scalarFunction.FuncName.L == ast.LT { return col, value - 1 } return col, value } } } } return nil, 0 } func extractEquivalenceColumns(result [][]Expression, expr Expression) [][]Expression { switch v := expr.(type) { case *ScalarFunction: // a==b, a<=>b, the latter one is evaluated to true when a,b are both null. if v.FuncName.L == ast.EQ || v.FuncName.L == ast.NullEQ { args := v.GetArgs() if len(args) == 2 { col1, ok1 := args[0].(*Column) col2, ok2 := args[1].(*Column) if ok1 && ok2 { result = append(result, []Expression{col1, col2}) } col, ok1 := args[0].(*Column) scl, ok2 := args[1].(*ScalarFunction) if ok1 && ok2 { result = append(result, []Expression{col, scl}) } col, ok1 = args[1].(*Column) scl, ok2 = args[0].(*ScalarFunction) if ok1 && ok2 { result = append(result, []Expression{col, scl}) } } return result } if v.FuncName.L == ast.In { args := v.GetArgs() // only `col in (only 1 element)`, can we build an equivalence here. if len(args[1:]) == 1 { col1, ok1 := args[0].(*Column) col2, ok2 := args[1].(*Column) if ok1 && ok2 { result = append(result, []Expression{col1, col2}) } col, ok1 := args[0].(*Column) scl, ok2 := args[1].(*ScalarFunction) if ok1 && ok2 { result = append(result, []Expression{col, scl}) } col, ok1 = args[1].(*Column) scl, ok2 = args[0].(*ScalarFunction) if ok1 && ok2 { result = append(result, []Expression{col, scl}) } } return result } // For Non-EQ function, we don't have to traverse down. // eg: (a=b or c=d) doesn't make any definitely equivalence assertion. } return result } // extractColumnsAndCorColumns extracts columns and correlated columns from `expr` and append them to `result`. func extractColumnsAndCorColumns(result []*Column, expr Expression) []*Column { switch v := expr.(type) { case *Column: result = append(result, v) case *CorrelatedColumn: result = append(result, &v.Column) case *ScalarFunction: for _, arg := range v.GetArgs() { result = extractColumnsAndCorColumns(result, arg) } } return result } // ExtractConstantEqColumnsOrScalar detects the constant equal relationship from CNF exprs. func ExtractConstantEqColumnsOrScalar(ctx BuildContext, result []Expression, exprs []Expression) []Expression { // exprs are CNF expressions, EQ condition only make sense in the top level of every expr. for _, expr := range exprs { result = extractConstantEqColumnsOrScalar(ctx, result, expr) } return result } func extractConstantEqColumnsOrScalar(ctx BuildContext, result []Expression, expr Expression) []Expression { switch v := expr.(type) { case *ScalarFunction: if v.FuncName.L == ast.EQ || v.FuncName.L == ast.NullEQ { args := v.GetArgs() if len(args) == 2 { col, ok1 := args[0].(*Column) _, ok2 := args[1].(*Constant) if ok1 && ok2 { result = append(result, col) } col, ok1 = args[1].(*Column) _, ok2 = args[0].(*Constant) if ok1 && ok2 { result = append(result, col) } // take the correlated column as constant here. col, ok1 = args[0].(*Column) _, ok2 = args[1].(*CorrelatedColumn) if ok1 && ok2 { result = append(result, col) } col, ok1 = args[1].(*Column) _, ok2 = args[0].(*CorrelatedColumn) if ok1 && ok2 { result = append(result, col) } scl, ok1 := args[0].(*ScalarFunction) _, ok2 = args[1].(*Constant) if ok1 && ok2 { result = append(result, scl) } scl, ok1 = args[1].(*ScalarFunction) _, ok2 = args[0].(*Constant) if ok1 && ok2 { result = append(result, scl) } // take the correlated column as constant here. scl, ok1 = args[0].(*ScalarFunction) _, ok2 = args[1].(*CorrelatedColumn) if ok1 && ok2 { result = append(result, scl) } scl, ok1 = args[1].(*ScalarFunction) _, ok2 = args[0].(*CorrelatedColumn) if ok1 && ok2 { result = append(result, scl) } } return result } if v.FuncName.L == ast.In { args := v.GetArgs() allArgsIsConst := true // only `col in (all same const)`, can col be the constant column. // eg: a in (1, "1") does, while a in (1, '2') doesn't. guard := args[1] for i, v := range args[1:] { if _, ok := v.(*Constant); !ok { allArgsIsConst = false break } if i == 0 { continue } if !guard.Equal(ctx.GetEvalCtx(), v) { allArgsIsConst = false break } } if allArgsIsConst { if col, ok := args[0].(*Column); ok { result = append(result, col) } else if scl, ok := args[0].(*ScalarFunction); ok { result = append(result, scl) } } return result } // For Non-EQ function, we don't have to traverse down. } return result } // ExtractColumnsAndCorColumnsFromExpressions extracts columns and correlated columns from expressions and append them to `result`. func ExtractColumnsAndCorColumnsFromExpressions(result []*Column, list []Expression) []*Column { for _, expr := range list { result = extractColumnsAndCorColumns(result, expr) } return result } // ExtractColumnSet extracts the different values of `UniqueId` for columns in expressions. func ExtractColumnSet(exprs ...Expression) intset.FastIntSet { set := intset.NewFastIntSet() for _, expr := range exprs { extractColumnSet(expr, &set) } return set } func extractColumnSet(expr Expression, set *intset.FastIntSet) { switch v := expr.(type) { case *Column: set.Insert(int(v.UniqueID)) case *ScalarFunction: for _, arg := range v.GetArgs() { extractColumnSet(arg, set) } } } // SetExprColumnInOperand is used to set columns in expr as InOperand. func SetExprColumnInOperand(expr Expression) Expression { switch v := expr.(type) { case *Column: col := v.Clone().(*Column) col.InOperand = true return col case *ScalarFunction: args := v.GetArgs() for i, arg := range args { args[i] = SetExprColumnInOperand(arg) } } return expr } // ColumnSubstitute substitutes the columns in filter to expressions in select fields. // e.g. select * from (select b as a from t) k where a < 10 => select * from (select b as a from t where b < 10) k. // TODO: remove this function and only use ColumnSubstituteImpl since this function swallows the error, which seems unsafe. func ColumnSubstitute(ctx BuildContext, expr Expression, schema *Schema, newExprs []Expression) Expression { _, _, resExpr := ColumnSubstituteImpl(ctx, expr, schema, newExprs, false) return resExpr } // ColumnSubstituteAll substitutes the columns just like ColumnSubstitute, but we don't accept partial substitution. // Only accept: // // 1: substitute them all once find col in schema. // 2: nothing in expr can be substituted. func ColumnSubstituteAll(ctx BuildContext, expr Expression, schema *Schema, newExprs []Expression) (bool, Expression) { _, hasFail, resExpr := ColumnSubstituteImpl(ctx, expr, schema, newExprs, true) return hasFail, resExpr } // ColumnSubstituteImpl tries to substitute column expr using newExprs, // the newFunctionInternal is only called if its child is substituted // @return bool means whether the expr has changed. // @return bool means whether the expr should change (has the dependency in schema, while the corresponding expr has some compatibility), but finally fallback. // @return Expression, the original expr or the changed expr, it depends on the first @return bool. func ColumnSubstituteImpl(ctx BuildContext, expr Expression, schema *Schema, newExprs []Expression, fail1Return bool) (bool, bool, Expression) { switch v := expr.(type) { case *Column: id := schema.ColumnIndex(v) if id == -1 { return false, false, v } newExpr := newExprs[id] if v.InOperand { newExpr = SetExprColumnInOperand(newExpr) } return true, false, newExpr case *ScalarFunction: substituted := false hasFail := false if v.FuncName.L == ast.Cast || v.FuncName.L == ast.Grouping { var newArg Expression substituted, hasFail, newArg = ColumnSubstituteImpl(ctx, v.GetArgs()[0], schema, newExprs, fail1Return) if fail1Return && hasFail { return substituted, hasFail, v } if substituted { flag := v.RetType.GetFlag() var e Expression var err error if v.FuncName.L == ast.Cast { // If the newArg is a ScalarFunction(cast), BuildCastFunctionWithCheck will modify the newArg.RetType, // So we need to deep copy RetType. // TODO: Expression interface needs a deep copy method. if newArgFunc, ok := newArg.(*ScalarFunction); ok { newArgFunc.RetType = newArgFunc.RetType.DeepCopy() newArg = newArgFunc } e, err = BuildCastFunctionWithCheck(ctx, newArg, v.RetType, false, v.Function.IsExplicitCharset()) terror.Log(err) } else { // for grouping function recreation, use clone (meta included) instead of newFunction e = v.Clone() e.(*ScalarFunction).Function.getArgs()[0] = newArg } e.SetCoercibility(v.Coercibility()) e.GetType(ctx.GetEvalCtx()).SetFlag(flag) return true, false, e } return false, false, v } // If the collation of the column is PAD SPACE, // we can't propagate the constant to the length function. // For example, schema = ['name'], newExprs = ['a'], v = length(name). // We can't substitute name with 'a' in length(name) because the collation of name is PAD SPACE. // TODO: We will fix it here temporarily, and redesign the logic if we encounter more similar functions or situations later. // Fixed issue #53730 if ctx.IsConstantPropagateCheck() && v.FuncName.L == ast.Length { arg0, isColumn := v.GetArgs()[0].(*Column) if isColumn { id := schema.ColumnIndex(arg0) if id != -1 { _, isConstant := newExprs[id].(*Constant) if isConstant { mappedNewColumnCollate := schema.Columns[id].GetStaticType().GetCollate() if mappedNewColumnCollate == charset.CollationUTF8MB4 || mappedNewColumnCollate == charset.CollationUTF8 { return false, false, v } } } } } // cowExprRef is a copy-on-write util, args array allocation happens only // when expr in args is changed refExprArr := cowExprRef{v.GetArgs(), nil} oldCollEt, err := CheckAndDeriveCollationFromExprs(ctx, v.FuncName.L, v.RetType.EvalType(), v.GetArgs()...) if err != nil { logutil.BgLogger().Warn("Unexpected error happened during ColumnSubstitution", zap.Stack("stack"), zap.Error(err)) return false, false, v } var tmpArgForCollCheck []Expression if collate.NewCollationEnabled() { tmpArgForCollCheck = make([]Expression, len(v.GetArgs())) } for idx, arg := range v.GetArgs() { changed, failed, newFuncExpr := ColumnSubstituteImpl(ctx, arg, schema, newExprs, fail1Return) if fail1Return && failed { return changed, failed, v } oldChanged := changed if collate.NewCollationEnabled() && changed { // Make sure the collation used by the ScalarFunction isn't changed and its result collation is not weaker than the collation used by the ScalarFunction. changed = false copy(tmpArgForCollCheck, refExprArr.Result()) tmpArgForCollCheck[idx] = newFuncExpr newCollEt, err := CheckAndDeriveCollationFromExprs(ctx, v.FuncName.L, v.RetType.EvalType(), tmpArgForCollCheck...) if err != nil { logutil.BgLogger().Warn("Unexpected error happened during ColumnSubstitution", zap.Stack("stack"), zap.Error(err)) return false, failed, v } if oldCollEt.Collation == newCollEt.Collation { if newFuncExpr.GetType(ctx.GetEvalCtx()).GetCollate() == arg.GetType(ctx.GetEvalCtx()).GetCollate() && newFuncExpr.Coercibility() == arg.Coercibility() { // It's safe to use the new expression, otherwise some cases in projection push-down will be wrong. changed = true } else { changed = checkCollationStrictness(oldCollEt.Collation, newFuncExpr.GetType(ctx.GetEvalCtx()).GetCollate()) } } } hasFail = hasFail || failed || oldChanged != changed if fail1Return && oldChanged != changed { // Only when the oldChanged is true and changed is false, we will get here. // And this means there some dependency in this arg can be substituted with // given expressions, while it has some collation compatibility, finally we // fall back to use the origin args. (commonly used in projection elimination // in which fallback usage is unacceptable) return changed, true, v } refExprArr.Set(idx, changed, newFuncExpr) if changed { substituted = true } } if substituted { var newFunc Expression var err error switch v.FuncName.L { case ast.EQ: // keep order as col=value to avoid flaky test. args := refExprArr.Result() switch args[0].(type) { case *Constant: newFunc, err = NewFunction(ctx, v.FuncName.L, v.RetType, args[1], args[0]) default: newFunc, err = NewFunction(ctx, v.FuncName.L, v.RetType, args[0], args[1]) } default: newFunc, err = NewFunction(ctx, v.FuncName.L, v.RetType, refExprArr.Result()...) } if err != nil { return true, true, v } return true, hasFail, newFunc } } return false, false, expr } // checkCollationStrictness check collation strictness-ship between `coll` and `newFuncColl` // return true iff `newFuncColl` is not weaker than `coll` func checkCollationStrictness(coll, newFuncColl string) bool { collGroupID, ok1 := CollationStrictnessGroup[coll] newFuncCollGroupID, ok2 := CollationStrictnessGroup[newFuncColl] if ok1 && ok2 { if collGroupID == newFuncCollGroupID { return true } if slices.Contains(CollationStrictness[collGroupID], newFuncCollGroupID) { return true } } return false } // getValidPrefix gets a prefix of string which can parsed to a number with base. the minimum base is 2 and the maximum is 36. func getValidPrefix(s string, base int64) string { var ( validLen int upper rune ) switch { case base >= 2 && base <= 9: upper = rune('0' + base) case base <= 36: upper = rune('A' + base - 10) default: return "" } Loop: for i := range len(s) { c := rune(s[i]) switch { case unicode.IsDigit(c) || unicode.IsLower(c) || unicode.IsUpper(c): c = unicode.ToUpper(c) if c >= upper { break Loop } validLen = i + 1 case c == '+' || c == '-': if i != 0 { break Loop } default: break Loop } } if validLen > 1 && s[0] == '+' { return s[1:validLen] } return s[:validLen] } // SubstituteCorCol2Constant will substitute correlated column to constant value which it contains. // If the args of one scalar function are all constant, we will substitute it to constant. func SubstituteCorCol2Constant(ctx BuildContext, expr Expression) (Expression, error) { switch x := expr.(type) { case *ScalarFunction: allConstant := true newArgs := make([]Expression, 0, len(x.GetArgs())) for _, arg := range x.GetArgs() { newArg, err := SubstituteCorCol2Constant(ctx, arg) if err != nil { return nil, err } _, ok := newArg.(*Constant) newArgs = append(newArgs, newArg) allConstant = allConstant && ok } if allConstant { val, err := x.Eval(ctx.GetEvalCtx(), chunk.Row{}) if err != nil { return nil, err } return &Constant{Value: val, RetType: x.GetType(ctx.GetEvalCtx())}, nil } var ( err error newSf Expression ) if x.FuncName.L == ast.Cast { newSf = BuildCastFunction(ctx, newArgs[0], x.RetType) } else if x.FuncName.L == ast.Grouping { newSf = x.Clone() newSf.(*ScalarFunction).GetArgs()[0] = newArgs[0] } else { newSf, err = NewFunction(ctx, x.FuncName.L, x.GetType(ctx.GetEvalCtx()), newArgs...) } return newSf, err case *CorrelatedColumn: return &Constant{Value: *x.Data, RetType: x.GetType(ctx.GetEvalCtx())}, nil case *Constant: if x.DeferredExpr != nil { newExpr := FoldConstant(ctx, x) return &Constant{Value: newExpr.(*Constant).Value, RetType: x.GetType(ctx.GetEvalCtx())}, nil } } return expr, nil } func locateStringWithCollation(str, substr, coll string) int64 { collator := collate.GetCollator(coll) strKey := collator.KeyWithoutTrimRightSpace(str) subStrKey := collator.KeyWithoutTrimRightSpace(substr) index := bytes.Index(strKey, subStrKey) if index == -1 || index == 0 { return int64(index + 1) } // todo: we can use binary search to make it faster. count := int64(0) for { r, size := utf8.DecodeRuneInString(str) count++ index -= len(collator.KeyWithoutTrimRightSpace(string(r))) if index <= 0 { return count + 1 } str = str[size:] } } // timeZone2Duration converts timezone whose format should satisfy the regular condition // `(^(+|-)(0?[0-9]|1[0-2]):[0-5]?\d$)|(^+13:00$)` to int for use by time.FixedZone(). func timeZone2int(tz string) int { sign := 1 if strings.HasPrefix(tz, "-") { sign = -1 } i := strings.Index(tz, ":") h, err := strconv.Atoi(tz[1:i]) terror.Log(err) m, err := strconv.Atoi(tz[i+1:]) terror.Log(err) return sign * ((h * 3600) + (m * 60)) } var logicalOps = map[string]struct{}{ ast.LT: {}, ast.GE: {}, ast.GT: {}, ast.LE: {}, ast.EQ: {}, ast.NE: {}, ast.UnaryNot: {}, ast.Like: {}, ast.LogicAnd: {}, ast.LogicOr: {}, ast.LogicXor: {}, ast.In: {}, ast.IsNull: {}, ast.IsFalsity: {}, ast.IsTruthWithoutNull: {}, ast.IsTruthWithNull: {}, ast.NullEQ: {}, ast.Regexp: {}, } var oppositeOp = map[string]string{ ast.LT: ast.GE, ast.GE: ast.LT, ast.GT: ast.LE, ast.LE: ast.GT, ast.EQ: ast.NE, ast.NE: ast.EQ, ast.LogicOr: ast.LogicAnd, ast.LogicAnd: ast.LogicOr, } // a op b is equal to b symmetricOp a var symmetricOp = map[opcode.Op]opcode.Op{ opcode.LT: opcode.GT, opcode.GE: opcode.LE, opcode.GT: opcode.LT, opcode.LE: opcode.GE, opcode.EQ: opcode.EQ, opcode.NE: opcode.NE, opcode.NullEQ: opcode.NullEQ, } func pushNotAcrossArgs(ctx BuildContext, exprs []Expression, not bool) ([]Expression, bool) { newExprs := make([]Expression, 0, len(exprs)) flag := false for _, expr := range exprs { newExpr, changed := pushNotAcrossExpr(ctx, expr, not) flag = changed || flag newExprs = append(newExprs, newExpr) } return newExprs, flag } // todo: consider more no precision-loss downcast cases. func noPrecisionLossCastCompatible(cast, argCol *types.FieldType) bool { // now only consider varchar type and integer. if !(types.IsTypeVarchar(cast.GetType()) && types.IsTypeVarchar(argCol.GetType())) && !(mysql.IsIntegerType(cast.GetType()) && mysql.IsIntegerType(argCol.GetType())) { // varchar type and integer on the storage layer is quite same, while the char type has its padding suffix. return false } if types.IsTypeVarchar(cast.GetType()) { // cast varchar function only bear the flen extension. if cast.GetFlen() < argCol.GetFlen() { return false } if !collate.CompatibleCollate(cast.GetCollate(), argCol.GetCollate()) { return false } } else { // For integers, we should ignore the potential display length represented by flen, using the default flen of the type. castFlen, _ := mysql.GetDefaultFieldLengthAndDecimal(cast.GetType()) originFlen, _ := mysql.GetDefaultFieldLengthAndDecimal(argCol.GetType()) // cast integer function only bear the flen extension and signed symbol unchanged. if castFlen < originFlen { return false } if mysql.HasUnsignedFlag(cast.GetFlag()) != mysql.HasUnsignedFlag(argCol.GetFlag()) { return false } } return true } func unwrapCast(sctx BuildContext, parentF *ScalarFunction, castOffset int) (Expression, bool) { _, collation := parentF.CharsetAndCollation() cast, ok := parentF.GetArgs()[castOffset].(*ScalarFunction) if !ok || cast.FuncName.L != ast.Cast { return parentF, false } // eg: if (cast(A) EQ const) with incompatible collation, even if cast is eliminated, the condition still can not be used to build range. if cast.RetType.EvalType() == types.ETString && !collate.CompatibleCollate(cast.RetType.GetCollate(), collation) { return parentF, false } // 1-castOffset should be constant if _, ok := parentF.GetArgs()[1-castOffset].(*Constant); !ok { return parentF, false } // the direct args of cast function should be column. c, ok := cast.GetArgs()[0].(*Column) if !ok { return parentF, false } // current only consider varchar and integer if !noPrecisionLossCastCompatible(cast.RetType, c.RetType) { return parentF, false } // the column is covered by indexes, deconstructing it out. if castOffset == 0 { return NewFunctionInternal(sctx, parentF.FuncName.L, parentF.RetType, c, parentF.GetArgs()[1]), true } return NewFunctionInternal(sctx, parentF.FuncName.L, parentF.RetType, parentF.GetArgs()[0], c), true } // eliminateCastFunction will detect the original arg before and the cast type after, once upon // there is no precision loss between them, current cast wrapper can be eliminated. For string // type, collation is also taken into consideration. (mainly used to build range or point) func eliminateCastFunction(sctx BuildContext, expr Expression) (_ Expression, changed bool) { f, ok := expr.(*ScalarFunction) if !ok { return expr, false } _, collation := expr.CharsetAndCollation() switch f.FuncName.L { case ast.LogicOr: dnfItems := FlattenDNFConditions(f) rmCast := false rmCastItems := make([]Expression, len(dnfItems)) for i, dnfItem := range dnfItems { newExpr, curDowncast := eliminateCastFunction(sctx, dnfItem) rmCastItems[i] = newExpr if curDowncast { rmCast = true } } if rmCast { // compose the new DNF expression. return ComposeDNFCondition(sctx, rmCastItems...), true } return expr, false case ast.LogicAnd: cnfItems := FlattenCNFConditions(f) rmCast := false rmCastItems := make([]Expression, len(cnfItems)) for i, cnfItem := range cnfItems { newExpr, curDowncast := eliminateCastFunction(sctx, cnfItem) rmCastItems[i] = newExpr if curDowncast { rmCast = true } } if rmCast { // compose the new CNF expression. return ComposeCNFCondition(sctx, rmCastItems...), true } return expr, false case ast.EQ, ast.NullEQ, ast.LE, ast.GE, ast.LT, ast.GT: // for case: eq(cast(test.t2.a, varchar(100), "aaaaa"), once t2.a is covered by index or pk, try deconstructing it out. if newF, ok := unwrapCast(sctx, f, 0); ok { return newF, true } // for case: eq("aaaaa", cast(test.t2.a, varchar(100)), once t2.a is covered by index or pk, try deconstructing it out. if newF, ok := unwrapCast(sctx, f, 1); ok { return newF, true } case ast.In: // case for: cast(a as bigint) in (1,2,3), we could deconstruct column 'a out directly. cast, ok := f.GetArgs()[0].(*ScalarFunction) if !ok || cast.FuncName.L != ast.Cast { return expr, false } // eg: if (cast(A) IN {const}) with incompatible collation, even if cast is eliminated, the condition still can not be used to build range. if cast.RetType.EvalType() == types.ETString && !collate.CompatibleCollate(cast.RetType.GetCollate(), collation) { return expr, false } for _, arg := range f.GetArgs()[1:] { if _, ok := arg.(*Constant); !ok { return expr, false } } // the direct args of cast function should be column. c, ok := cast.GetArgs()[0].(*Column) if !ok { return expr, false } // current only consider varchar and integer if !noPrecisionLossCastCompatible(cast.RetType, c.RetType) { return expr, false } newArgs := []Expression{c} newArgs = append(newArgs, f.GetArgs()[1:]...) return NewFunctionInternal(sctx, f.FuncName.L, f.RetType, newArgs...), true } return expr, false } // pushNotAcrossExpr try to eliminate the NOT expr in expression tree. // Input `not` indicates whether there's a `NOT` be pushed down. // Output `changed` indicates whether the output expression differs from the // input `expr` because of the pushed-down-not. func pushNotAcrossExpr(ctx BuildContext, expr Expression, not bool) (_ Expression, changed bool) { if f, ok := expr.(*ScalarFunction); ok { switch f.FuncName.L { case ast.UnaryNot: child, err := wrapWithIsTrue(ctx, true, f.GetArgs()[0], true) if err != nil { return expr, false } var childExpr Expression childExpr, changed = pushNotAcrossExpr(ctx, child, !not) if !changed && !not { return expr, false } return childExpr, true case ast.LT, ast.GE, ast.GT, ast.LE, ast.EQ, ast.NE: if not { return NewFunctionInternal(ctx, oppositeOp[f.FuncName.L], f.GetType(ctx.GetEvalCtx()), f.GetArgs()...), true } newArgs, changed := pushNotAcrossArgs(ctx, f.GetArgs(), false) if !changed { return f, false } return NewFunctionInternal(ctx, f.FuncName.L, f.GetType(ctx.GetEvalCtx()), newArgs...), true case ast.LogicAnd, ast.LogicOr: var ( newArgs []Expression changed bool ) funcName := f.FuncName.L if not { newArgs, _ = pushNotAcrossArgs(ctx, f.GetArgs(), true) funcName = oppositeOp[f.FuncName.L] changed = true } else { newArgs, changed = pushNotAcrossArgs(ctx, f.GetArgs(), false) } if !changed { return f, false } return NewFunctionInternal(ctx, funcName, f.GetType(ctx.GetEvalCtx()), newArgs...), true } } if not { expr = NewFunctionInternal(ctx, ast.UnaryNot, types.NewFieldType(mysql.TypeTiny), expr) } return expr, not } // GetExprInsideIsTruth get the expression inside the `istrue_with_null` and `istrue`. // This is useful when handling expressions from "not" or "!", because we might wrap `istrue_with_null` or `istrue` // when handling them. See pushNotAcrossExpr() and wrapWithIsTrue() for details. func GetExprInsideIsTruth(expr Expression) Expression { if f, ok := expr.(*ScalarFunction); ok { switch f.FuncName.L { case ast.IsTruthWithNull, ast.IsTruthWithoutNull: return GetExprInsideIsTruth(f.GetArgs()[0]) default: return expr } } return expr } // PushDownNot pushes the `not` function down to the expression's arguments. func PushDownNot(ctx BuildContext, expr Expression) Expression { newExpr, _ := pushNotAcrossExpr(ctx, expr, false) return newExpr } // EliminateNoPrecisionLossCast remove the redundant cast function for range build convenience. // 1: deeper cast embedded in other complicated function will not be considered. // 2: cast args should be one for original base column and one for constant. // 3: some collation compatibility and precision loss will be considered when remove this cast func. func EliminateNoPrecisionLossCast(sctx BuildContext, expr Expression) Expression { newExpr, _ := eliminateCastFunction(sctx, expr) return newExpr } // ContainOuterNot checks if there is an outer `not`. func ContainOuterNot(expr Expression) bool { return containOuterNot(expr, false) } // containOuterNot checks if there is an outer `not`. // Input `not` means whether there is `not` outside `expr` // // eg. // // not(0+(t.a == 1 and t.b == 2)) returns true // not(t.a) and not(t.b) returns false func containOuterNot(expr Expression, not bool) bool { if f, ok := expr.(*ScalarFunction); ok { switch f.FuncName.L { case ast.UnaryNot: return containOuterNot(f.GetArgs()[0], true) case ast.IsTruthWithNull, ast.IsNull: return containOuterNot(f.GetArgs()[0], not) default: if not { return true } hasNot := false for _, expr := range f.GetArgs() { hasNot = hasNot || containOuterNot(expr, not) if hasNot { return hasNot } } return hasNot } } return false } // Contains tests if `exprs` contains `e`. func Contains(ectx EvalContext, exprs []Expression, e Expression) bool { return slices.ContainsFunc(exprs, func(expr Expression) bool { if expr == nil { return e == nil } return e == expr || expr.Equal(ectx, e) }) } // ExtractFiltersFromDNFs checks whether the cond is DNF. If so, it will get the extracted part and the remained part. // The original DNF will be replaced by the remained part or just be deleted if remained part is nil. // And the extracted part will be appended to the end of the original slice. func ExtractFiltersFromDNFs(ctx BuildContext, conditions []Expression) []Expression { var allExtracted []Expression for i := len(conditions) - 1; i >= 0; i-- { if sf, ok := conditions[i].(*ScalarFunction); ok && sf.FuncName.L == ast.LogicOr { extracted, remained := extractFiltersFromDNF(ctx, sf) allExtracted = append(allExtracted, extracted...) if remained == nil { conditions = slices.Delete(conditions, i, i+1) } else { conditions[i] = remained } } } return append(conditions, allExtracted...) } // extractFiltersFromDNF extracts the same condition that occurs in every DNF item and remove them from dnf leaves. func extractFiltersFromDNF(ctx BuildContext, dnfFunc *ScalarFunction) ([]Expression, Expression) { dnfItems := FlattenDNFConditions(dnfFunc) codeMap := make(map[string]int) hashcode2Expr := make(map[string]Expression) for i, dnfItem := range dnfItems { innerMap := make(map[string]struct{}) cnfItems := SplitCNFItems(dnfItem) for _, cnfItem := range cnfItems { code := cnfItem.HashCode() if i == 0 { codeMap[string(code)] = 1 hashcode2Expr[string(code)] = cnfItem } else if _, ok := codeMap[string(code)]; ok { // We need this check because there may be the case like `select * from t, t1 where (t.a=t1.a and t.a=t1.a) or (something). // We should make sure that the two `t.a=t1.a` contributes only once. // TODO: do this out of this function. if _, ok = innerMap[string(code)]; !ok { codeMap[string(code)]++ innerMap[string(code)] = struct{}{} } } } } // We should make sure that this item occurs in every DNF item. for hashcode, cnt := range codeMap { if cnt < len(dnfItems) { delete(hashcode2Expr, hashcode) } } if len(hashcode2Expr) == 0 { return nil, dnfFunc } newDNFItems := make([]Expression, 0, len(dnfItems)) onlyNeedExtracted := false for _, dnfItem := range dnfItems { cnfItems := SplitCNFItems(dnfItem) newCNFItems := make([]Expression, 0, len(cnfItems)) for _, cnfItem := range cnfItems { code := cnfItem.HashCode() _, ok := hashcode2Expr[string(code)] if !ok { newCNFItems = append(newCNFItems, cnfItem) } } // If the extracted part is just one leaf of the DNF expression. Then the value of the total DNF expression is // always the same with the value of the extracted part. if len(newCNFItems) == 0 { onlyNeedExtracted = true break } newDNFItems = append(newDNFItems, ComposeCNFCondition(ctx, newCNFItems...)) } extractedExpr := make([]Expression, 0, len(hashcode2Expr)) for _, expr := range hashcode2Expr { extractedExpr = append(extractedExpr, expr) } if onlyNeedExtracted { return extractedExpr, nil } return extractedExpr, ComposeDNFCondition(ctx, newDNFItems...) } // DeriveRelaxedFiltersFromDNF given a DNF expression, derive a relaxed DNF expression which only contains columns // in specified schema; the derived expression is a superset of original expression, i.e, any tuple satisfying // the original expression must satisfy the derived expression. Return nil when the derived expression is universal set. // A running example is: for schema of t1, `(t1.a=1 and t2.a=1) or (t1.a=2 and t2.a=2)` would be derived as // `t1.a=1 or t1.a=2`, while `t1.a=1 or t2.a=1` would get nil. func DeriveRelaxedFiltersFromDNF(ctx BuildContext, expr Expression, schema *Schema) Expression { sf, ok := expr.(*ScalarFunction) if !ok || sf.FuncName.L != ast.LogicOr { return nil } dnfItems := FlattenDNFConditions(sf) newDNFItems := make([]Expression, 0, len(dnfItems)) for _, dnfItem := range dnfItems { cnfItems := SplitCNFItems(dnfItem) newCNFItems := make([]Expression, 0, len(cnfItems)) for _, cnfItem := range cnfItems { if itemSF, ok := cnfItem.(*ScalarFunction); ok && itemSF.FuncName.L == ast.LogicOr { relaxedCNFItem := DeriveRelaxedFiltersFromDNF(ctx, cnfItem, schema) if relaxedCNFItem != nil { newCNFItems = append(newCNFItems, relaxedCNFItem) } // If relaxed expression for embedded DNF is universal set, just drop this CNF item continue } // This cnfItem must be simple expression now // If it cannot be fully covered by schema, just drop this CNF item if ExprFromSchema(cnfItem, schema) { newCNFItems = append(newCNFItems, cnfItem) } } // If this DNF item involves no column of specified schema, the relaxed expression must be universal set if len(newCNFItems) == 0 { return nil } newDNFItems = append(newDNFItems, ComposeCNFCondition(ctx, newCNFItems...)) } return ComposeDNFCondition(ctx, newDNFItems...) } // GetRowLen gets the length if the func is row, returns 1 if not row. func GetRowLen(e Expression) int { if f, ok := e.(*ScalarFunction); ok && f.FuncName.L == ast.RowFunc { return len(f.GetArgs()) } return 1 } // CheckArgsNotMultiColumnRow checks the args are not multi-column row. func CheckArgsNotMultiColumnRow(args ...Expression) error { for _, arg := range args { if GetRowLen(arg) != 1 { return ErrOperandColumns.GenWithStackByArgs(1) } } return nil } // GetFuncArg gets the argument of the function at idx. func GetFuncArg(e Expression, idx int) Expression { if f, ok := e.(*ScalarFunction); ok { return f.GetArgs()[idx] } return nil } // PopRowFirstArg pops the first element and returns the rest of row. // e.g. After this function (1, 2, 3) becomes (2, 3). func PopRowFirstArg(ctx BuildContext, e Expression) (ret Expression, err error) { if f, ok := e.(*ScalarFunction); ok && f.FuncName.L == ast.RowFunc { args := f.GetArgs() if len(args) == 2 { return args[1], nil } ret, err = NewFunction(ctx, ast.RowFunc, f.GetType(ctx.GetEvalCtx()), args[1:]...) return ret, err } return } // DatumToConstant generates a Constant expression from a Datum. func DatumToConstant(d types.Datum, tp byte, flag uint) *Constant { t := types.NewFieldType(tp) t.AddFlag(flag) return &Constant{Value: d, RetType: t} } // ParamMarkerExpression generate a getparam function expression. func ParamMarkerExpression(ctx BuildContext, v *driver.ParamMarkerExpr, needParam bool) (*Constant, error) { useCache := ctx.IsUseCache() tp := types.NewFieldType(mysql.TypeUnspecified) types.InferParamTypeFromDatum(&v.Datum, tp) value := &Constant{Value: v.Datum, RetType: tp} if useCache || needParam { value.ParamMarker = &ParamMarker{ order: v.Order, } } return value, nil } // ParamMarkerInPrepareChecker checks whether the given ast tree has paramMarker and is in prepare statement. type ParamMarkerInPrepareChecker struct { InPrepareStmt bool } // Enter implements Visitor Interface. func (pc *ParamMarkerInPrepareChecker) Enter(in ast.Node) (out ast.Node, skipChildren bool) { switch v := in.(type) { case *driver.ParamMarkerExpr: pc.InPrepareStmt = !v.InExecute return v, true } return in, false } // Leave implements Visitor Interface. func (pc *ParamMarkerInPrepareChecker) Leave(in ast.Node) (out ast.Node, ok bool) { return in, true } // DisableParseJSONFlag4Expr disables ParseToJSONFlag for `expr` except Column. // We should not *PARSE* a string as JSON under some scenarios. ParseToJSONFlag // is 0 for JSON column yet(as well as JSON correlated column), so we can skip // it. Moreover, Column.RetType refers to the infoschema, if we modify it, data // race may happen if another goroutine read from the infoschema at the same // time. func DisableParseJSONFlag4Expr(ctx EvalContext, expr Expression) { if _, isColumn := expr.(*Column); isColumn { return } if _, isCorCol := expr.(*CorrelatedColumn); isCorCol { return } expr.GetType(ctx).SetFlag(expr.GetType(ctx).GetFlag() & ^mysql.ParseToJSONFlag) } // ConstructPositionExpr constructs PositionExpr with the given ParamMarkerExpr. func ConstructPositionExpr(p *driver.ParamMarkerExpr) *ast.PositionExpr { return &ast.PositionExpr{P: p} } // PosFromPositionExpr generates a position value from PositionExpr. func PosFromPositionExpr(ctx BuildContext, v *ast.PositionExpr) (int, bool, error) { if v.P == nil { return v.N, false, nil } value, err := ParamMarkerExpression(ctx, v.P.(*driver.ParamMarkerExpr), false) if err != nil { return 0, true, err } pos, isNull, err := GetIntFromConstant(ctx.GetEvalCtx(), value) if err != nil || isNull { return 0, true, err } return pos, false, nil } // GetStringFromConstant gets a string value from the Constant expression. func GetStringFromConstant(ctx EvalContext, value Expression) (string, bool, error) { con, ok := value.(*Constant) if !ok { err := errors.Errorf("Not a Constant expression %+v", value) return "", true, err } str, isNull, err := con.EvalString(ctx, chunk.Row{}) if err != nil || isNull { return "", true, err } return str, false, nil } // GetIntFromConstant gets an integer value from the Constant expression. func GetIntFromConstant(ctx EvalContext, value Expression) (int, bool, error) { str, isNull, err := GetStringFromConstant(ctx, value) if err != nil || isNull { return 0, true, err } intNum, err := strconv.Atoi(str) if err != nil { return 0, true, nil } return intNum, false, nil } // BuildNotNullExpr wraps up `not(isnull())` for given expression. func BuildNotNullExpr(ctx BuildContext, expr Expression) Expression { isNull := NewFunctionInternal(ctx, ast.IsNull, types.NewFieldType(mysql.TypeTiny), expr) notNull := NewFunctionInternal(ctx, ast.UnaryNot, types.NewFieldType(mysql.TypeTiny), isNull) return notNull } // IsRuntimeConstExpr checks if a expr can be treated as a constant in **executor**. func IsRuntimeConstExpr(expr Expression) bool { switch x := expr.(type) { case *ScalarFunction: if _, ok := unFoldableFunctions[x.FuncName.L]; ok { return false } for _, arg := range x.GetArgs() { if !IsRuntimeConstExpr(arg) { return false } } return true case *Column: return false case *Constant, *CorrelatedColumn: return true } return false } // CheckNonDeterministic checks whether the current expression contains a non-deterministic func. func CheckNonDeterministic(e Expression) bool { switch x := e.(type) { case *Constant, *Column, *CorrelatedColumn: return false case *ScalarFunction: if _, ok := unFoldableFunctions[x.FuncName.L]; ok { return true } if slices.ContainsFunc(x.GetArgs(), CheckNonDeterministic) { return true } } return false } // CheckFuncInExpr checks whether there's a given function in the expression. func CheckFuncInExpr(e Expression, funcName string) bool { switch x := e.(type) { case *Constant, *Column, *CorrelatedColumn: return false case *ScalarFunction: if x.FuncName.L == funcName { return true } for _, arg := range x.GetArgs() { if CheckFuncInExpr(arg, funcName) { return true } } } return false } // IsMutableEffectsExpr checks if expr contains function which is mutable or has side effects. func IsMutableEffectsExpr(expr Expression) bool { switch x := expr.(type) { case *ScalarFunction: if _, ok := mutableEffectsFunctions[x.FuncName.L]; ok { return true } if slices.ContainsFunc(x.GetArgs(), IsMutableEffectsExpr) { return true } case *Column: case *Constant: if x.DeferredExpr != nil { return IsMutableEffectsExpr(x.DeferredExpr) } } return false } // IsImmutableFunc checks whether this expression only consists of foldable functions. // This expression can be evaluated by using `expr.Eval(chunk.Row{})` directly and the result won't change if it's immutable. func IsImmutableFunc(expr Expression) bool { switch x := expr.(type) { case *ScalarFunction: if _, ok := unFoldableFunctions[x.FuncName.L]; ok { return false } if _, ok := mutableEffectsFunctions[x.FuncName.L]; ok { return false } for _, arg := range x.GetArgs() { if !IsImmutableFunc(arg) { return false } } return true default: return true } } // RemoveDupExprs removes identical exprs. Not that if expr contains functions which // are mutable or have side effects, we cannot remove it even if it has duplicates; // if the plan is going to be cached, we cannot remove expressions containing `?` neither. func RemoveDupExprs(exprs []Expression) []Expression { if len(exprs) <= 1 { return exprs } exists := make(map[string]struct{}, len(exprs)) return slices.DeleteFunc(exprs, func(expr Expression) bool { key := string(expr.HashCode()) if _, ok := exists[key]; !ok || IsMutableEffectsExpr(expr) { exists[key] = struct{}{} return false } return true }) } // GetUint64FromConstant gets a uint64 from constant expression. func GetUint64FromConstant(ctx EvalContext, expr Expression) (uint64, bool, bool) { con, ok := expr.(*Constant) if !ok { logutil.BgLogger().Warn("not a constant expression", zap.String("expression", expr.ExplainInfo(ctx))) return 0, false, false } dt := con.Value if con.ParamMarker != nil { var err error dt, err = con.ParamMarker.GetUserVar(ctx) if err != nil { logutil.BgLogger().Warn("get param failed", zap.Error(err)) return 0, false, false } } else if con.DeferredExpr != nil { var err error dt, err = con.DeferredExpr.Eval(ctx, chunk.Row{}) if err != nil { logutil.BgLogger().Warn("eval deferred expr failed", zap.Error(err)) return 0, false, false } } switch dt.Kind() { case types.KindNull: return 0, true, true case types.KindInt64: val := dt.GetInt64() if val < 0 { return 0, false, false } return uint64(val), false, true case types.KindUint64: return dt.GetUint64(), false, true } return 0, false, false } // ContainVirtualColumn checks if the expressions contain a virtual column func ContainVirtualColumn(exprs []Expression) bool { for _, expr := range exprs { switch v := expr.(type) { case *Column: if v.VirtualExpr != nil { return true } case *ScalarFunction: if ContainVirtualColumn(v.GetArgs()) { return true } } } return false } // ContainCorrelatedColumn checks if the expressions contain a correlated column func ContainCorrelatedColumn(exprs ...Expression) bool { for _, expr := range exprs { switch v := expr.(type) { case *CorrelatedColumn: return true case *ScalarFunction: if ContainCorrelatedColumn(v.GetArgs()...) { return true } } } return false } func jsonUnquoteFunctionBenefitsFromPushedDown(sf *ScalarFunction) bool { arg0 := sf.GetArgs()[0] // Only `->>` which parsed to JSONUnquote(CAST(JSONExtract() AS string)) can be pushed down to tikv if fChild, ok := arg0.(*ScalarFunction); ok { if fChild.FuncName.L == ast.Cast { if fGrand, ok := fChild.GetArgs()[0].(*ScalarFunction); ok { if fGrand.FuncName.L == ast.JSONExtract { return true } } } } return false } // ProjectionBenefitsFromPushedDown evaluates if the expressions can improve performance when pushed down to TiKV // Projections are not pushed down to tikv by default, thus we need to check strictly here to avoid potential performance degradation. // Note: virtual column is not considered here, since this function cares performance instead of functionality func ProjectionBenefitsFromPushedDown(exprs []Expression, inputSchemaLen int) bool { // In debug usage, we need to force push down projections to tikv to check tikv expression behavior. failpoint.Inject("forcePushDownTiKV", func() { failpoint.Return(true) }) allColRef := true colRefCount := 0 for _, expr := range exprs { switch v := expr.(type) { case *Column: colRefCount = colRefCount + 1 continue case *ScalarFunction: allColRef = false switch v.FuncName.L { case ast.JSONDepth, ast.JSONLength, ast.JSONType, ast.JSONValid, ast.JSONContains, ast.JSONContainsPath, ast.JSONExtract, ast.JSONKeys, ast.JSONSearch, ast.JSONMemberOf, ast.JSONOverlaps: continue case ast.JSONUnquote: if jsonUnquoteFunctionBenefitsFromPushedDown(v) { continue } return false default: return false } default: return false } } // For all col refs, only push down column pruning projections if allColRef { return colRefCount < inputSchemaLen } return true } // MaybeOverOptimized4PlanCache used to check whether an optimization can work // for the statement when we enable the plan cache. // In some situations, some optimizations maybe over-optimize and cache an // overOptimized plan. The cached plan may not get the correct result when we // reuse the plan for other statements. // For example, `pk>=$a and pk<=$b` can be optimized to a PointGet when // `$a==$b`, but it will cause wrong results when `$a!=$b`. // So we need to do the check here. The check includes the following aspects: // 1. Whether the plan cache switch is enable. // 2. Whether the statement can be cached. // 3. Whether the expressions contain a lazy constant. // TODO: Do more careful check here. func MaybeOverOptimized4PlanCache(ctx BuildContext, exprs ...Expression) bool { // If we do not enable plan cache, all the optimization can work correctly. if !ctx.IsUseCache() { return false } return containMutableConst(ctx.GetEvalCtx(), exprs) } // containMutableConst checks if the expressions contain a lazy constant. func containMutableConst(ctx EvalContext, exprs []Expression) bool { for _, expr := range exprs { switch v := expr.(type) { case *Constant: if v.ParamMarker != nil || v.DeferredExpr != nil { return true } case *ScalarFunction: if containMutableConst(ctx, v.GetArgs()) { return true } } } return false } // RemoveMutableConst used to remove the `ParamMarker` and `DeferredExpr` in the `Constant` expr. func RemoveMutableConst(ctx BuildContext, exprs ...Expression) (err error) { for _, expr := range exprs { switch v := expr.(type) { case *Constant: v.ParamMarker = nil if v.DeferredExpr != nil { // evaluate and update v.Value to convert v to a complete immutable constant. // TODO: remove or hide DeferredExpr since it's too dangerous (hard to be consistent with v.Value all the time). v.Value, err = v.DeferredExpr.Eval(ctx.GetEvalCtx(), chunk.Row{}) if err != nil { return err } v.DeferredExpr = nil } v.DeferredExpr = nil // do nothing since v.Value has already been evaluated in this case. case *ScalarFunction: err := RemoveMutableConst(ctx, v.GetArgs()...) if err != nil { return err } } } return nil } const ( _ = iota kib = 1 << (10 * iota) mib = 1 << (10 * iota) gib = 1 << (10 * iota) tib = 1 << (10 * iota) pib = 1 << (10 * iota) eib = 1 << (10 * iota) ) const ( nano = 1 micro = 1000 * nano milli = 1000 * micro sec = 1000 * milli minute = 60 * sec hour = 60 * minute dayTime = 24 * hour ) // GetFormatBytes convert byte count to value with units. func GetFormatBytes(bytes float64) string { var divisor float64 var unit string bytesAbs := math.Abs(bytes) if bytesAbs >= eib { divisor = eib unit = "EiB" } else if bytesAbs >= pib { divisor = pib unit = "PiB" } else if bytesAbs >= tib { divisor = tib unit = "TiB" } else if bytesAbs >= gib { divisor = gib unit = "GiB" } else if bytesAbs >= mib { divisor = mib unit = "MiB" } else if bytesAbs >= kib { divisor = kib unit = "KiB" } else { divisor = 1 unit = "bytes" } if divisor == 1 { return strconv.FormatFloat(bytes, 'f', 0, 64) + " " + unit } value := bytes / divisor if math.Abs(value) >= 100000.0 { return strconv.FormatFloat(value, 'e', 2, 64) + " " + unit } return strconv.FormatFloat(value, 'f', 2, 64) + " " + unit } // GetFormatNanoTime convert time in nanoseconds to value with units. func GetFormatNanoTime(time float64) string { var divisor float64 var unit string timeAbs := math.Abs(time) if timeAbs >= dayTime { divisor = dayTime unit = "d" } else if timeAbs >= hour { divisor = hour unit = "h" } else if timeAbs >= minute { divisor = minute unit = "min" } else if timeAbs >= sec { divisor = sec unit = "s" } else if timeAbs >= milli { divisor = milli unit = "ms" } else if timeAbs >= micro { divisor = micro unit = "us" } else { divisor = 1 unit = "ns" } if divisor == 1 { return strconv.FormatFloat(time, 'f', 0, 64) + " " + unit } value := time / divisor if math.Abs(value) >= 100000.0 { return strconv.FormatFloat(value, 'e', 2, 64) + " " + unit } return strconv.FormatFloat(value, 'f', 2, 64) + " " + unit } // SQLDigestTextRetriever is used to find the normalized SQL statement text by SQL digests in statements_summary table. // It's exported for test purposes. It's used by the `tidb_decode_sql_digests` builtin function, but also exposed to // be used in other modules. type SQLDigestTextRetriever struct { // SQLDigestsMap is the place to put the digests that's requested for getting SQL text and also the place to put // the query result. SQLDigestsMap map[string]string // Replace querying for test purposes. mockLocalData map[string]string mockGlobalData map[string]string // There are two ways for querying information: 1) query specified digests by WHERE IN query, or 2) query all // information to avoid the too long WHERE IN clause. If there are more than `fetchAllLimit` digests needs to be // queried, the second way will be chosen; otherwise, the first way will be chosen. fetchAllLimit int } // NewSQLDigestTextRetriever creates a new SQLDigestTextRetriever. func NewSQLDigestTextRetriever() *SQLDigestTextRetriever { return &SQLDigestTextRetriever{ SQLDigestsMap: make(map[string]string), fetchAllLimit: 512, } } func (r *SQLDigestTextRetriever) runMockQuery(data map[string]string, inValues []any) (map[string]string, error) { if len(inValues) == 0 { return data, nil } res := make(map[string]string, len(inValues)) for _, digest := range inValues { if text, ok := data[digest.(string)]; ok { res[digest.(string)] = text } } return res, nil } // runFetchDigestQuery runs query to the system tables to fetch the kv mapping of SQL digests and normalized SQL texts // of the given SQL digests, if `inValues` is given, or all these mappings otherwise. If `queryGlobal` is false, it // queries information_schema.statements_summary and information_schema.statements_summary_history; otherwise, it // queries the cluster version of these two tables. func (r *SQLDigestTextRetriever) runFetchDigestQuery(ctx context.Context, exec expropt.SQLExecutor, queryGlobal bool, inValues []any) (map[string]string, error) { ctx = kv.WithInternalSourceType(ctx, kv.InternalTxnOthers) // If mock data is set, query the mock data instead of the real statements_summary tables. if !queryGlobal && r.mockLocalData != nil { return r.runMockQuery(r.mockLocalData, inValues) } else if queryGlobal && r.mockGlobalData != nil { return r.runMockQuery(r.mockGlobalData, inValues) } // Information in statements_summary will be periodically moved to statements_summary_history. Union them together // to avoid missing information when statements_summary is just cleared. stmt := "select digest, digest_text from information_schema.statements_summary union distinct " + "select digest, digest_text from information_schema.statements_summary_history" if queryGlobal { stmt = "select digest, digest_text from information_schema.cluster_statements_summary union distinct " + "select digest, digest_text from information_schema.cluster_statements_summary_history" } // Add the where clause if `inValues` is specified. if len(inValues) > 0 { stmt += " where digest in (" + strings.Repeat("%?,", len(inValues)-1) + "%?)" } rows, _, err := exec.ExecRestrictedSQL(ctx, nil, stmt, inValues...) if err != nil { return nil, err } res := make(map[string]string, len(rows)) for _, row := range rows { res[row.GetString(0)] = row.GetString(1) } return res, nil } func (r *SQLDigestTextRetriever) updateDigestInfo(queryResult map[string]string) { for digest, text := range r.SQLDigestsMap { if len(text) > 0 { // The text of this digest is already known continue } sqlText, ok := queryResult[digest] if ok { r.SQLDigestsMap[digest] = sqlText } } } // RetrieveLocal tries to retrieve the SQL text of the SQL digests from local information. func (r *SQLDigestTextRetriever) RetrieveLocal(ctx context.Context, exec expropt.SQLExecutor) error { if len(r.SQLDigestsMap) == 0 { return nil } var queryResult map[string]string if len(r.SQLDigestsMap) <= r.fetchAllLimit { inValues := make([]any, 0, len(r.SQLDigestsMap)) for key := range r.SQLDigestsMap { inValues = append(inValues, key) } var err error queryResult, err = r.runFetchDigestQuery(ctx, exec, false, inValues) if err != nil { return errors.Trace(err) } if len(queryResult) == len(r.SQLDigestsMap) { r.SQLDigestsMap = queryResult return nil } } else { var err error queryResult, err = r.runFetchDigestQuery(ctx, exec, false, nil) if err != nil { return errors.Trace(err) } } r.updateDigestInfo(queryResult) return nil } // RetrieveGlobal tries to retrieve the SQL text of the SQL digests from the information of the whole cluster. func (r *SQLDigestTextRetriever) RetrieveGlobal(ctx context.Context, exec expropt.SQLExecutor) error { err := r.RetrieveLocal(ctx, exec) if err != nil { return errors.Trace(err) } // In some unit test environments it's unable to retrieve global info, and this function blocks it for tens of // seconds, which wastes much time during unit test. In this case, enable this failpoint to bypass retrieving // globally. failpoint.Inject("sqlDigestRetrieverSkipRetrieveGlobal", func() { failpoint.Return(nil) }) var unknownDigests []any for k, v := range r.SQLDigestsMap { if len(v) == 0 { unknownDigests = append(unknownDigests, k) } } if len(unknownDigests) == 0 { return nil } var queryResult map[string]string if len(r.SQLDigestsMap) <= r.fetchAllLimit { queryResult, err = r.runFetchDigestQuery(ctx, exec, true, unknownDigests) if err != nil { return errors.Trace(err) } } else { queryResult, err = r.runFetchDigestQuery(ctx, exec, true, nil) if err != nil { return errors.Trace(err) } } r.updateDigestInfo(queryResult) return nil } // ExprsToStringsForDisplay convert a slice of Expression to a slice of string using Expression.String(), and // to make it better for display and debug, it also escapes the string to corresponding golang string literal, // which means using \t, \n, \x??, \u????, ... to represent newline, control character, non-printable character, // invalid utf-8 bytes and so on. func ExprsToStringsForDisplay(ctx EvalContext, exprs []Expression) []string { strs := make([]string, len(exprs)) for i, cond := range exprs { quote := `"` // We only need the escape functionality of strconv.Quote, the quoting is not needed, // so we trim the \" prefix and suffix here. strs[i] = strings.TrimSuffix( strings.TrimPrefix( strconv.Quote(cond.StringWithCtx(ctx, errors.RedactLogDisable)), quote), quote) } return strs } // HasColumnWithCondition tries to retrieve the expression (column or function) if it contains the target column. func HasColumnWithCondition(e Expression, cond func(*Column) bool) bool { return hasColumnWithCondition(e, cond) } func hasColumnWithCondition(e Expression, cond func(*Column) bool) bool { switch v := e.(type) { case *Column: return cond(v) case *ScalarFunction: for _, arg := range v.GetArgs() { if hasColumnWithCondition(arg, cond) { return true } } } return false } // ConstExprConsiderPlanCache indicates whether the expression can be considered as a constant expression considering planCache. // If the expression is in plan cache, it should have a const level `ConstStrict` because it can be shared across statements. // If the expression is not in plan cache, `ConstOnlyInContext` is enough because it is only used in one statement. // Please notice that if the expression may be cached in other ways except plan cache, we should not use this function. func ConstExprConsiderPlanCache(expr Expression, inPlanCache bool) bool { switch expr.ConstLevel() { case ConstStrict: return true case ConstOnlyInContext: return !inPlanCache default: return false } } // ExprsHasSideEffects checks if any of the expressions has side effects. func ExprsHasSideEffects(exprs []Expression) bool { return slices.ContainsFunc(exprs, ExprHasSetVarOrSleep) } // ExprHasSetVarOrSleep checks if the expression has SetVar function or Sleep function. func ExprHasSetVarOrSleep(expr Expression) bool { scalaFunc, isScalaFunc := expr.(*ScalarFunction) if !isScalaFunc { return false } if scalaFunc.FuncName.L == ast.SetVar || scalaFunc.FuncName.L == ast.Sleep { return true } return slices.ContainsFunc(scalaFunc.GetArgs(), ExprHasSetVarOrSleep) } // ExecBinaryParam parse execute binary param arguments to datum slice. func ExecBinaryParam(typectx types.Context, binaryParams []param.BinaryParam) (params []Expression, err error) { var ( tmp any ) params = make([]Expression, len(binaryParams)) args := make([]types.Datum, len(binaryParams)) for i := range args { tp := binaryParams[i].Tp isUnsigned := binaryParams[i].IsUnsigned switch tp { case mysql.TypeNull: var nilDatum types.Datum nilDatum.SetNull() args[i] = nilDatum continue case mysql.TypeTiny: if isUnsigned { args[i] = types.NewUintDatum(uint64(binaryParams[i].Val[0])) } else { args[i] = types.NewIntDatum(int64(int8(binaryParams[i].Val[0]))) } continue case mysql.TypeShort, mysql.TypeYear: valU16 := binary.LittleEndian.Uint16(binaryParams[i].Val) if isUnsigned { args[i] = types.NewUintDatum(uint64(valU16)) } else { args[i] = types.NewIntDatum(int64(int16(valU16))) } continue case mysql.TypeInt24, mysql.TypeLong: valU32 := binary.LittleEndian.Uint32(binaryParams[i].Val) if isUnsigned { args[i] = types.NewUintDatum(uint64(valU32)) } else { args[i] = types.NewIntDatum(int64(int32(valU32))) } continue case mysql.TypeLonglong: valU64 := binary.LittleEndian.Uint64(binaryParams[i].Val) if isUnsigned { args[i] = types.NewUintDatum(valU64) } else { args[i] = types.NewIntDatum(int64(valU64)) } continue case mysql.TypeFloat: args[i] = types.NewFloat32Datum(math.Float32frombits(binary.LittleEndian.Uint32(binaryParams[i].Val))) continue case mysql.TypeDouble: args[i] = types.NewFloat64Datum(math.Float64frombits(binary.LittleEndian.Uint64(binaryParams[i].Val))) continue case mysql.TypeDate, mysql.TypeTimestamp, mysql.TypeDatetime: switch len(binaryParams[i].Val) { case 0: tmp = types.ZeroDatetimeStr case 4: _, tmp = binaryDate(0, binaryParams[i].Val) case 7: _, tmp = binaryDateTime(0, binaryParams[i].Val) case 11: _, tmp = binaryTimestamp(0, binaryParams[i].Val) case 13: _, tmp = binaryTimestampWithTZ(0, binaryParams[i].Val) default: err = mysql.ErrMalformPacket return } // TODO: generate the time datum directly var parseTime func(types.Context, string) (types.Time, error) switch tp { case mysql.TypeDate: parseTime = types.ParseDate case mysql.TypeDatetime: parseTime = types.ParseDatetime case mysql.TypeTimestamp: // To be compatible with MySQL, even the type of parameter is // TypeTimestamp, the return type should also be `Datetime`. parseTime = types.ParseDatetime } var time types.Time time, err = parseTime(typectx, tmp.(string)) err = typectx.HandleTruncate(err) if err != nil { return } args[i] = types.NewDatum(time) continue case mysql.TypeDuration: fsp := 0 switch len(binaryParams[i].Val) { case 0: tmp = "0" case 8: isNegative := binaryParams[i].Val[0] if isNegative > 1 { err = mysql.ErrMalformPacket return } _, tmp = binaryDuration(1, binaryParams[i].Val, isNegative) case 12: isNegative := binaryParams[i].Val[0] if isNegative > 1 { err = mysql.ErrMalformPacket return } _, tmp = binaryDurationWithMS(1, binaryParams[i].Val, isNegative) fsp = types.MaxFsp default: err = mysql.ErrMalformPacket return } // TODO: generate the duration datum directly var dur types.Duration dur, _, err = types.ParseDuration(typectx, tmp.(string), fsp) err = typectx.HandleTruncate(err) if err != nil { return } args[i] = types.NewDatum(dur) continue case mysql.TypeNewDecimal: if binaryParams[i].IsNull { args[i] = types.NewDecimalDatum(nil) } else { var dec types.MyDecimal if err := typectx.HandleTruncate(dec.FromString(binaryParams[i].Val)); err != nil && err != types.ErrTruncated { return nil, err } args[i] = types.NewDecimalDatum(&dec) } continue case mysql.TypeBlob, mysql.TypeTinyBlob, mysql.TypeMediumBlob, mysql.TypeLongBlob: if binaryParams[i].IsNull { args[i] = types.NewBytesDatum(nil) } else { args[i] = types.NewBytesDatum(binaryParams[i].Val) } continue case mysql.TypeUnspecified, mysql.TypeVarchar, mysql.TypeVarString, mysql.TypeString, mysql.TypeEnum, mysql.TypeSet, mysql.TypeGeometry, mysql.TypeBit: if !binaryParams[i].IsNull { tmp = string(hack.String(binaryParams[i].Val)) } else { tmp = nil } args[i] = types.NewDatum(tmp) continue default: err = param.ErrUnknownFieldType.GenWithStack("stmt unknown field type %d", tp) return } } for i := range params { ft := new(types.FieldType) types.InferParamTypeFromUnderlyingValue(args[i].GetValue(), ft) params[i] = &Constant{Value: args[i], RetType: ft} } return } func binaryDate(pos int, paramValues []byte) (int, string) { year := binary.LittleEndian.Uint16(paramValues[pos : pos+2]) pos += 2 month := paramValues[pos] pos++ day := paramValues[pos] pos++ return pos, fmt.Sprintf("%04d-%02d-%02d", year, month, day) } func binaryDateTime(pos int, paramValues []byte) (int, string) { pos, date := binaryDate(pos, paramValues) hour := paramValues[pos] pos++ minute := paramValues[pos] pos++ second := paramValues[pos] pos++ return pos, fmt.Sprintf("%s %02d:%02d:%02d", date, hour, minute, second) } func binaryTimestamp(pos int, paramValues []byte) (int, string) { pos, dateTime := binaryDateTime(pos, paramValues) microSecond := binary.LittleEndian.Uint32(paramValues[pos : pos+4]) pos += 4 return pos, fmt.Sprintf("%s.%06d", dateTime, microSecond) } func binaryTimestampWithTZ(pos int, paramValues []byte) (int, string) { pos, timestamp := binaryTimestamp(pos, paramValues) tzShiftInMin := int16(binary.LittleEndian.Uint16(paramValues[pos : pos+2])) tzShiftHour := tzShiftInMin / 60 tzShiftAbsMin := tzShiftInMin % 60 if tzShiftAbsMin < 0 { tzShiftAbsMin = -tzShiftAbsMin } pos += 2 return pos, fmt.Sprintf("%s%+02d:%02d", timestamp, tzShiftHour, tzShiftAbsMin) } func binaryDuration(pos int, paramValues []byte, isNegative uint8) (int, string) { sign := "" if isNegative == 1 { sign = "-" } days := binary.LittleEndian.Uint32(paramValues[pos : pos+4]) pos += 4 hours := paramValues[pos] pos++ minutes := paramValues[pos] pos++ seconds := paramValues[pos] pos++ return pos, fmt.Sprintf("%s%d %02d:%02d:%02d", sign, days, hours, minutes, seconds) } func binaryDurationWithMS(pos int, paramValues []byte, isNegative uint8) (int, string) { pos, dur := binaryDuration(pos, paramValues, isNegative) microSecond := binary.LittleEndian.Uint32(paramValues[pos : pos+4]) pos += 4 return pos, fmt.Sprintf("%s.%06d", dur, microSecond) } // IsConstNull is used to check whether the expression is a constant null expression. // For example, `1 > NULL` is a constant null expression. // Now we just assume that the first argrument is a column, // the second argument is a constant null. func IsConstNull(expr Expression) bool { if e, ok := expr.(*ScalarFunction); ok { switch e.FuncName.L { case ast.LT, ast.LE, ast.GT, ast.GE, ast.EQ, ast.NE: if constExpr, ok := e.GetArgs()[1].(*Constant); ok && constExpr.Value.IsNull() && constExpr.DeferredExpr == nil { return true } } } return false } // IsColOpCol is to whether ScalarFunction meets col op col condition. func IsColOpCol(sf *ScalarFunction) (_, _ *Column, _ bool) { args := sf.GetArgs() if len(args) == 2 { col2, ok2 := args[1].(*Column) col1, ok1 := args[0].(*Column) return col1, col2, ok1 && ok2 } return nil, nil, false } // ExtractColumnsFromColOpCol is to extract columns from col op col condition. func ExtractColumnsFromColOpCol(sf *ScalarFunction) (_, _ *Column) { args := sf.GetArgs() if len(args) == 2 { col2 := args[1].(*Column) col1 := args[0].(*Column) return col1, col2 } return nil, nil }