// Copyright 2016 PingCAP, Inc. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // See the License for the specific language governing permissions and // limitations under the License. package expression import ( "strconv" "strings" "time" "unicode" "github.com/juju/errors" "github.com/pingcap/tidb/ast" "github.com/pingcap/tidb/mysql" "github.com/pingcap/tidb/parser/opcode" "github.com/pingcap/tidb/sessionctx" "github.com/pingcap/tidb/terror" "github.com/pingcap/tidb/types" "github.com/pingcap/tidb/util/hack" ) // Filter the input expressions, append the results to result. func Filter(result []Expression, input []Expression, filter func(Expression) bool) []Expression { for _, e := range input { if filter(e) { result = append(result, e) } } return result } // ExtractColumns extracts all columns from an expression. func ExtractColumns(expr Expression) (cols []*Column) { // Pre-allocate a slice to reduce allocation, 8 doesn't have special meaning. result := make([]*Column, 0, 8) return extractColumns(result, expr, nil) } // ExtractColumnsFromExpressions is a more efficient version of ExtractColumns for batch operation. // filter can be nil, or a function to filter the result column. // It's often observed that the pattern of the caller like this: // // cols := ExtractColumns(...) // for _, col := range cols { // if xxx(col) {...} // } // // Provide an additional filter argument, this can be done in one step. // To avoid allocation for cols that not need. func ExtractColumnsFromExpressions(result []*Column, exprs []Expression, filter func(*Column) bool) []*Column { for _, expr := range exprs { result = extractColumns(result, expr, filter) } return result } func extractColumns(result []*Column, expr Expression, filter func(*Column) bool) []*Column { switch v := expr.(type) { case *Column: if filter == nil || filter(v) { result = append(result, v) } case *ScalarFunction: for _, arg := range v.GetArgs() { result = extractColumns(result, arg, filter) } } return result } // ColumnSubstitute substitutes the columns in filter to expressions in select fields. // e.g. select * from (select b as a from t) k where a < 10 => select * from (select b as a from t where b < 10) k. func ColumnSubstitute(expr Expression, schema *Schema, newExprs []Expression) Expression { switch v := expr.(type) { case *Column: id := schema.ColumnIndex(v) if id == -1 { return v } return newExprs[id].Clone() case *ScalarFunction: if v.FuncName.L == ast.Cast { newFunc := v.Clone().(*ScalarFunction) newFunc.GetArgs()[0] = ColumnSubstitute(newFunc.GetArgs()[0], schema, newExprs) return newFunc } newArgs := make([]Expression, 0, len(v.GetArgs())) for _, arg := range v.GetArgs() { newArgs = append(newArgs, ColumnSubstitute(arg, schema, newExprs)) } return NewFunctionInternal(v.GetCtx(), v.FuncName.L, v.RetType, newArgs...) } return expr } // getValidPrefix gets a prefix of string which can parsed to a number with base. the minimum base is 2 and the maximum is 36. func getValidPrefix(s string, base int64) string { var ( validLen int upper rune ) switch { case base >= 2 && base <= 9: upper = rune('0' + base) case base <= 36: upper = rune('A' + base - 10) default: return "" } Loop: for i := 0; i < len(s); i++ { c := rune(s[i]) switch { case unicode.IsDigit(c) || unicode.IsLower(c) || unicode.IsUpper(c): c = unicode.ToUpper(c) if c < upper { validLen = i + 1 } else { break Loop } case c == '+' || c == '-': if i != 0 { break Loop } default: break Loop } } if validLen > 1 && s[0] == '+' { return s[1:validLen] } return s[:validLen] } // SubstituteCorCol2Constant will substitute correlated column to constant value which it contains. // If the args of one scalar function are all constant, we will substitute it to constant. func SubstituteCorCol2Constant(expr Expression) (Expression, error) { switch x := expr.(type) { case *ScalarFunction: allConstant := true newArgs := make([]Expression, 0, len(x.GetArgs())) for _, arg := range x.GetArgs() { newArg, err := SubstituteCorCol2Constant(arg) if err != nil { return nil, errors.Trace(err) } _, ok := newArg.(*Constant) newArgs = append(newArgs, newArg) allConstant = allConstant && ok } if allConstant { val, err := x.Eval(nil) if err != nil { return nil, errors.Trace(err) } return &Constant{Value: val, RetType: x.GetType()}, nil } var newSf Expression if x.FuncName.L == ast.Cast { newSf = BuildCastFunction(x.GetCtx(), newArgs[0], x.RetType) } else { newSf = NewFunctionInternal(x.GetCtx(), x.FuncName.L, x.GetType(), newArgs...) } return newSf, nil case *CorrelatedColumn: return &Constant{Value: *x.Data, RetType: x.GetType()}, nil case *Constant: if x.DeferredExpr != nil { newExpr := FoldConstant(x) return &Constant{Value: newExpr.(*Constant).Value, RetType: x.GetType()}, nil } } return expr.Clone(), nil } // timeZone2Duration converts timezone whose format should satisfy the regular condition // `(^(+|-)(0?[0-9]|1[0-2]):[0-5]?\d$)|(^+13:00$)` to time.Duration. func timeZone2Duration(tz string) time.Duration { sign := 1 if strings.HasPrefix(tz, "-") { sign = -1 } i := strings.Index(tz, ":") h, err := strconv.Atoi(tz[1:i]) terror.Log(errors.Trace(err)) m, err := strconv.Atoi(tz[i+1:]) terror.Log(errors.Trace(err)) return time.Duration(sign) * (time.Duration(h)*time.Hour + time.Duration(m)*time.Minute) } var oppositeOp = map[string]string{ ast.LT: ast.GE, ast.GE: ast.LT, ast.GT: ast.LE, ast.LE: ast.GT, ast.EQ: ast.NE, ast.NE: ast.EQ, } // a op b is equal to b symmetricOp a var symmetricOp = map[opcode.Op]opcode.Op{ opcode.LT: opcode.GT, opcode.GE: opcode.LE, opcode.GT: opcode.LT, opcode.LE: opcode.GE, opcode.EQ: opcode.EQ, opcode.NE: opcode.NE, opcode.NullEQ: opcode.NullEQ, } // PushDownNot pushes the `not` function down to the expression's arguments. func PushDownNot(ctx sessionctx.Context, expr Expression, not bool) Expression { if f, ok := expr.(*ScalarFunction); ok { switch f.FuncName.L { case ast.UnaryNot: return PushDownNot(f.GetCtx(), f.GetArgs()[0], !not) case ast.LT, ast.GE, ast.GT, ast.LE, ast.EQ, ast.NE: if not { return NewFunctionInternal(f.GetCtx(), oppositeOp[f.FuncName.L], f.GetType(), f.GetArgs()...) } for i, arg := range f.GetArgs() { f.GetArgs()[i] = PushDownNot(f.GetCtx(), arg, false) } return f case ast.LogicAnd: if not { args := f.GetArgs() for i, a := range args { args[i] = PushDownNot(f.GetCtx(), a, true) } return NewFunctionInternal(f.GetCtx(), ast.LogicOr, f.GetType(), args...) } for i, arg := range f.GetArgs() { f.GetArgs()[i] = PushDownNot(f.GetCtx(), arg, false) } return f case ast.LogicOr: if not { args := f.GetArgs() for i, a := range args { args[i] = PushDownNot(f.GetCtx(), a, true) } return NewFunctionInternal(f.GetCtx(), ast.LogicAnd, f.GetType(), args...) } for i, arg := range f.GetArgs() { f.GetArgs()[i] = PushDownNot(f.GetCtx(), arg, false) } return f } } if not { expr = NewFunctionInternal(ctx, ast.UnaryNot, types.NewFieldType(mysql.TypeTiny), expr) } return expr } // Contains tests if `exprs` contains `e`. func Contains(exprs []Expression, e Expression) bool { for _, expr := range exprs { if e == expr { return true } } return false } // ExtractFiltersFromDNFs checks whether the cond is DNF. If so, it will get the extracted part and the remained part. // The original DNF will be replaced by the remained part or just be deleted if remained part is nil. // And the extracted part will be appended to the end of the orignal slice. func ExtractFiltersFromDNFs(ctx sessionctx.Context, conditions []Expression) []Expression { var allExtracted []Expression for i := len(conditions) - 1; i >= 0; i-- { if sf, ok := conditions[i].(*ScalarFunction); ok && sf.FuncName.L == ast.LogicOr { extracted, remained := extractFiltersFromDNF(ctx, sf) allExtracted = append(allExtracted, extracted...) if remained == nil { conditions = append(conditions[:i], conditions[i+1:]...) } else { conditions[i] = remained } } } return append(conditions, allExtracted...) } // extractFiltersFromDNF extracts the same condition that occurs in every DNF item and remove them from dnf leaves. func extractFiltersFromDNF(ctx sessionctx.Context, dnfFunc *ScalarFunction) ([]Expression, Expression) { dnfItems := FlattenDNFConditions(dnfFunc) sc := ctx.GetSessionVars().StmtCtx codeMap := make(map[string]int) hashcode2Expr := make(map[string]Expression) for i, dnfItem := range dnfItems { innerMap := make(map[string]struct{}) cnfItems := SplitCNFItems(dnfItem) for _, cnfItem := range cnfItems { code := cnfItem.HashCode(sc) if i == 0 { codeMap[hack.String(code)] = 1 hashcode2Expr[hack.String(code)] = cnfItem } else if _, ok := codeMap[hack.String(code)]; ok { // We need this check because there may be the case like `select * from t, t1 where (t.a=t1.a and t.a=t1.a) or (something). // We should make sure that the two `t.a=t1.a` contributes only once. // TODO: do this out of this function. if _, ok = innerMap[hack.String(code)]; !ok { codeMap[hack.String(code)]++ innerMap[hack.String(code)] = struct{}{} } } } } // We should make sure that this item occurs in every DNF item. for hashcode, cnt := range codeMap { if cnt < len(dnfItems) { delete(hashcode2Expr, hashcode) } } if len(hashcode2Expr) == 0 { return nil, dnfFunc } newDNFItems := make([]Expression, 0, len(dnfItems)) onlyNeedExtracted := false for _, dnfItem := range dnfItems { cnfItems := SplitCNFItems(dnfItem) newCNFItems := make([]Expression, 0, len(cnfItems)) for _, cnfItem := range cnfItems { code := cnfItem.HashCode(sc) _, ok := hashcode2Expr[hack.String(code)] if !ok { newCNFItems = append(newCNFItems, cnfItem) } } // If the extracted part is just one leaf of the DNF expression. Then the value of the total DNF expression is // always the same with the value of the extracted part. if len(newCNFItems) == 0 { onlyNeedExtracted = true break } newDNFItems = append(newDNFItems, ComposeCNFCondition(ctx, newCNFItems...)) } extractedExpr := make([]Expression, 0, len(hashcode2Expr)) for _, expr := range hashcode2Expr { extractedExpr = append(extractedExpr, expr) } if onlyNeedExtracted { return extractedExpr, nil } return extractedExpr, ComposeDNFCondition(ctx, newDNFItems...) }