Files
tidb/expression/expression.go
Han Fei 10359b0052 rewrite executor. (#1294)
* rewrite executor.
2016-06-07 20:32:15 +08:00

378 lines
10 KiB
Go

// Copyright 2016 PingCAP, Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// See the License for the specific language governing permissions and
// limitations under the License.
package expression
import (
"fmt"
"github.com/juju/errors"
"github.com/pingcap/tidb/ast"
"github.com/pingcap/tidb/context"
"github.com/pingcap/tidb/evaluator"
"github.com/pingcap/tidb/model"
"github.com/pingcap/tidb/parser/opcode"
"github.com/pingcap/tidb/util/types"
)
// Rewrite rewrites ast to Expression.
func Rewrite(expr ast.ExprNode, schema Schema, AggrMapper map[*ast.AggregateFuncExpr]int) (newExpr Expression, err error) {
er := &expressionRewriter{schema: schema, aggrMap: AggrMapper}
expr.Accept(er)
if er.err != nil {
return nil, errors.Trace(er.err)
}
if len(er.ctxStack) != 1 {
return nil, errors.Errorf("context len %v is invalid", len(er.ctxStack))
}
return er.ctxStack[0], nil
}
type expressionRewriter struct {
ctxStack []Expression
schema Schema
err error
aggrMap map[*ast.AggregateFuncExpr]int
}
// Expression represents all scalar expression in SQL.
type Expression interface {
// Eval evaluates an expression through a row.
Eval(row []types.Datum, ctx context.Context) (types.Datum, error)
// Get the expression return type.
GetType() *types.FieldType
// DeepCopy copies an expression totally.
DeepCopy() Expression
// ToString converts an expression into a string.
ToString() string
}
// EvalBool evaluates expression to a boolean value.
func EvalBool(expr Expression, row []types.Datum, ctx context.Context) (bool, error) {
data, err := expr.Eval(row, ctx)
if err != nil {
return false, errors.Trace(err)
}
if data.Kind() == types.KindNull {
return false, nil
}
i, err := data.ToBool()
if err != nil {
return false, errors.Trace(err)
}
return i != 0, nil
}
// Column represents a column.
type Column struct {
FromID string
ColName model.CIStr
DbName model.CIStr
TblName model.CIStr
RetType *types.FieldType
// only used during execution
Index int
}
// ToString implements Expression interface.
func (col *Column) ToString() string {
result := col.ColName.L
if col.TblName.L != "" {
result = col.TblName.L + "." + result
}
if col.DbName.L != "" {
result = col.DbName.L + "." + result
}
return result
}
// GetType implements Expression interface.
func (col *Column) GetType() *types.FieldType {
return col.RetType
}
// Eval implements Expression interface.
func (col *Column) Eval(row []types.Datum, _ context.Context) (d types.Datum, err error) {
return row[col.Index], nil
}
// DeepCopy implements Expression interface.
func (col *Column) DeepCopy() Expression {
newCol := *col
return &newCol
}
// Schema stands for the row schema get from input.
type Schema []*Column
// DeepCopy copies the total schema.
func (s Schema) DeepCopy() Schema {
result := make(Schema, 0, len(s))
for _, col := range s {
newCol := *col
result = append(result, &newCol)
}
return result
}
// FindColumn replaces an ast column with an expression column.
func (s Schema) FindColumn(astCol *ast.ColumnName) (*Column, error) {
dbName, tblName, colName := astCol.Schema, astCol.Table, astCol.Name
idx := -1
for i, col := range s {
if (dbName.L == "" || dbName.L == col.DbName.L) && (tblName.L == "" || tblName.L == col.TblName.L) && (colName.L == col.ColName.L) {
if idx != -1 {
return nil, errors.Errorf("Column '%s' is ambiguous", colName.L)
}
idx = i
}
}
if idx == -1 {
return nil, errors.Errorf("Unknown column %s %s %s.", dbName.L, tblName.L, colName.L)
}
return s[idx], nil
}
// InitIndices sets indices for columns in schema.
func (s Schema) InitIndices() {
for i, c := range s {
c.Index = i
}
}
// RetrieveColumn replaces column in expression with column in schema.
func (s Schema) RetrieveColumn(col *Column) *Column {
for _, c := range s {
if c.FromID == col.FromID && c.ColName.L == col.ColName.L {
return c
}
}
return nil
}
// GetIndex finds the index for a column.
func (s Schema) GetIndex(col *Column) int {
for i, c := range s {
if c.FromID == col.FromID && c.ColName.L == col.ColName.L {
return i
}
}
return -1
}
// ScalarFunction is the function that returns a value.
type ScalarFunction struct {
Args []Expression
FuncName model.CIStr
// TODO: Implement type inference here, now we use ast's return type temporarily.
retType *types.FieldType
function evaluator.BuiltinFunc
}
// ToString implements Expression interface.
func (sf *ScalarFunction) ToString() string {
result := sf.FuncName.L + "("
for _, arg := range sf.Args {
result += arg.ToString()
result += ","
}
result += ")"
return result
}
// NewFunction creates a new scalar function.
func NewFunction(funcName model.CIStr, args []Expression) *ScalarFunction {
return &ScalarFunction{Args: args, FuncName: funcName, function: evaluator.Funcs[funcName.L].F}
}
//Schema2Exprs converts []*Column to []Expression.
func Schema2Exprs(schema Schema) []Expression {
result := make([]Expression, 0, len(schema))
for _, col := range schema {
result = append(result, col)
}
return result
}
//ScalarFuncs2Exprs converts []*ScalarFunction to []Expression.
func ScalarFuncs2Exprs(funcs []*ScalarFunction) []Expression {
result := make([]Expression, 0, len(funcs))
for _, col := range funcs {
result = append(result, col)
}
return result
}
// DeepCopy implements Expression interface.
func (sf *ScalarFunction) DeepCopy() Expression {
newFunc := &ScalarFunction{FuncName: sf.FuncName, function: sf.function, retType: sf.retType}
for _, arg := range sf.Args {
newFunc.Args = append(newFunc.Args, arg.DeepCopy())
}
return newFunc
}
// GetType implements Expression interface.
func (sf *ScalarFunction) GetType() *types.FieldType {
return sf.retType
}
// Eval implements Expression interface.
func (sf *ScalarFunction) Eval(row []types.Datum, ctx context.Context) (types.Datum, error) {
args := make([]types.Datum, 0, len(sf.Args))
for _, arg := range sf.Args {
result, err := arg.Eval(row, ctx)
if err == nil {
args = append(args, result)
} else {
return types.Datum{}, errors.Trace(err)
}
}
return sf.function(args, ctx)
}
// Constant stands for a constant value.
type Constant struct {
Value types.Datum
RetType *types.FieldType
}
// ToString implements Expression interface.
func (c *Constant) ToString() string {
return fmt.Sprintf("%v", c.Value.GetValue())
}
// DeepCopy implements Expression interface.
func (c *Constant) DeepCopy() Expression {
con := *c
return &con
}
// GetType implements Expression interface.
func (c *Constant) GetType() *types.FieldType {
return c.RetType
}
// Eval implements Expression interface.
func (c *Constant) Eval(_ []types.Datum, _ context.Context) (types.Datum, error) {
return c.Value, nil
}
// Enter implements Visitor interface.
func (er *expressionRewriter) Enter(inNode ast.Node) (retNode ast.Node, skipChildren bool) {
switch v := inNode.(type) {
case *ast.AggregateFuncExpr:
index, ok := -1, false
if er.aggrMap != nil {
index, ok = er.aggrMap[v]
}
if !ok {
er.err = errors.New("Can't appear aggrFunctions")
return inNode, true
}
er.ctxStack = append(er.ctxStack, er.schema[index])
return inNode, true
}
return inNode, false
}
// Leave implements Visitor interface.
func (er *expressionRewriter) Leave(inNode ast.Node) (retNode ast.Node, ok bool) {
length := len(er.ctxStack)
switch v := inNode.(type) {
case *ast.AggregateFuncExpr:
case *ast.FuncCallExpr:
function := &ScalarFunction{FuncName: v.FnName}
for i := length - len(v.Args); i < length; i++ {
function.Args = append(function.Args, er.ctxStack[i])
}
f := evaluator.Funcs[v.FnName.L]
if len(function.Args) < f.MinArgs || (f.MaxArgs != -1 && len(function.Args) > f.MaxArgs) {
er.err = evaluator.ErrInvalidOperation.Gen("number of function arguments must in [%d, %d].", f.MinArgs, f.MaxArgs)
return retNode, false
}
function.function = f.F
function.retType = v.Type
er.ctxStack = er.ctxStack[:length-len(v.Args)]
er.ctxStack = append(er.ctxStack, function)
case *ast.ColumnName:
column, err := er.schema.FindColumn(v)
if err != nil {
er.err = errors.Trace(err)
return retNode, false
}
er.ctxStack = append(er.ctxStack, column)
case *ast.ColumnNameExpr, *ast.ParenthesesExpr, *ast.WhenClause:
case *ast.ValueExpr:
value := &Constant{Value: v.Datum, RetType: v.Type}
er.ctxStack = append(er.ctxStack, value)
case *ast.IsNullExpr:
function := &ScalarFunction{
Args: []Expression{er.ctxStack[length-1]},
FuncName: model.NewCIStr("isnull"),
retType: v.Type,
}
f, ok := evaluator.Funcs[function.FuncName.L]
if !ok {
er.err = errors.New("Can't find function!")
return retNode, false
}
function.function = f.F
er.ctxStack = er.ctxStack[:length-1]
er.ctxStack = append(er.ctxStack, function)
case *ast.BinaryOperationExpr:
function := &ScalarFunction{Args: []Expression{er.ctxStack[length-2], er.ctxStack[length-1]}, retType: v.Type}
funcName, ok := opcode.Ops[v.Op]
if !ok {
er.err = errors.Errorf("Unknown opcode %v", v.Op)
return retNode, false
}
function.FuncName = model.NewCIStr(funcName)
f, ok := evaluator.Funcs[function.FuncName.L]
if !ok {
er.err = errors.New("Can't find function!")
return retNode, false
}
function.function = f.F
er.ctxStack = er.ctxStack[:length-2]
er.ctxStack = append(er.ctxStack, function)
case *ast.UnaryOperationExpr:
function := &ScalarFunction{Args: []Expression{er.ctxStack[length-1]}, retType: v.Type}
switch v.Op {
case opcode.Not:
function.FuncName = model.NewCIStr("not")
case opcode.BitNeg:
function.FuncName = model.NewCIStr("bitneg")
case opcode.Plus:
function.FuncName = model.NewCIStr("unaryplus")
case opcode.Minus:
function.FuncName = model.NewCIStr("unaryminus")
}
f, ok := evaluator.Funcs[function.FuncName.L]
if !ok {
er.err = errors.New("Can't find function!")
return retNode, false
}
function.function = f.F
er.ctxStack = er.ctxStack[:length-1]
er.ctxStack = append(er.ctxStack, function)
default:
er.err = errors.Errorf("UnkownType: %T", v)
}
return inNode, true
}