378 lines
10 KiB
Go
378 lines
10 KiB
Go
// Copyright 2016 PingCAP, Inc.
|
|
//
|
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
|
// you may not use this file except in compliance with the License.
|
|
// You may obtain a copy of the License at
|
|
//
|
|
// http://www.apache.org/licenses/LICENSE-2.0
|
|
//
|
|
// Unless required by applicable law or agreed to in writing, software
|
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
|
// See the License for the specific language governing permissions and
|
|
// limitations under the License.
|
|
|
|
package expression
|
|
|
|
import (
|
|
"fmt"
|
|
"github.com/juju/errors"
|
|
"github.com/pingcap/tidb/ast"
|
|
"github.com/pingcap/tidb/context"
|
|
"github.com/pingcap/tidb/evaluator"
|
|
"github.com/pingcap/tidb/model"
|
|
"github.com/pingcap/tidb/parser/opcode"
|
|
"github.com/pingcap/tidb/util/types"
|
|
)
|
|
|
|
// Rewrite rewrites ast to Expression.
|
|
func Rewrite(expr ast.ExprNode, schema Schema, AggrMapper map[*ast.AggregateFuncExpr]int) (newExpr Expression, err error) {
|
|
er := &expressionRewriter{schema: schema, aggrMap: AggrMapper}
|
|
expr.Accept(er)
|
|
if er.err != nil {
|
|
return nil, errors.Trace(er.err)
|
|
}
|
|
if len(er.ctxStack) != 1 {
|
|
return nil, errors.Errorf("context len %v is invalid", len(er.ctxStack))
|
|
}
|
|
return er.ctxStack[0], nil
|
|
}
|
|
|
|
type expressionRewriter struct {
|
|
ctxStack []Expression
|
|
schema Schema
|
|
err error
|
|
aggrMap map[*ast.AggregateFuncExpr]int
|
|
}
|
|
|
|
// Expression represents all scalar expression in SQL.
|
|
type Expression interface {
|
|
// Eval evaluates an expression through a row.
|
|
Eval(row []types.Datum, ctx context.Context) (types.Datum, error)
|
|
|
|
// Get the expression return type.
|
|
GetType() *types.FieldType
|
|
|
|
// DeepCopy copies an expression totally.
|
|
DeepCopy() Expression
|
|
|
|
// ToString converts an expression into a string.
|
|
ToString() string
|
|
}
|
|
|
|
// EvalBool evaluates expression to a boolean value.
|
|
func EvalBool(expr Expression, row []types.Datum, ctx context.Context) (bool, error) {
|
|
data, err := expr.Eval(row, ctx)
|
|
if err != nil {
|
|
return false, errors.Trace(err)
|
|
}
|
|
if data.Kind() == types.KindNull {
|
|
return false, nil
|
|
}
|
|
|
|
i, err := data.ToBool()
|
|
if err != nil {
|
|
return false, errors.Trace(err)
|
|
}
|
|
return i != 0, nil
|
|
}
|
|
|
|
// Column represents a column.
|
|
type Column struct {
|
|
FromID string
|
|
ColName model.CIStr
|
|
DbName model.CIStr
|
|
TblName model.CIStr
|
|
RetType *types.FieldType
|
|
|
|
// only used during execution
|
|
Index int
|
|
}
|
|
|
|
// ToString implements Expression interface.
|
|
func (col *Column) ToString() string {
|
|
result := col.ColName.L
|
|
if col.TblName.L != "" {
|
|
result = col.TblName.L + "." + result
|
|
}
|
|
if col.DbName.L != "" {
|
|
result = col.DbName.L + "." + result
|
|
}
|
|
return result
|
|
}
|
|
|
|
// GetType implements Expression interface.
|
|
func (col *Column) GetType() *types.FieldType {
|
|
return col.RetType
|
|
}
|
|
|
|
// Eval implements Expression interface.
|
|
func (col *Column) Eval(row []types.Datum, _ context.Context) (d types.Datum, err error) {
|
|
return row[col.Index], nil
|
|
}
|
|
|
|
// DeepCopy implements Expression interface.
|
|
func (col *Column) DeepCopy() Expression {
|
|
newCol := *col
|
|
return &newCol
|
|
}
|
|
|
|
// Schema stands for the row schema get from input.
|
|
type Schema []*Column
|
|
|
|
// DeepCopy copies the total schema.
|
|
func (s Schema) DeepCopy() Schema {
|
|
result := make(Schema, 0, len(s))
|
|
for _, col := range s {
|
|
newCol := *col
|
|
result = append(result, &newCol)
|
|
}
|
|
return result
|
|
}
|
|
|
|
// FindColumn replaces an ast column with an expression column.
|
|
func (s Schema) FindColumn(astCol *ast.ColumnName) (*Column, error) {
|
|
dbName, tblName, colName := astCol.Schema, astCol.Table, astCol.Name
|
|
idx := -1
|
|
for i, col := range s {
|
|
if (dbName.L == "" || dbName.L == col.DbName.L) && (tblName.L == "" || tblName.L == col.TblName.L) && (colName.L == col.ColName.L) {
|
|
if idx != -1 {
|
|
return nil, errors.Errorf("Column '%s' is ambiguous", colName.L)
|
|
}
|
|
idx = i
|
|
}
|
|
}
|
|
if idx == -1 {
|
|
return nil, errors.Errorf("Unknown column %s %s %s.", dbName.L, tblName.L, colName.L)
|
|
}
|
|
return s[idx], nil
|
|
}
|
|
|
|
// InitIndices sets indices for columns in schema.
|
|
func (s Schema) InitIndices() {
|
|
for i, c := range s {
|
|
c.Index = i
|
|
}
|
|
}
|
|
|
|
// RetrieveColumn replaces column in expression with column in schema.
|
|
func (s Schema) RetrieveColumn(col *Column) *Column {
|
|
for _, c := range s {
|
|
if c.FromID == col.FromID && c.ColName.L == col.ColName.L {
|
|
return c
|
|
}
|
|
}
|
|
return nil
|
|
}
|
|
|
|
// GetIndex finds the index for a column.
|
|
func (s Schema) GetIndex(col *Column) int {
|
|
for i, c := range s {
|
|
if c.FromID == col.FromID && c.ColName.L == col.ColName.L {
|
|
return i
|
|
}
|
|
}
|
|
return -1
|
|
}
|
|
|
|
// ScalarFunction is the function that returns a value.
|
|
type ScalarFunction struct {
|
|
Args []Expression
|
|
FuncName model.CIStr
|
|
// TODO: Implement type inference here, now we use ast's return type temporarily.
|
|
retType *types.FieldType
|
|
function evaluator.BuiltinFunc
|
|
}
|
|
|
|
// ToString implements Expression interface.
|
|
func (sf *ScalarFunction) ToString() string {
|
|
result := sf.FuncName.L + "("
|
|
for _, arg := range sf.Args {
|
|
result += arg.ToString()
|
|
result += ","
|
|
}
|
|
result += ")"
|
|
return result
|
|
}
|
|
|
|
// NewFunction creates a new scalar function.
|
|
func NewFunction(funcName model.CIStr, args []Expression) *ScalarFunction {
|
|
return &ScalarFunction{Args: args, FuncName: funcName, function: evaluator.Funcs[funcName.L].F}
|
|
}
|
|
|
|
//Schema2Exprs converts []*Column to []Expression.
|
|
func Schema2Exprs(schema Schema) []Expression {
|
|
result := make([]Expression, 0, len(schema))
|
|
for _, col := range schema {
|
|
result = append(result, col)
|
|
}
|
|
return result
|
|
}
|
|
|
|
//ScalarFuncs2Exprs converts []*ScalarFunction to []Expression.
|
|
func ScalarFuncs2Exprs(funcs []*ScalarFunction) []Expression {
|
|
result := make([]Expression, 0, len(funcs))
|
|
for _, col := range funcs {
|
|
result = append(result, col)
|
|
}
|
|
return result
|
|
}
|
|
|
|
// DeepCopy implements Expression interface.
|
|
func (sf *ScalarFunction) DeepCopy() Expression {
|
|
newFunc := &ScalarFunction{FuncName: sf.FuncName, function: sf.function, retType: sf.retType}
|
|
for _, arg := range sf.Args {
|
|
newFunc.Args = append(newFunc.Args, arg.DeepCopy())
|
|
}
|
|
return newFunc
|
|
}
|
|
|
|
// GetType implements Expression interface.
|
|
func (sf *ScalarFunction) GetType() *types.FieldType {
|
|
return sf.retType
|
|
}
|
|
|
|
// Eval implements Expression interface.
|
|
func (sf *ScalarFunction) Eval(row []types.Datum, ctx context.Context) (types.Datum, error) {
|
|
args := make([]types.Datum, 0, len(sf.Args))
|
|
for _, arg := range sf.Args {
|
|
result, err := arg.Eval(row, ctx)
|
|
if err == nil {
|
|
args = append(args, result)
|
|
} else {
|
|
return types.Datum{}, errors.Trace(err)
|
|
}
|
|
}
|
|
return sf.function(args, ctx)
|
|
}
|
|
|
|
// Constant stands for a constant value.
|
|
type Constant struct {
|
|
Value types.Datum
|
|
RetType *types.FieldType
|
|
}
|
|
|
|
// ToString implements Expression interface.
|
|
func (c *Constant) ToString() string {
|
|
return fmt.Sprintf("%v", c.Value.GetValue())
|
|
}
|
|
|
|
// DeepCopy implements Expression interface.
|
|
func (c *Constant) DeepCopy() Expression {
|
|
con := *c
|
|
return &con
|
|
}
|
|
|
|
// GetType implements Expression interface.
|
|
func (c *Constant) GetType() *types.FieldType {
|
|
return c.RetType
|
|
}
|
|
|
|
// Eval implements Expression interface.
|
|
func (c *Constant) Eval(_ []types.Datum, _ context.Context) (types.Datum, error) {
|
|
return c.Value, nil
|
|
}
|
|
|
|
// Enter implements Visitor interface.
|
|
func (er *expressionRewriter) Enter(inNode ast.Node) (retNode ast.Node, skipChildren bool) {
|
|
switch v := inNode.(type) {
|
|
case *ast.AggregateFuncExpr:
|
|
index, ok := -1, false
|
|
if er.aggrMap != nil {
|
|
index, ok = er.aggrMap[v]
|
|
}
|
|
if !ok {
|
|
er.err = errors.New("Can't appear aggrFunctions")
|
|
return inNode, true
|
|
}
|
|
er.ctxStack = append(er.ctxStack, er.schema[index])
|
|
return inNode, true
|
|
}
|
|
return inNode, false
|
|
}
|
|
|
|
// Leave implements Visitor interface.
|
|
func (er *expressionRewriter) Leave(inNode ast.Node) (retNode ast.Node, ok bool) {
|
|
length := len(er.ctxStack)
|
|
switch v := inNode.(type) {
|
|
case *ast.AggregateFuncExpr:
|
|
case *ast.FuncCallExpr:
|
|
function := &ScalarFunction{FuncName: v.FnName}
|
|
for i := length - len(v.Args); i < length; i++ {
|
|
function.Args = append(function.Args, er.ctxStack[i])
|
|
}
|
|
f := evaluator.Funcs[v.FnName.L]
|
|
if len(function.Args) < f.MinArgs || (f.MaxArgs != -1 && len(function.Args) > f.MaxArgs) {
|
|
er.err = evaluator.ErrInvalidOperation.Gen("number of function arguments must in [%d, %d].", f.MinArgs, f.MaxArgs)
|
|
return retNode, false
|
|
}
|
|
function.function = f.F
|
|
function.retType = v.Type
|
|
er.ctxStack = er.ctxStack[:length-len(v.Args)]
|
|
er.ctxStack = append(er.ctxStack, function)
|
|
case *ast.ColumnName:
|
|
column, err := er.schema.FindColumn(v)
|
|
if err != nil {
|
|
er.err = errors.Trace(err)
|
|
return retNode, false
|
|
}
|
|
er.ctxStack = append(er.ctxStack, column)
|
|
case *ast.ColumnNameExpr, *ast.ParenthesesExpr, *ast.WhenClause:
|
|
case *ast.ValueExpr:
|
|
value := &Constant{Value: v.Datum, RetType: v.Type}
|
|
er.ctxStack = append(er.ctxStack, value)
|
|
case *ast.IsNullExpr:
|
|
function := &ScalarFunction{
|
|
Args: []Expression{er.ctxStack[length-1]},
|
|
FuncName: model.NewCIStr("isnull"),
|
|
retType: v.Type,
|
|
}
|
|
f, ok := evaluator.Funcs[function.FuncName.L]
|
|
if !ok {
|
|
er.err = errors.New("Can't find function!")
|
|
return retNode, false
|
|
}
|
|
function.function = f.F
|
|
er.ctxStack = er.ctxStack[:length-1]
|
|
er.ctxStack = append(er.ctxStack, function)
|
|
case *ast.BinaryOperationExpr:
|
|
function := &ScalarFunction{Args: []Expression{er.ctxStack[length-2], er.ctxStack[length-1]}, retType: v.Type}
|
|
funcName, ok := opcode.Ops[v.Op]
|
|
if !ok {
|
|
er.err = errors.Errorf("Unknown opcode %v", v.Op)
|
|
return retNode, false
|
|
}
|
|
function.FuncName = model.NewCIStr(funcName)
|
|
f, ok := evaluator.Funcs[function.FuncName.L]
|
|
if !ok {
|
|
er.err = errors.New("Can't find function!")
|
|
return retNode, false
|
|
}
|
|
function.function = f.F
|
|
er.ctxStack = er.ctxStack[:length-2]
|
|
er.ctxStack = append(er.ctxStack, function)
|
|
case *ast.UnaryOperationExpr:
|
|
function := &ScalarFunction{Args: []Expression{er.ctxStack[length-1]}, retType: v.Type}
|
|
switch v.Op {
|
|
case opcode.Not:
|
|
function.FuncName = model.NewCIStr("not")
|
|
case opcode.BitNeg:
|
|
function.FuncName = model.NewCIStr("bitneg")
|
|
case opcode.Plus:
|
|
function.FuncName = model.NewCIStr("unaryplus")
|
|
case opcode.Minus:
|
|
function.FuncName = model.NewCIStr("unaryminus")
|
|
}
|
|
f, ok := evaluator.Funcs[function.FuncName.L]
|
|
if !ok {
|
|
er.err = errors.New("Can't find function!")
|
|
return retNode, false
|
|
}
|
|
function.function = f.F
|
|
er.ctxStack = er.ctxStack[:length-1]
|
|
er.ctxStack = append(er.ctxStack, function)
|
|
default:
|
|
er.err = errors.Errorf("UnkownType: %T", v)
|
|
}
|
|
return inNode, true
|
|
}
|