Files
tidb/ast/functions.go
2016-12-06 13:34:52 +08:00

540 lines
14 KiB
Go

// Copyright 2015 PingCAP, Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// See the License for the specific language governing permissions and
// limitations under the License.
package ast
import (
"bytes"
"fmt"
"strings"
"github.com/juju/errors"
"github.com/pingcap/tidb/model"
"github.com/pingcap/tidb/sessionctx/variable"
"github.com/pingcap/tidb/util/distinct"
"github.com/pingcap/tidb/util/types"
)
var (
_ FuncNode = &AggregateFuncExpr{}
_ FuncNode = &FuncCallExpr{}
_ FuncNode = &FuncCastExpr{}
)
// List scalar function names.
const (
AndAnd = "and"
Cast = "cast"
LeftShift = "leftshift"
RightShift = "rightshift"
OrOr = "or"
GE = "ge"
LE = "le"
EQ = "eq"
NE = "ne"
LT = "lt"
GT = "gt"
Plus = "plus"
Minus = "minus"
And = "bitand"
Or = "bitor"
Mod = "mod"
Xor = "bitxor"
Div = "div"
Mul = "mul"
UnaryNot = "not" // Avoid name conflict with Not in github/pingcap/check.
BitNeg = "bitneg"
IntDiv = "intdiv"
LogicXor = "xor"
NullEQ = "nulleq"
UnaryPlus = "unaryplus"
UnaryMinus = "unaryminus"
In = "in"
Like = "like"
Case = "case"
Regexp = "regexp"
IsNull = "isnull"
IsTruth = "istrue" // Avoid name conflict with IsTrue in github/pingcap/check.
IsFalsity = "isfalse" // Avoid name conflict with IsFalse in github/pingcap/check.
RowFunc = "row"
SetVar = "setvar"
GetVar = "getvar"
// common functions
Coalesce = "coalesce"
Greatest = "greatest"
// math functions
Abs = "abs"
Ceil = "ceil"
Ceiling = "ceiling"
Pow = "pow"
Power = "power"
Rand = "rand"
Round = "round"
// time functions
Curdate = "curdate"
CurrentDate = "current_date"
CurrentTime = "current_time"
CurrentTimestamp = "current_timestamp"
Curtime = "curtime"
Date = "date"
DateArith = "date_arith"
DateFormat = "date_format"
Day = "day"
DayName = "dayname"
DayOfMonth = "dayofmonth"
DayOfWeek = "dayofweek"
DayOfYear = "dayofyear"
Extract = "extract"
Hour = "hour"
MicroSecond = "microsecond"
Minute = "minute"
Month = "month"
MonthName = "monthname"
Now = "now"
Second = "second"
StrToDate = "str_to_date"
Sysdate = "sysdate"
Time = "time"
UTCDate = "utc_date"
Week = "week"
Weekday = "weekday"
WeekOfYear = "weekofyear"
Year = "year"
YearWeek = "yearweek"
FromUnixTime = "from_unixtime"
// string functions
ASCII = "ascii"
Concat = "concat"
ConcatWS = "concat_ws"
Convert = "convert"
Lcase = "lcase"
Left = "left"
Length = "length"
Locate = "locate"
Lower = "lower"
Ltrim = "ltrim"
Repeat = "repeat"
Replace = "replace"
Reverse = "reverse"
Rtrim = "rtrim"
Space = "space"
Strcmp = "strcmp"
Substring = "substring"
SubstringIndex = "substring_index"
Trim = "trim"
Upper = "upper"
Ucase = "ucase"
Hex = "hex"
Unhex = "unhex"
// information functions
ConnectionID = "connection_id"
CurrentUser = "current_user"
Database = "database"
Schema = "schema"
FoundRows = "found_rows"
LastInsertId = "last_insert_id"
User = "user"
Version = "version"
// control functions
If = "if"
Ifnull = "ifnull"
Nullif = "nullif"
// miscellaneous functions
Sleep = "sleep"
// get_lock() and release_lock() is parsed but do nothing.
// It is used for preventing error in Ruby's activerecord migrations.
GetLock = "get_lock"
ReleaseLock = "release_lock"
)
// FuncCallExpr is for function expression.
type FuncCallExpr struct {
funcNode
// FnName is the function name.
FnName model.CIStr
// Args is the function args.
Args []ExprNode
}
// Accept implements Node interface.
func (n *FuncCallExpr) Accept(v Visitor) (Node, bool) {
newNode, skipChildren := v.Enter(n)
if skipChildren {
return v.Leave(newNode)
}
n = newNode.(*FuncCallExpr)
for i, val := range n.Args {
node, ok := val.Accept(v)
if !ok {
return n, false
}
n.Args[i] = node.(ExprNode)
}
return v.Leave(n)
}
// CastFunctionType is the type for cast function.
type CastFunctionType int
// CastFunction types
const (
CastFunction CastFunctionType = iota + 1
CastConvertFunction
CastBinaryOperator
)
// FuncCastExpr is the cast function converting value to another type, e.g, cast(expr AS signed).
// See https://dev.mysql.com/doc/refman/5.7/en/cast-functions.html
type FuncCastExpr struct {
funcNode
// Expr is the expression to be converted.
Expr ExprNode
// Tp is the conversion type.
Tp *types.FieldType
// Cast, Convert and Binary share this struct.
FunctionType CastFunctionType
}
// Accept implements Node Accept interface.
func (n *FuncCastExpr) Accept(v Visitor) (Node, bool) {
newNode, skipChildren := v.Enter(n)
if skipChildren {
return v.Leave(newNode)
}
n = newNode.(*FuncCastExpr)
node, ok := n.Expr.Accept(v)
if !ok {
return n, false
}
n.Expr = node.(ExprNode)
return v.Leave(n)
}
// TrimDirectionType is the type for trim direction.
type TrimDirectionType int
const (
// TrimBothDefault trims from both direction by default.
TrimBothDefault TrimDirectionType = iota
// TrimBoth trims from both direction with explicit notation.
TrimBoth
// TrimLeading trims from left.
TrimLeading
// TrimTrailing trims from right.
TrimTrailing
)
// DateArithType is type for DateArith type.
type DateArithType byte
const (
// DateAdd is to run adddate or date_add function option.
// See https://dev.mysql.com/doc/refman/5.7/en/date-and-time-functions.html#function_adddate
// See https://dev.mysql.com/doc/refman/5.7/en/date-and-time-functions.html#function_date-add
DateAdd DateArithType = iota + 1
// DateSub is to run subdate or date_sub function option.
// See https://dev.mysql.com/doc/refman/5.7/en/date-and-time-functions.html#function_subdate
// See https://dev.mysql.com/doc/refman/5.7/en/date-and-time-functions.html#function_date-sub
DateSub
)
// DateArithInterval is the struct of DateArith interval part.
type DateArithInterval struct {
Unit string
Interval ExprNode
}
const (
// AggFuncCount is the name of Count function.
AggFuncCount = "count"
// AggFuncSum is the name of Sum function.
AggFuncSum = "sum"
// AggFuncAvg is the name of Avg function.
AggFuncAvg = "avg"
// AggFuncFirstRow is the name of FirstRowColumn function.
AggFuncFirstRow = "firstrow"
// AggFuncMax is the name of max function.
AggFuncMax = "max"
// AggFuncMin is the name of min function.
AggFuncMin = "min"
// AggFuncGroupConcat is the name of group_concat function.
AggFuncGroupConcat = "group_concat"
)
// AggregateFuncExpr represents aggregate function expression.
type AggregateFuncExpr struct {
funcNode
// F is the function name.
F string
// Args is the function args.
Args []ExprNode
// If distinct is true, the function only aggregate distinct values.
// For example, column c1 values are "1", "2", "2", "sum(c1)" is "5",
// but "sum(distinct c1)" is "3".
Distinct bool
CurrentGroup []byte
// contextPerGroupMap is used to store aggregate evaluation context.
// Each entry for a group.
contextPerGroupMap map[string](*AggEvaluateContext)
}
// Accept implements Node Accept interface.
func (n *AggregateFuncExpr) Accept(v Visitor) (Node, bool) {
newNode, skipChildren := v.Enter(n)
if skipChildren {
return v.Leave(newNode)
}
n = newNode.(*AggregateFuncExpr)
for i, val := range n.Args {
node, ok := val.Accept(v)
if !ok {
return n, false
}
n.Args[i] = node.(ExprNode)
}
return v.Leave(n)
}
// Clear clears aggregate computing context.
func (n *AggregateFuncExpr) Clear() {
n.CurrentGroup = []byte{}
n.contextPerGroupMap = nil
}
// Update is used for update aggregate context.
func (n *AggregateFuncExpr) Update(sc *variable.StatementContext) error {
name := strings.ToLower(n.F)
switch name {
case AggFuncCount:
return n.updateCount(sc)
case AggFuncFirstRow:
return n.updateFirstRow(sc)
case AggFuncGroupConcat:
return n.updateGroupConcat(sc)
case AggFuncMax:
return n.updateMaxMin(sc, true)
case AggFuncMin:
return n.updateMaxMin(sc, false)
case AggFuncSum, AggFuncAvg:
return n.updateSum(sc)
}
return nil
}
// GetContext gets aggregate evaluation context for the current group.
// If it is nil, add a new context into contextPerGroupMap.
func (n *AggregateFuncExpr) GetContext() *AggEvaluateContext {
if n.contextPerGroupMap == nil {
n.contextPerGroupMap = make(map[string](*AggEvaluateContext))
}
if _, ok := n.contextPerGroupMap[string(n.CurrentGroup)]; !ok {
c := &AggEvaluateContext{}
if n.Distinct {
c.DistinctChecker = distinct.CreateDistinctChecker()
}
n.contextPerGroupMap[string(n.CurrentGroup)] = c
}
return n.contextPerGroupMap[string(n.CurrentGroup)]
}
// SetContext sets the aggregate expr evaluation context.
func (n *AggregateFuncExpr) SetContext(ctx map[string](*AggEvaluateContext)) {
n.contextPerGroupMap = ctx
}
func (n *AggregateFuncExpr) updateCount(sc *variable.StatementContext) error {
ctx := n.GetContext()
vals := make([]interface{}, 0, len(n.Args))
for _, a := range n.Args {
value := a.GetValue()
if value == nil {
return nil
}
vals = append(vals, value)
}
if n.Distinct {
d, err := ctx.DistinctChecker.Check(vals)
if err != nil {
return errors.Trace(err)
}
if !d {
return nil
}
}
ctx.Count++
return nil
}
func (n *AggregateFuncExpr) updateFirstRow(sc *variable.StatementContext) error {
ctx := n.GetContext()
if !ctx.Value.IsNull() {
return nil
}
if len(n.Args) != 1 {
return errors.New("Wrong number of args for AggFuncFirstRow")
}
ctx.Value = *n.Args[0].GetDatum()
return nil
}
func (n *AggregateFuncExpr) updateMaxMin(sc *variable.StatementContext, max bool) error {
ctx := n.GetContext()
if len(n.Args) != 1 {
return errors.New("Wrong number of args for AggFuncFirstRow")
}
v := *n.Args[0].GetDatum()
if ctx.Value.IsNull() {
ctx.Value = v
return nil
}
c, err := ctx.Value.CompareDatum(sc, v)
if err != nil {
return errors.Trace(err)
}
if max {
if c == -1 {
ctx.Value = v
}
} else {
if c == 1 {
ctx.Value = v
}
}
return nil
}
func (n *AggregateFuncExpr) updateSum(sc *variable.StatementContext) error {
ctx := n.GetContext()
value := *n.Args[0].GetDatum()
if value.IsNull() {
return nil
}
if n.Distinct {
d, err := ctx.DistinctChecker.Check([]interface{}{value.GetValue()})
if err != nil {
return errors.Trace(err)
}
if !d {
return nil
}
}
var err error
ctx.Value, err = types.CalculateSum(sc, ctx.Value, value)
if err != nil {
return errors.Trace(err)
}
ctx.Count++
return nil
}
func (n *AggregateFuncExpr) updateGroupConcat(sc *variable.StatementContext) error {
ctx := n.GetContext()
vals := make([]interface{}, 0, len(n.Args))
for _, a := range n.Args {
value := a.GetValue()
if value == nil {
return nil
}
vals = append(vals, value)
}
if n.Distinct {
d, err := ctx.DistinctChecker.Check(vals)
if err != nil {
return errors.Trace(err)
}
if !d {
return nil
}
}
if ctx.Buffer == nil {
ctx.Buffer = &bytes.Buffer{}
} else {
// now use comma separator
ctx.Buffer.WriteString(",")
}
for _, val := range vals {
ctx.Buffer.WriteString(fmt.Sprintf("%v", val))
}
// TODO: if total length is greater than global var group_concat_max_len, truncate it.
return nil
}
// AggregateFuncExtractor visits Expr tree.
// It converts ColunmNameExpr to AggregateFuncExpr and collects AggregateFuncExpr.
type AggregateFuncExtractor struct {
inAggregateFuncExpr bool
// AggFuncs is the collected AggregateFuncExprs.
AggFuncs []*AggregateFuncExpr
extracting bool
}
// Enter implements Visitor interface.
func (a *AggregateFuncExtractor) Enter(n Node) (node Node, skipChildren bool) {
switch n.(type) {
case *AggregateFuncExpr:
a.inAggregateFuncExpr = true
case *SelectStmt, *InsertStmt, *DeleteStmt, *UpdateStmt:
// Enter a new context, skip it.
// For example: select sum(c) + c + exists(select c from t) from t;
if a.extracting {
return n, true
}
}
a.extracting = true
return n, false
}
// Leave implements Visitor interface.
func (a *AggregateFuncExtractor) Leave(n Node) (node Node, ok bool) {
switch v := n.(type) {
case *AggregateFuncExpr:
a.inAggregateFuncExpr = false
a.AggFuncs = append(a.AggFuncs, v)
case *ColumnNameExpr:
// compose new AggregateFuncExpr
if !a.inAggregateFuncExpr {
// For example: select sum(c) + c from t;
// The c in sum() should be evaluated for each row.
// The c after plus should be evaluated only once.
agg := &AggregateFuncExpr{
F: AggFuncFirstRow,
Args: []ExprNode{v},
}
agg.SetFlag((v.GetFlag() | FlagHasAggregateFunc))
agg.SetType(v.GetType())
a.AggFuncs = append(a.AggFuncs, agg)
return agg, true
}
}
return n, true
}
// AggEvaluateContext is used to store intermediate result when calculating aggregate functions.
type AggEvaluateContext struct {
DistinctChecker *distinct.Checker
Count int64
Value types.Datum
Buffer *bytes.Buffer // Buffer is used for group_concat.
GotFirstRow bool // It will check if the agg has met the first row key.
}