276 lines
7.6 KiB
Go
276 lines
7.6 KiB
Go
// Copyright 2016 PingCAP, Inc.
|
|
//
|
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
|
// you may not use this file except in compliance with the License.
|
|
// You may obtain a copy of the License at
|
|
//
|
|
// http://www.apache.org/licenses/LICENSE-2.0
|
|
//
|
|
// Unless required by applicable law or agreed to in writing, software
|
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
|
// See the License for the specific language governing permissions and
|
|
// limitations under the License.
|
|
|
|
package expression
|
|
|
|
import (
|
|
"strconv"
|
|
"strings"
|
|
"time"
|
|
"unicode"
|
|
|
|
"github.com/juju/errors"
|
|
"github.com/pingcap/tidb/ast"
|
|
"github.com/pingcap/tidb/context"
|
|
"github.com/pingcap/tidb/mysql"
|
|
"github.com/pingcap/tidb/parser/opcode"
|
|
"github.com/pingcap/tidb/terror"
|
|
"github.com/pingcap/tidb/types"
|
|
)
|
|
|
|
// Filter the input expressions, append the results to result.
|
|
func Filter(result []Expression, input []Expression, filter func(Expression) bool) []Expression {
|
|
for _, e := range input {
|
|
if filter(e) {
|
|
result = append(result, e)
|
|
}
|
|
}
|
|
return result
|
|
}
|
|
|
|
// ExtractColumns extracts all columns from an expression.
|
|
func ExtractColumns(expr Expression) (cols []*Column) {
|
|
// Pre-allocate a slice to reduce allocation, 8 doesn't have special meaning.
|
|
result := make([]*Column, 0, 8)
|
|
return extractColumns(result, expr, nil)
|
|
}
|
|
|
|
// ExtractColumnsFromExpressions is a more effecient version of ExtractColumns for batch operation.
|
|
// filter can be nil, or a function to filter the result column.
|
|
// It's often observed that the pattern of the caller like this:
|
|
//
|
|
// cols := ExtractColumns(...)
|
|
// for _, col := range cols {
|
|
// if xxx(col) {...}
|
|
// }
|
|
//
|
|
// Provide an additional filter argument, this can be done in one step.
|
|
// To avoid allocation for cols that not need.
|
|
func ExtractColumnsFromExpressions(result []*Column, exprs []Expression, filter func(*Column) bool) []*Column {
|
|
for _, expr := range exprs {
|
|
result = extractColumns(result, expr, filter)
|
|
}
|
|
return result
|
|
}
|
|
|
|
func extractColumns(result []*Column, expr Expression, filter func(*Column) bool) []*Column {
|
|
switch v := expr.(type) {
|
|
case *Column:
|
|
if filter == nil || filter(v) {
|
|
result = append(result, v)
|
|
}
|
|
case *ScalarFunction:
|
|
for _, arg := range v.GetArgs() {
|
|
result = extractColumns(result, arg, filter)
|
|
}
|
|
}
|
|
return result
|
|
}
|
|
|
|
// ColumnSubstitute substitutes the columns in filter to expressions in select fields.
|
|
// e.g. select * from (select b as a from t) k where a < 10 => select * from (select b as a from t where b < 10) k.
|
|
func ColumnSubstitute(expr Expression, schema *Schema, newExprs []Expression) Expression {
|
|
switch v := expr.(type) {
|
|
case *Column:
|
|
id := schema.ColumnIndex(v)
|
|
if id == -1 {
|
|
return v
|
|
}
|
|
return newExprs[id].Clone()
|
|
case *ScalarFunction:
|
|
if v.FuncName.L == ast.Cast {
|
|
newFunc := v.Clone().(*ScalarFunction)
|
|
newFunc.GetArgs()[0] = ColumnSubstitute(newFunc.GetArgs()[0], schema, newExprs)
|
|
return newFunc
|
|
}
|
|
newArgs := make([]Expression, 0, len(v.GetArgs()))
|
|
for _, arg := range v.GetArgs() {
|
|
newArgs = append(newArgs, ColumnSubstitute(arg, schema, newExprs))
|
|
}
|
|
return NewFunctionInternal(v.GetCtx(), v.FuncName.L, v.RetType, newArgs...)
|
|
}
|
|
return expr
|
|
}
|
|
|
|
// getValidPrefix gets a prefix of string which can parsed to a number with base. the minimum base is 2 and the maximum is 36.
|
|
func getValidPrefix(s string, base int64) string {
|
|
var (
|
|
validLen int
|
|
upper rune
|
|
)
|
|
switch {
|
|
case base >= 2 && base <= 9:
|
|
upper = rune('0' + base)
|
|
case base <= 36:
|
|
upper = rune('A' + base - 10)
|
|
default:
|
|
return ""
|
|
}
|
|
Loop:
|
|
for i := 0; i < len(s); i++ {
|
|
c := rune(s[i])
|
|
switch {
|
|
case unicode.IsDigit(c) || unicode.IsLower(c) || unicode.IsUpper(c):
|
|
c = unicode.ToUpper(c)
|
|
if c < upper {
|
|
validLen = i + 1
|
|
} else {
|
|
break Loop
|
|
}
|
|
case c == '+' || c == '-':
|
|
if i != 0 {
|
|
break Loop
|
|
}
|
|
default:
|
|
break Loop
|
|
}
|
|
}
|
|
if validLen > 1 && s[0] == '+' {
|
|
return s[1:validLen]
|
|
}
|
|
return s[:validLen]
|
|
}
|
|
|
|
// SubstituteCorCol2Constant will substitute correlated column to constant value which it contains.
|
|
// If the args of one scalar function are all constant, we will substitute it to constant.
|
|
func SubstituteCorCol2Constant(expr Expression) (Expression, error) {
|
|
switch x := expr.(type) {
|
|
case *ScalarFunction:
|
|
allConstant := true
|
|
newArgs := make([]Expression, 0, len(x.GetArgs()))
|
|
for _, arg := range x.GetArgs() {
|
|
newArg, err := SubstituteCorCol2Constant(arg)
|
|
if err != nil {
|
|
return nil, errors.Trace(err)
|
|
}
|
|
_, ok := newArg.(*Constant)
|
|
newArgs = append(newArgs, newArg)
|
|
allConstant = allConstant && ok
|
|
}
|
|
if allConstant {
|
|
val, err := x.Eval(nil)
|
|
if err != nil {
|
|
return nil, errors.Trace(err)
|
|
}
|
|
return &Constant{Value: val, RetType: x.GetType()}, nil
|
|
}
|
|
var newSf Expression
|
|
if x.FuncName.L == ast.Cast {
|
|
newSf = BuildCastFunction(x.GetCtx(), newArgs[0], x.RetType)
|
|
} else {
|
|
newSf = NewFunctionInternal(x.GetCtx(), x.FuncName.L, x.GetType(), newArgs...)
|
|
}
|
|
return newSf, nil
|
|
case *CorrelatedColumn:
|
|
return &Constant{Value: *x.Data, RetType: x.GetType()}, nil
|
|
case *Constant:
|
|
if x.DeferredExpr != nil {
|
|
newExpr := FoldConstant(x)
|
|
return &Constant{Value: newExpr.(*Constant).Value, RetType: x.GetType()}, nil
|
|
}
|
|
}
|
|
return expr.Clone(), nil
|
|
}
|
|
|
|
// timeZone2Duration converts timezone whose format should satisfy the regular condition
|
|
// `(^(+|-)(0?[0-9]|1[0-2]):[0-5]?\d$)|(^+13:00$)` to time.Duration.
|
|
func timeZone2Duration(tz string) time.Duration {
|
|
sign := 1
|
|
if strings.HasPrefix(tz, "-") {
|
|
sign = -1
|
|
}
|
|
|
|
i := strings.Index(tz, ":")
|
|
h, err := strconv.Atoi(tz[1:i])
|
|
terror.Log(errors.Trace(err))
|
|
m, err := strconv.Atoi(tz[i+1:])
|
|
terror.Log(errors.Trace(err))
|
|
return time.Duration(sign) * (time.Duration(h)*time.Hour + time.Duration(m)*time.Minute)
|
|
}
|
|
|
|
var oppositeOp = map[string]string{
|
|
ast.LT: ast.GE,
|
|
ast.GE: ast.LT,
|
|
ast.GT: ast.LE,
|
|
ast.LE: ast.GT,
|
|
ast.EQ: ast.NE,
|
|
ast.NE: ast.EQ,
|
|
}
|
|
|
|
// a op b is equal to b symmetricOp a
|
|
var symmetricOp = map[opcode.Op]opcode.Op{
|
|
opcode.LT: opcode.GT,
|
|
opcode.GE: opcode.LE,
|
|
opcode.GT: opcode.LT,
|
|
opcode.LE: opcode.GE,
|
|
opcode.EQ: opcode.EQ,
|
|
opcode.NE: opcode.NE,
|
|
opcode.NullEQ: opcode.NullEQ,
|
|
}
|
|
|
|
// PushDownNot pushes the `not` function down to the expression's arguments.
|
|
func PushDownNot(expr Expression, not bool, ctx context.Context) Expression {
|
|
if f, ok := expr.(*ScalarFunction); ok {
|
|
switch f.FuncName.L {
|
|
case ast.UnaryNot:
|
|
return PushDownNot(f.GetArgs()[0], !not, f.GetCtx())
|
|
case ast.LT, ast.GE, ast.GT, ast.LE, ast.EQ, ast.NE:
|
|
if not {
|
|
return NewFunctionInternal(f.GetCtx(), oppositeOp[f.FuncName.L], f.GetType(), f.GetArgs()...)
|
|
}
|
|
for i, arg := range f.GetArgs() {
|
|
f.GetArgs()[i] = PushDownNot(arg, false, f.GetCtx())
|
|
}
|
|
return f
|
|
case ast.LogicAnd:
|
|
if not {
|
|
args := f.GetArgs()
|
|
for i, a := range args {
|
|
args[i] = PushDownNot(a, true, f.GetCtx())
|
|
}
|
|
return NewFunctionInternal(f.GetCtx(), ast.LogicOr, f.GetType(), args...)
|
|
}
|
|
for i, arg := range f.GetArgs() {
|
|
f.GetArgs()[i] = PushDownNot(arg, false, f.GetCtx())
|
|
}
|
|
return f
|
|
case ast.LogicOr:
|
|
if not {
|
|
args := f.GetArgs()
|
|
for i, a := range args {
|
|
args[i] = PushDownNot(a, true, f.GetCtx())
|
|
}
|
|
return NewFunctionInternal(f.GetCtx(), ast.LogicAnd, f.GetType(), args...)
|
|
}
|
|
for i, arg := range f.GetArgs() {
|
|
f.GetArgs()[i] = PushDownNot(arg, false, f.GetCtx())
|
|
}
|
|
return f
|
|
}
|
|
}
|
|
if not {
|
|
expr = NewFunctionInternal(ctx, ast.UnaryNot, types.NewFieldType(mysql.TypeTiny), expr)
|
|
}
|
|
return expr
|
|
}
|
|
|
|
// Contains tests if `exprs` contains `e`.
|
|
func Contains(exprs []Expression, e Expression) bool {
|
|
for _, expr := range exprs {
|
|
if e == expr {
|
|
return true
|
|
}
|
|
}
|
|
return false
|
|
}
|