// Copyright 2016 PingCAP, Inc. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. package expression import ( "bytes" "slices" "unsafe" "github.com/pingcap/errors" "github.com/pingcap/tidb/pkg/expression/exprctx" "github.com/pingcap/tidb/pkg/parser/ast" "github.com/pingcap/tidb/pkg/parser/mysql" "github.com/pingcap/tidb/pkg/parser/terror" "github.com/pingcap/tidb/pkg/planner/cascades/base" "github.com/pingcap/tidb/pkg/sessionctx/variable" "github.com/pingcap/tidb/pkg/types" "github.com/pingcap/tidb/pkg/util/chunk" "github.com/pingcap/tidb/pkg/util/codec" "github.com/pingcap/tidb/pkg/util/dbterror/plannererrors" "github.com/pingcap/tidb/pkg/util/hack" "github.com/pingcap/tidb/pkg/util/intest" ) var _ base.HashEquals = &ScalarFunction{} // ScalarFunction is the function that returns a value. type ScalarFunction struct { FuncName ast.CIStr // RetType is the type that ScalarFunction returns. // TODO: Implement type inference here, now we use ast's return type temporarily. RetType *types.FieldType `plan-cache-clone:"shallow"` Function builtinFunc hashcode []byte canonicalhashcode []byte } // SafeToShareAcrossSession returns if the function can be shared across different sessions. func (sf *ScalarFunction) SafeToShareAcrossSession() bool { return sf.Function.SafeToShareAcrossSession() } // VecEvalInt evaluates this expression in a vectorized manner. func (sf *ScalarFunction) VecEvalInt(ctx EvalContext, input *chunk.Chunk, result *chunk.Column) error { intest.Assert(ctx != nil) if intest.EnableAssert { ctx = wrapEvalAssert(ctx, sf.Function) } return sf.Function.vecEvalInt(ctx, input, result) } // VecEvalReal evaluates this expression in a vectorized manner. func (sf *ScalarFunction) VecEvalReal(ctx EvalContext, input *chunk.Chunk, result *chunk.Column) error { intest.Assert(ctx != nil) if intest.EnableAssert { ctx = wrapEvalAssert(ctx, sf.Function) } return sf.Function.vecEvalReal(ctx, input, result) } // VecEvalString evaluates this expression in a vectorized manner. func (sf *ScalarFunction) VecEvalString(ctx EvalContext, input *chunk.Chunk, result *chunk.Column) error { intest.Assert(ctx != nil) if intest.EnableAssert { ctx = wrapEvalAssert(ctx, sf.Function) } return sf.Function.vecEvalString(ctx, input, result) } // VecEvalDecimal evaluates this expression in a vectorized manner. func (sf *ScalarFunction) VecEvalDecimal(ctx EvalContext, input *chunk.Chunk, result *chunk.Column) error { intest.Assert(ctx != nil) if intest.EnableAssert { ctx = wrapEvalAssert(ctx, sf.Function) } return sf.Function.vecEvalDecimal(ctx, input, result) } // VecEvalTime evaluates this expression in a vectorized manner. func (sf *ScalarFunction) VecEvalTime(ctx EvalContext, input *chunk.Chunk, result *chunk.Column) error { intest.Assert(ctx != nil) if intest.EnableAssert { ctx = wrapEvalAssert(ctx, sf.Function) } return sf.Function.vecEvalTime(ctx, input, result) } // VecEvalDuration evaluates this expression in a vectorized manner. func (sf *ScalarFunction) VecEvalDuration(ctx EvalContext, input *chunk.Chunk, result *chunk.Column) error { intest.Assert(ctx != nil) if intest.EnableAssert { ctx = wrapEvalAssert(ctx, sf.Function) } return sf.Function.vecEvalDuration(ctx, input, result) } // VecEvalJSON evaluates this expression in a vectorized manner. func (sf *ScalarFunction) VecEvalJSON(ctx EvalContext, input *chunk.Chunk, result *chunk.Column) error { intest.Assert(ctx != nil) if intest.EnableAssert { ctx = wrapEvalAssert(ctx, sf.Function) } return sf.Function.vecEvalJSON(ctx, input, result) } // VecEvalVectorFloat32 evaluates this expression in a vectorized manner. func (sf *ScalarFunction) VecEvalVectorFloat32(ctx EvalContext, input *chunk.Chunk, result *chunk.Column) error { return sf.Function.vecEvalVectorFloat32(ctx, input, result) } // GetArgs gets arguments of function. func (sf *ScalarFunction) GetArgs() []Expression { return sf.Function.getArgs() } // Vectorized returns if this expression supports vectorized evaluation. func (sf *ScalarFunction) Vectorized() bool { return sf.Function.vectorized() && sf.Function.isChildrenVectorized() } // StringWithCtx implements Expression interface. func (sf *ScalarFunction) StringWithCtx(ctx ParamValues, redact string) string { buffer := bytes.NewBuffer(make([]byte, 0, len(sf.FuncName.L)+8+16*len(sf.GetArgs()))) buffer.WriteString(sf.FuncName.L) buffer.WriteByte('(') switch sf.FuncName.L { case ast.Cast: for _, arg := range sf.GetArgs() { buffer.WriteString(arg.StringWithCtx(ctx, redact)) buffer.WriteString(", ") buffer.WriteString(sf.RetType.String()) } default: for i, arg := range sf.GetArgs() { buffer.WriteString(arg.StringWithCtx(ctx, redact)) if i+1 != len(sf.GetArgs()) { buffer.WriteString(", ") } } } buffer.WriteString(")") return buffer.String() } // String returns the string representation of the function func (sf *ScalarFunction) String() string { return sf.StringWithCtx(exprctx.EmptyParamValues, errors.RedactLogDisable) } // typeInferForNull infers the NULL constants field type and set the field type // of NULL constant same as other non-null operands. func typeInferForNull(ctx EvalContext, args []Expression) { if len(args) < 2 { return } var isNull = func(expr Expression) bool { cons, ok := expr.(*Constant) return ok && cons.RetType.GetType() == mysql.TypeNull && cons.Value.IsNull() } // Infer the actual field type of the NULL constant. var retFieldTp *types.FieldType var hasNullArg bool for i := len(args) - 1; i >= 0; i-- { isNullArg := isNull(args[i]) if !isNullArg && retFieldTp == nil { retFieldTp = args[i].GetType(ctx) } hasNullArg = hasNullArg || isNullArg // Break if there are both NULL and non-NULL expression if hasNullArg && retFieldTp != nil { break } } if !hasNullArg || retFieldTp == nil { return } for i, arg := range args { argflags := arg.GetType(ctx) if isNull(arg) && !(argflags.Equals(retFieldTp) && mysql.HasNotNullFlag(retFieldTp.GetFlag())) { newarg := arg.Clone() *newarg.GetType(ctx) = *retFieldTp.Clone() newarg.GetType(ctx).DelFlag(mysql.NotNullFlag) // Remove NotNullFlag of NullConst args[i] = newarg } } } // newFunctionImpl creates a new scalar function or constant. // fold: 1 means folding constants, while 0 means not, // -1 means try to fold constants if without errors/warnings, otherwise not. func newFunctionImpl(ctx BuildContext, fold int, funcName string, retType *types.FieldType, checkOrInit ScalarFunctionCallBack, args ...Expression) (ret Expression, err error) { if retType == nil { return nil, errors.Errorf("RetType cannot be nil for ScalarFunction") } switch funcName { case ast.Cast: return BuildCastFunction(ctx, args[0], retType), nil case ast.GetVar: return BuildGetVarFunction(ctx, args[0], retType) case InternalFuncFromBinary: return BuildFromBinaryFunction(ctx, args[0], retType, false), nil case InternalFuncToBinary: return BuildToBinaryFunction(ctx, args[0]), nil case ast.Sysdate: if ctx.GetSysdateIsNow() { funcName = ast.Now } } fc, ok := funcs[funcName] if !ok { if extFunc, exist := extensionFuncs.Load(funcName); exist { fc = extFunc.(functionClass) ok = true } } if !ok { db := ctx.GetEvalCtx().CurrentDB() if db == "" { return nil, errors.Trace(plannererrors.ErrNoDB) } return nil, ErrFunctionNotExists.GenWithStackByArgs("FUNCTION", db+"."+funcName) } noopFuncsMode := ctx.GetNoopFuncsMode() if noopFuncsMode != variable.OnInt { if _, ok := noopFuncs[funcName]; ok { err := ErrFunctionsNoopImpl.FastGenByArgs(funcName) if noopFuncsMode == variable.OffInt { return nil, errors.Trace(err) } // NoopFuncsMode is Warn, append an error ctx.GetEvalCtx().AppendWarning(err) } } funcArgs := slices.Clone(args) switch funcName { case ast.If, ast.Ifnull, ast.Nullif: // Do nothing. Because it will call InferType4ControlFuncs. case ast.RowFunc: // Do nothing. Because it shouldn't use ROW's args to infer null type. // For example, expression ('abc', 1) = (null, 0). Null's type should be STRING, not INT. // The type infer happens when converting the expression to ('abc' = null) and (1 = 0). default: typeInferForNull(ctx.GetEvalCtx(), funcArgs) } f, err := fc.getFunction(ctx, funcArgs) if err != nil { return nil, err } if builtinRetTp := f.getRetTp(); builtinRetTp.GetType() != mysql.TypeUnspecified || retType.GetType() == mysql.TypeUnspecified { retType = builtinRetTp } sf := &ScalarFunction{ FuncName: ast.NewCIStr(funcName), RetType: retType, Function: f, } if checkOrInit != nil { sf2, err := checkOrInit(sf) if err != nil { return nil, errors.Trace(err) } sf = sf2 } if fold == 1 { return FoldConstant(ctx, sf), nil } else if fold == -1 { // try to fold constants, and return the original function if errors/warnings occur evalCtx := ctx.GetEvalCtx() beforeWarns := evalCtx.WarningCount() newSf := FoldConstant(ctx, sf) afterWarns := evalCtx.WarningCount() if afterWarns > beforeWarns { evalCtx.TruncateWarnings(beforeWarns) return sf, nil } return newSf, nil } return sf, nil } // ScalarFunctionCallBack is the definition of callback of calling a newFunction. type ScalarFunctionCallBack func(function *ScalarFunction) (*ScalarFunction, error) func defaultScalarFunctionCheck(function *ScalarFunction) (*ScalarFunction, error) { // todo: more scalar function init actions can be added here, or setting up with customized init callback. if function.FuncName.L == ast.Grouping { if !function.Function.(*BuiltinGroupingImplSig).isMetaInited { return function, errors.Errorf("grouping meta data hasn't been initialized, try use function clone instead") } } return function, nil } // NewFunctionWithInit creates a new scalar function with callback init function. func NewFunctionWithInit(ctx BuildContext, funcName string, retType *types.FieldType, init ScalarFunctionCallBack, args ...Expression) (Expression, error) { return newFunctionImpl(ctx, 1, funcName, retType, init, args...) } // NewFunction creates a new scalar function or constant via a constant folding. func NewFunction(ctx BuildContext, funcName string, retType *types.FieldType, args ...Expression) (Expression, error) { return newFunctionImpl(ctx, 1, funcName, retType, defaultScalarFunctionCheck, args...) } // NewFunctionBase creates a new scalar function with no constant folding. func NewFunctionBase(ctx BuildContext, funcName string, retType *types.FieldType, args ...Expression) (Expression, error) { return newFunctionImpl(ctx, 0, funcName, retType, defaultScalarFunctionCheck, args...) } // NewFunctionTryFold creates a new scalar function with trying constant folding. func NewFunctionTryFold(ctx BuildContext, funcName string, retType *types.FieldType, args ...Expression) (Expression, error) { return newFunctionImpl(ctx, -1, funcName, retType, defaultScalarFunctionCheck, args...) } // NewFunctionInternal is similar to NewFunction, but do not return error, should only be used internally. // Deprecated: use NewFunction instead, old logic here is for the convenience of go linter error check. // while for the new function creation, some errors can also be thrown out, for example, args verification // error, collation derivation error, special function with meta doesn't be initialized error and so on. // only threw the these internal error out, then we can debug and dig it out quickly rather than in a confusion // of index out of range / nil pointer error / function execution error. func NewFunctionInternal(ctx BuildContext, funcName string, retType *types.FieldType, args ...Expression) Expression { expr, err := NewFunction(ctx, funcName, retType, args...) terror.Log(errors.Trace(err)) return expr } // ScalarFuncs2Exprs converts []*ScalarFunction to []Expression. func ScalarFuncs2Exprs(funcs []*ScalarFunction) []Expression { result := make([]Expression, 0, len(funcs)) for _, col := range funcs { result = append(result, col) } return result } // Clone implements Expression interface. func (sf *ScalarFunction) Clone() Expression { c := &ScalarFunction{ FuncName: sf.FuncName, RetType: sf.RetType, Function: sf.Function.Clone(), } // An implicit assumption: ScalarFunc.RetType == ScalarFunc.builtinFunc.RetType if sf.canonicalhashcode != nil { c.canonicalhashcode = slices.Clone(sf.canonicalhashcode) } c.SetCharsetAndCollation(sf.CharsetAndCollation()) c.SetCoercibility(sf.Coercibility()) c.SetRepertoire(sf.Repertoire()) return c } // GetType implements Expression interface. func (sf *ScalarFunction) GetType(_ EvalContext) *types.FieldType { return sf.GetStaticType() } // GetStaticType returns the static type of the scalar function. func (sf *ScalarFunction) GetStaticType() *types.FieldType { return sf.RetType } // Equal implements Expression interface. func (sf *ScalarFunction) Equal(ctx EvalContext, e Expression) bool { intest.Assert(ctx != nil) fun, ok := e.(*ScalarFunction) if !ok { return false } if sf.FuncName.L != fun.FuncName.L { return false } if !sf.RetType.Equal(fun.RetType) { return false } return sf.Function.equal(ctx, fun.Function) } // IsCorrelated implements Expression interface. func (sf *ScalarFunction) IsCorrelated() bool { for _, arg := range sf.GetArgs() { if arg.IsCorrelated() { return true } } return false } // ConstLevel returns the const level for the expression func (sf *ScalarFunction) ConstLevel() ConstLevel { // Note: some unfoldable functions are deterministic, we use unFoldableFunctions here for simplification. if _, ok := unFoldableFunctions[sf.FuncName.L]; ok { return ConstNone } if _, ok := sf.Function.(*extensionFuncSig); ok { // we should return `ConstNone` for extension functions for safety, because it may have a side effect. return ConstNone } level := ConstStrict for _, arg := range sf.GetArgs() { argLevel := arg.ConstLevel() if argLevel == ConstNone { return ConstNone } if argLevel < level { level = argLevel } } return level } // Decorrelate implements Expression interface. func (sf *ScalarFunction) Decorrelate(schema *Schema) Expression { for i, arg := range sf.GetArgs() { sf.GetArgs()[i] = arg.Decorrelate(schema) } return sf } // Traverse implements the TraverseDown interface. func (sf *ScalarFunction) Traverse(action TraverseAction) Expression { return action.Transform(sf) } // Eval implements Expression interface. func (sf *ScalarFunction) Eval(ctx EvalContext, row chunk.Row) (d types.Datum, err error) { var ( res any isNull bool ) intest.AssertNotNil(ctx) switch tp, evalType := sf.GetType(ctx), sf.GetType(ctx).EvalType(); evalType { case types.ETInt: var intRes int64 intRes, isNull, err = sf.EvalInt(ctx, row) if mysql.HasUnsignedFlag(tp.GetFlag()) { res = uint64(intRes) } else { res = intRes } case types.ETReal: res, isNull, err = sf.EvalReal(ctx, row) case types.ETDecimal: res, isNull, err = sf.EvalDecimal(ctx, row) case types.ETDatetime, types.ETTimestamp: res, isNull, err = sf.EvalTime(ctx, row) case types.ETDuration: res, isNull, err = sf.EvalDuration(ctx, row) case types.ETJson: res, isNull, err = sf.EvalJSON(ctx, row) case types.ETVectorFloat32: res, isNull, err = sf.EvalVectorFloat32(ctx, row) case types.ETString: var str string str, isNull, err = sf.EvalString(ctx, row) if !isNull && err == nil && tp.GetType() == mysql.TypeEnum { res, err = types.ParseEnum(tp.GetElems(), str, tp.GetCollate()) tc := typeCtx(ctx) err = tc.HandleTruncate(err) } else { res = str } } if isNull || err != nil { d.SetNull() return d, err } d.SetValue(res, sf.RetType) return } // EvalInt implements Expression interface. func (sf *ScalarFunction) EvalInt(ctx EvalContext, row chunk.Row) (int64, bool, error) { intest.Assert(ctx != nil) if intest.EnableAssert { ctx = wrapEvalAssert(ctx, sf.Function) } return sf.Function.evalInt(ctx, row) } // EvalReal implements Expression interface. func (sf *ScalarFunction) EvalReal(ctx EvalContext, row chunk.Row) (float64, bool, error) { intest.Assert(ctx != nil) if intest.EnableAssert { ctx = wrapEvalAssert(ctx, sf.Function) } return sf.Function.evalReal(ctx, row) } // EvalDecimal implements Expression interface. func (sf *ScalarFunction) EvalDecimal(ctx EvalContext, row chunk.Row) (*types.MyDecimal, bool, error) { intest.Assert(ctx != nil) if intest.EnableAssert { ctx = wrapEvalAssert(ctx, sf.Function) } return sf.Function.evalDecimal(ctx, row) } // EvalString implements Expression interface. func (sf *ScalarFunction) EvalString(ctx EvalContext, row chunk.Row) (string, bool, error) { intest.Assert(ctx != nil) if intest.EnableAssert { ctx = wrapEvalAssert(ctx, sf.Function) } return sf.Function.evalString(ctx, row) } // EvalTime implements Expression interface. func (sf *ScalarFunction) EvalTime(ctx EvalContext, row chunk.Row) (types.Time, bool, error) { intest.Assert(ctx != nil) if intest.EnableAssert { ctx = wrapEvalAssert(ctx, sf.Function) } return sf.Function.evalTime(ctx, row) } // EvalDuration implements Expression interface. func (sf *ScalarFunction) EvalDuration(ctx EvalContext, row chunk.Row) (types.Duration, bool, error) { intest.Assert(ctx != nil) if intest.EnableAssert { ctx = wrapEvalAssert(ctx, sf.Function) } return sf.Function.evalDuration(ctx, row) } // EvalJSON implements Expression interface. func (sf *ScalarFunction) EvalJSON(ctx EvalContext, row chunk.Row) (types.BinaryJSON, bool, error) { intest.Assert(ctx != nil) if intest.EnableAssert { ctx = wrapEvalAssert(ctx, sf.Function) } return sf.Function.evalJSON(ctx, row) } // EvalVectorFloat32 implements Expression interface. func (sf *ScalarFunction) EvalVectorFloat32(ctx EvalContext, row chunk.Row) (types.VectorFloat32, bool, error) { return sf.Function.evalVectorFloat32(ctx, row) } // HashCode implements Expression interface. func (sf *ScalarFunction) HashCode() []byte { if len(sf.hashcode) > 0 { if intest.InTest { copyhashcode := make([]byte, len(sf.hashcode)) copy(copyhashcode, sf.hashcode) ReHashCode(sf) intest.Assert(bytes.Equal(sf.hashcode, copyhashcode), "HashCode should not change after ReHashCode is called") } return sf.hashcode } ReHashCode(sf) return sf.hashcode } // CanonicalHashCode implements Expression interface. func (sf *ScalarFunction) CanonicalHashCode() []byte { if len(sf.canonicalhashcode) > 0 { return sf.canonicalhashcode } simpleCanonicalizedHashCode(sf) return sf.canonicalhashcode } // ExpressionsSemanticEqual is used to judge whether two expression tree is semantic equivalent. func ExpressionsSemanticEqual(expr1, expr2 Expression) bool { return bytes.Equal(expr1.CanonicalHashCode(), expr2.CanonicalHashCode()) } // simpleCanonicalizedHashCode is used to judge whether two expression is semantically equal. func simpleCanonicalizedHashCode(sf *ScalarFunction) { if sf.canonicalhashcode != nil { sf.canonicalhashcode = sf.canonicalhashcode[:0] } sf.canonicalhashcode = append(sf.canonicalhashcode, scalarFunctionFlag) argsHashCode := make([][]byte, 0, len(sf.GetArgs())) for _, arg := range sf.GetArgs() { argsHashCode = append(argsHashCode, arg.CanonicalHashCode()) } switch sf.FuncName.L { case ast.Plus, ast.Mul, ast.EQ, ast.In, ast.LogicOr, ast.LogicAnd: // encode original function name. sf.canonicalhashcode = codec.EncodeCompactBytes(sf.canonicalhashcode, hack.Slice(sf.FuncName.L)) // reorder parameters hashcode, eg: a+b and b+a should has the same hashcode here. slices.SortFunc(argsHashCode, func(i, j []byte) int { return bytes.Compare(i, j) }) for _, argCode := range argsHashCode { sf.canonicalhashcode = append(sf.canonicalhashcode, argCode...) } case ast.GE, ast.LE: // directed binary OP: a >= b and b <= a should have the same hashcode. // encode GE function name. sf.canonicalhashcode = codec.EncodeCompactBytes(sf.canonicalhashcode, hack.Slice(ast.GE)) // encode GE function name and switch the args order. if sf.FuncName.L == ast.GE { for _, argCode := range argsHashCode { sf.canonicalhashcode = append(sf.canonicalhashcode, argCode...) } } else { for i := len(argsHashCode) - 1; i >= 0; i-- { sf.canonicalhashcode = append(sf.canonicalhashcode, argsHashCode[i]...) } } case ast.GT, ast.LT: sf.canonicalhashcode = codec.EncodeCompactBytes(sf.canonicalhashcode, hack.Slice(ast.GT)) if sf.FuncName.L == ast.GT { for _, argCode := range argsHashCode { sf.canonicalhashcode = append(sf.canonicalhashcode, argCode...) } } else { for i := len(argsHashCode) - 1; i >= 0; i-- { sf.canonicalhashcode = append(sf.canonicalhashcode, argsHashCode[i]...) } } case ast.UnaryNot: child, ok := sf.GetArgs()[0].(*ScalarFunction) if !ok { // encode original function name. sf.canonicalhashcode = codec.EncodeCompactBytes(sf.canonicalhashcode, hack.Slice(sf.FuncName.L)) // use the origin arg hash code. for _, argCode := range argsHashCode { sf.canonicalhashcode = append(sf.canonicalhashcode, argCode...) } } else { childArgsHashCode := make([][]byte, 0, len(child.GetArgs())) for _, arg := range child.GetArgs() { childArgsHashCode = append(childArgsHashCode, arg.CanonicalHashCode()) } switch child.FuncName.L { case ast.GT: // not GT ==> LE ==> use GE and switch args sf.canonicalhashcode = codec.EncodeCompactBytes(sf.canonicalhashcode, hack.Slice(ast.GE)) for i := len(childArgsHashCode) - 1; i >= 0; i-- { sf.canonicalhashcode = append(sf.canonicalhashcode, childArgsHashCode[i]...) } case ast.LT: // not LT ==> GE sf.canonicalhashcode = codec.EncodeCompactBytes(sf.canonicalhashcode, hack.Slice(ast.GE)) for _, argCode := range childArgsHashCode { sf.canonicalhashcode = append(sf.canonicalhashcode, argCode...) } case ast.GE: // not GE ==> LT ==> use GT and switch args sf.canonicalhashcode = codec.EncodeCompactBytes(sf.canonicalhashcode, hack.Slice(ast.GT)) for i := len(childArgsHashCode) - 1; i >= 0; i-- { sf.canonicalhashcode = append(sf.canonicalhashcode, childArgsHashCode[i]...) } case ast.LE: // not LE ==> GT sf.canonicalhashcode = codec.EncodeCompactBytes(sf.canonicalhashcode, hack.Slice(ast.GT)) for _, argCode := range childArgsHashCode { sf.canonicalhashcode = append(sf.canonicalhashcode, argCode...) } } } default: // encode original function name. sf.canonicalhashcode = codec.EncodeCompactBytes(sf.canonicalhashcode, hack.Slice(sf.FuncName.L)) for _, argCode := range argsHashCode { sf.canonicalhashcode = append(sf.canonicalhashcode, argCode...) } // Cast is a special case. The RetType should also be considered as an argument. // Please see `newFunctionImpl()` for detail. if sf.FuncName.L == ast.Cast { evalTp := sf.RetType.EvalType() sf.canonicalhashcode = append(sf.canonicalhashcode, byte(evalTp)) } } } // Hash64 implements HashEquals.<0th> interface. func (sf *ScalarFunction) Hash64(h base.Hasher) { h.HashByte(scalarFunctionFlag) h.HashString(sf.FuncName.L) if sf.RetType == nil { h.HashByte(base.NilFlag) } else { h.HashByte(base.NotNilFlag) sf.RetType.Hash64(h) } // hash the arg length to avoid hash collision. h.HashInt(len(sf.GetArgs())) for _, arg := range sf.GetArgs() { arg.Hash64(h) } } // Equals implements HashEquals.<1th> interface. func (sf *ScalarFunction) Equals(other any) bool { sf2, ok := other.(*ScalarFunction) if !ok { return false } if sf == nil { return sf2 == nil } if sf2 == nil { return false } ok = sf.FuncName.L == sf2.FuncName.L ok = ok && (sf.RetType == nil && sf2.RetType == nil || sf.RetType != nil && sf2.RetType != nil && sf.RetType.Equals(sf2.RetType)) if len(sf.GetArgs()) != len(sf2.GetArgs()) { return false } for i, arg := range sf.GetArgs() { ok = ok && arg.Equals(sf2.GetArgs()[i]) if !ok { return false } } return ok } // ReHashCode is used after we change the argument in place. func ReHashCode(sf *ScalarFunction) { sf.hashcode = sf.hashcode[:0] sf.hashcode = slices.Grow(sf.hashcode, 1+len(sf.FuncName.L)+len(sf.GetArgs())*8+1) sf.hashcode = append(sf.hashcode, scalarFunctionFlag) sf.hashcode = codec.EncodeCompactBytes(sf.hashcode, hack.Slice(sf.FuncName.L)) for _, arg := range sf.GetArgs() { sf.hashcode = append(sf.hashcode, arg.HashCode()...) } // Cast is a special case. The RetType should also be considered as an argument. // Please see `newFunctionImpl()` for detail. if sf.FuncName.L == ast.Cast { evalTp := sf.RetType.EvalType() sf.hashcode = append(sf.hashcode, byte(evalTp)) } if sf.FuncName.L == ast.Grouping { sf.hashcode = codec.EncodeInt(sf.hashcode, int64(sf.Function.(*BuiltinGroupingImplSig).GetGroupingMode())) marks := sf.Function.(*BuiltinGroupingImplSig).GetMetaGroupingMarks() sf.hashcode = codec.EncodeInt(sf.hashcode, int64(len(marks))) for _, mark := range marks { sf.hashcode = codec.EncodeInt(sf.hashcode, int64(len(mark))) // we need to sort map keys to ensure the hashcode is deterministic. keys := make([]uint64, 0, len(mark)) for k := range mark { keys = append(keys, k) } slices.Sort(keys) for _, k := range keys { sf.hashcode = codec.EncodeInt(sf.hashcode, int64(k)) } } } } // ResolveIndices implements Expression interface. func (sf *ScalarFunction) ResolveIndices(schema *Schema) (Expression, error) { newSf := sf.Clone() err := newSf.resolveIndices(schema) return newSf, err } func (sf *ScalarFunction) resolveIndices(schema *Schema) error { for _, arg := range sf.GetArgs() { err := arg.resolveIndices(schema) if err != nil { return err } } return nil } // ResolveIndicesByVirtualExpr implements Expression interface. func (sf *ScalarFunction) ResolveIndicesByVirtualExpr(ctx EvalContext, schema *Schema) (Expression, bool) { newSf := sf.Clone() isOK := newSf.resolveIndicesByVirtualExpr(ctx, schema) return newSf, isOK } func (sf *ScalarFunction) resolveIndicesByVirtualExpr(ctx EvalContext, schema *Schema) bool { for _, arg := range sf.GetArgs() { isOk := arg.resolveIndicesByVirtualExpr(ctx, schema) if !isOk { return false } } return true } // RemapColumn remaps columns with provided mapping and returns new expression func (sf *ScalarFunction) RemapColumn(m map[int64]*Column) (Expression, error) { newSf, ok := sf.Clone().(*ScalarFunction) if !ok { return nil, errors.New("failed to cast to scalar function") } for i, arg := range sf.GetArgs() { newArg, err := arg.RemapColumn(m) if err != nil { return nil, err } newSf.GetArgs()[i] = newArg } // clear hash code newSf.hashcode = nil return newSf, nil } // GetSingleColumn returns (Col, Desc) when the ScalarFunction is equivalent to (Col, Desc) // when used as a sort key, otherwise returns (nil, false). // // Can only handle: // - ast.Plus // - ast.Minus // - ast.UnaryMinus func (sf *ScalarFunction) GetSingleColumn(reverse bool) (*Column, bool) { switch sf.FuncName.String() { case ast.Plus: args := sf.GetArgs() switch tp := args[0].(type) { case *Column: if _, ok := args[1].(*Constant); !ok { return nil, false } return tp, reverse case *ScalarFunction: if _, ok := args[1].(*Constant); !ok { return nil, false } return tp.GetSingleColumn(reverse) case *Constant: switch rtp := args[1].(type) { case *Column: return rtp, reverse case *ScalarFunction: return rtp.GetSingleColumn(reverse) } } return nil, false case ast.Minus: args := sf.GetArgs() switch tp := args[0].(type) { case *Column: if _, ok := args[1].(*Constant); !ok { return nil, false } return tp, reverse case *ScalarFunction: if _, ok := args[1].(*Constant); !ok { return nil, false } return tp.GetSingleColumn(reverse) case *Constant: switch rtp := args[1].(type) { case *Column: return rtp, !reverse case *ScalarFunction: return rtp.GetSingleColumn(!reverse) } } return nil, false case ast.UnaryMinus: args := sf.GetArgs() switch tp := args[0].(type) { case *Column: return tp, !reverse case *ScalarFunction: return tp.GetSingleColumn(!reverse) } return nil, false } return nil, false } // Coercibility returns the coercibility value which is used to check collations. func (sf *ScalarFunction) Coercibility() Coercibility { if !sf.Function.HasCoercibility() { sf.SetCoercibility(deriveCoercibilityForScalarFunc(sf)) } return sf.Function.Coercibility() } // HasCoercibility ... func (sf *ScalarFunction) HasCoercibility() bool { return sf.Function.HasCoercibility() } // SetCoercibility sets a specified coercibility for this expression. func (sf *ScalarFunction) SetCoercibility(val Coercibility) { sf.Function.SetCoercibility(val) } // CharsetAndCollation gets charset and collation. func (sf *ScalarFunction) CharsetAndCollation() (string, string) { return sf.Function.CharsetAndCollation() } // SetCharsetAndCollation sets charset and collation. func (sf *ScalarFunction) SetCharsetAndCollation(chs, coll string) { sf.Function.SetCharsetAndCollation(chs, coll) } // Repertoire returns the repertoire value which is used to check collations. func (sf *ScalarFunction) Repertoire() Repertoire { return sf.Function.Repertoire() } // SetRepertoire sets a specified repertoire for this expression. func (sf *ScalarFunction) SetRepertoire(r Repertoire) { sf.Function.SetRepertoire(r) } // IsExplicitCharset return the charset is explicit set or not. func (sf *ScalarFunction) IsExplicitCharset() bool { return sf.Function.IsExplicitCharset() } // SetExplicitCharset set the charset is explicit or not. func (sf *ScalarFunction) SetExplicitCharset(explicit bool) { sf.Function.SetExplicitCharset(explicit) } const emptyScalarFunctionSize = int64(unsafe.Sizeof(ScalarFunction{})) // MemoryUsage return the memory usage of ScalarFunction func (sf *ScalarFunction) MemoryUsage() (sum int64) { if sf == nil { return } sum = emptyScalarFunctionSize + int64(len(sf.FuncName.L)+len(sf.FuncName.O)) + int64(cap(sf.hashcode)) if sf.RetType != nil { sum += sf.RetType.MemoryUsage() } if sf.Function != nil { sum += sf.Function.MemoryUsage() } return sum }