Files
tidb/plan/typeinferer.go

608 lines
19 KiB
Go

// Copyright 2015 PingCAP, Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// See the License for the specific language governing permissions and
// limitations under the License.
package plan
import (
"strings"
"github.com/juju/errors"
"github.com/ngaut/log"
"github.com/pingcap/tidb/ast"
"github.com/pingcap/tidb/mysql"
"github.com/pingcap/tidb/parser/opcode"
"github.com/pingcap/tidb/sessionctx/variable"
"github.com/pingcap/tidb/util/charset"
"github.com/pingcap/tidb/util/types"
)
// InferType infers result type for ast.ExprNode.
func InferType(sc *variable.StatementContext, node ast.Node) error {
var inferrer typeInferrer
inferrer.sc = sc
// TODO: get the default charset from ctx
inferrer.defaultCharset = "utf8"
node.Accept(&inferrer)
return inferrer.err
}
type typeInferrer struct {
sc *variable.StatementContext
err error
defaultCharset string
}
func (v *typeInferrer) Enter(in ast.Node) (out ast.Node, skipChildren bool) {
return in, false
}
func (v *typeInferrer) Leave(in ast.Node) (out ast.Node, ok bool) {
switch x := in.(type) {
case *ast.AggregateFuncExpr:
v.aggregateFunc(x)
case *ast.BetweenExpr:
x.SetType(types.NewFieldType(mysql.TypeLonglong))
x.Type.Charset = charset.CharsetBin
x.Type.Collate = charset.CollationBin
case *ast.BinaryOperationExpr:
v.binaryOperation(x)
case *ast.CaseExpr:
v.handleCaseExpr(x)
case *ast.ColumnNameExpr:
x.SetType(&x.Refer.Column.FieldType)
case *ast.CompareSubqueryExpr:
x.SetType(types.NewFieldType(mysql.TypeLonglong))
x.Type.Charset = charset.CharsetBin
x.Type.Collate = charset.CollationBin
case *ast.ExistsSubqueryExpr:
x.SetType(types.NewFieldType(mysql.TypeLonglong))
x.Type.Charset = charset.CharsetBin
x.Type.Collate = charset.CollationBin
case *ast.FuncCallExpr:
v.handleFuncCallExpr(x)
case *ast.FuncCastExpr:
// Copy a new field type.
tp := *x.Tp
x.SetType(&tp)
if len(x.Type.Charset) == 0 {
x.Type.Charset, x.Type.Collate = types.DefaultCharsetForType(x.Type.Tp)
}
case *ast.IsNullExpr:
x.SetType(types.NewFieldType(mysql.TypeLonglong))
x.Type.Charset = charset.CharsetBin
x.Type.Collate = charset.CollationBin
case *ast.IsTruthExpr:
x.SetType(types.NewFieldType(mysql.TypeLonglong))
x.Type.Charset = charset.CharsetBin
x.Type.Collate = charset.CollationBin
case *ast.ParamMarkerExpr:
types.DefaultTypeForValue(x.GetValue(), x.GetType())
case *ast.ParenthesesExpr:
x.SetType(x.Expr.GetType())
case *ast.PatternInExpr:
x.SetType(types.NewFieldType(mysql.TypeLonglong))
x.Type.Charset = charset.CharsetBin
x.Type.Collate = charset.CollationBin
v.convertValueToColumnTypeIfNeeded(x)
case *ast.PatternLikeExpr:
v.handleLikeExpr(x)
case *ast.PatternRegexpExpr:
v.handleRegexpExpr(x)
case *ast.SelectStmt:
v.selectStmt(x)
case *ast.UnaryOperationExpr:
v.unaryOperation(x)
case *ast.ValueExpr:
v.handleValueExpr(x)
case *ast.ValuesExpr:
v.handleValuesExpr(x)
case *ast.VariableExpr:
x.SetType(types.NewFieldType(mysql.TypeVarString))
x.Type.Charset = v.defaultCharset
cln, err := charset.GetDefaultCollation(v.defaultCharset)
if err != nil {
v.err = err
}
x.Type.Collate = cln
// TODO: handle all expression types.
}
return in, true
}
func (v *typeInferrer) selectStmt(x *ast.SelectStmt) {
rf := x.GetResultFields()
for _, val := range rf {
// column ID is 0 means it is not a real column from table, but a temporary column,
// so its type is not pre-defined, we need to set it.
if val.Column.ID == 0 && val.Expr.GetType() != nil {
val.Column.FieldType = *(val.Expr.GetType())
}
}
}
func (v *typeInferrer) aggregateFunc(x *ast.AggregateFuncExpr) {
name := strings.ToLower(x.F)
switch name {
case ast.AggFuncCount:
ft := types.NewFieldType(mysql.TypeLonglong)
ft.Flen = 21
ft.Charset = charset.CharsetBin
ft.Collate = charset.CollationBin
x.SetType(ft)
case ast.AggFuncMax, ast.AggFuncMin:
x.SetType(x.Args[0].GetType())
case ast.AggFuncSum, ast.AggFuncAvg:
ft := types.NewFieldType(mysql.TypeNewDecimal)
ft.Charset = charset.CharsetBin
ft.Collate = charset.CollationBin
ft.Decimal = x.Args[0].GetType().Decimal
x.SetType(ft)
case ast.AggFuncGroupConcat:
ft := types.NewFieldType(mysql.TypeVarString)
ft.Charset = v.defaultCharset
cln, err := charset.GetDefaultCollation(v.defaultCharset)
if err != nil {
v.err = err
}
ft.Collate = cln
x.SetType(ft)
}
}
func (v *typeInferrer) binaryOperation(x *ast.BinaryOperationExpr) {
switch x.Op {
case opcode.AndAnd, opcode.OrOr, opcode.LogicXor:
x.Type.Init(mysql.TypeLonglong)
case opcode.LT, opcode.LE, opcode.GE, opcode.GT, opcode.EQ, opcode.NE, opcode.NullEQ:
x.Type.Init(mysql.TypeLonglong)
case opcode.RightShift, opcode.LeftShift, opcode.And, opcode.Or, opcode.Xor:
x.Type.Init(mysql.TypeLonglong)
x.Type.Flag |= mysql.UnsignedFlag
case opcode.IntDiv:
x.Type.Init(mysql.TypeLonglong)
case opcode.Plus, opcode.Minus, opcode.Mul, opcode.Mod:
if x.L.GetType() != nil && x.R.GetType() != nil {
xTp := mergeArithType(x.L.GetType(), x.R.GetType())
x.Type.Init(xTp)
leftUnsigned := x.L.GetType().Flag & mysql.UnsignedFlag
rightUnsigned := x.R.GetType().Flag & mysql.UnsignedFlag
// If both operands are unsigned, result is unsigned.
x.Type.Flag |= (leftUnsigned & rightUnsigned)
}
case opcode.Div:
if x.L.GetType() != nil && x.R.GetType() != nil {
xTp := mergeArithType(x.L.GetType(), x.R.GetType())
if xTp == mysql.TypeLonglong {
xTp = mysql.TypeNewDecimal
}
x.Type.Init(xTp)
}
}
x.Type.Charset = charset.CharsetBin
x.Type.Collate = charset.CollationBin
}
// toArithType converts DateTime, Duration and Timestamp types to NewDecimal type if Decimal > 0.
func toArithType(ft *types.FieldType) (tp byte) {
tp = ft.Tp
if types.IsTypeFractionable(tp) {
if ft.Decimal > 0 {
tp = mysql.TypeNewDecimal
} else {
tp = mysql.TypeLonglong
}
}
return
}
func mergeArithType(fta, ftb *types.FieldType) byte {
a, b := toArithType(fta), toArithType(ftb)
switch a {
case mysql.TypeString, mysql.TypeVarchar, mysql.TypeVarString, mysql.TypeDouble, mysql.TypeFloat:
return mysql.TypeDouble
}
switch b {
case mysql.TypeString, mysql.TypeVarchar, mysql.TypeVarString, mysql.TypeDouble, mysql.TypeFloat:
return mysql.TypeDouble
}
if a == mysql.TypeNewDecimal || b == mysql.TypeNewDecimal {
return mysql.TypeNewDecimal
}
return mysql.TypeLonglong
}
func mergeCmpType(fta, ftb *types.FieldType) (ft *types.FieldType) {
ft = &types.FieldType{}
if fta.Charset == charset.CharsetUTF8 && ftb.Charset == charset.CharsetUTF8 {
ft.Charset = charset.CharsetUTF8
ft.Collate = mysql.UTF8DefaultCollation
} else {
ft.Flag |= mysql.BinaryFlag
}
isFtaTime, isFtbTime := types.IsTypeFractionable(fta.Tp), types.IsTypeFractionable(ftb.Tp)
if types.IsTypeBlob(fta.Tp) || types.IsTypeBlob(ftb.Tp) {
ft.Tp = mysql.TypeBlob
} else if types.IsTypeVarchar(fta.Tp) || types.IsTypeVarchar(ftb.Tp) {
ft.Tp = mysql.TypeVarString
} else if types.IsTypeChar(fta.Tp) || types.IsTypeChar(ftb.Tp) {
ft.Tp = mysql.TypeString
} else if isFtaTime && isFtbTime {
ft.Tp = mysql.TypeDatetime
} else if isFtaTime || isFtbTime {
ft.Tp = mysql.TypeVarString
} else if fta.Tp == mysql.TypeEnum || ftb.Tp == mysql.TypeEnum || fta.Tp == mysql.TypeSet || ftb.Tp == mysql.TypeSet {
ft.Tp = mysql.TypeString
} else if fta.Tp == mysql.TypeDouble || ftb.Tp == mysql.TypeDouble {
ft.Tp = mysql.TypeDouble
} else if fta.Tp == mysql.TypeFloat || ftb.Tp == mysql.TypeFloat {
ft.Tp = mysql.TypeFloat
} else if fta.Tp == mysql.TypeNewDecimal || ftb.Tp == mysql.TypeNewDecimal {
ft.Tp = mysql.TypeNewDecimal
} else if fta.Tp == mysql.TypeLonglong || ftb.Tp == mysql.TypeLonglong {
ft.Tp = mysql.TypeLonglong
} else {
ft.Tp = mysql.TypeLong
}
return ft
}
func (v *typeInferrer) unaryOperation(x *ast.UnaryOperationExpr) {
switch x.Op {
case opcode.Not:
x.Type.Init(mysql.TypeLonglong)
case opcode.BitNeg:
x.Type.Init(mysql.TypeLonglong)
x.Type.Flag |= mysql.UnsignedFlag
case opcode.Plus:
x.Type = *x.V.GetType()
case opcode.Minus:
x.Type.Init(mysql.TypeLonglong)
if x.V.GetType() != nil {
switch x.V.GetType().Tp {
case mysql.TypeString, mysql.TypeVarchar, mysql.TypeVarString, mysql.TypeDouble, mysql.TypeFloat, mysql.TypeDatetime, mysql.TypeDuration, mysql.TypeTimestamp:
x.Type.Tp = mysql.TypeDouble
case mysql.TypeNewDecimal:
x.Type.Tp = mysql.TypeNewDecimal
}
}
}
x.Type.Charset = charset.CharsetBin
x.Type.Collate = charset.CollationBin
}
func (v *typeInferrer) handleValueExpr(x *ast.ValueExpr) {
types.DefaultTypeForValue(x.GetValue(), x.GetType())
}
func (v *typeInferrer) handleValuesExpr(x *ast.ValuesExpr) {
x.SetType(x.Column.GetType())
}
func (v *typeInferrer) getFsp(x *ast.FuncCallExpr) int {
if len(x.Args) == 1 {
fsp, err := x.Args[0].GetDatum().ToInt64(v.sc)
if err != nil {
v.err = err
}
return int(fsp)
}
return 0
}
func (v *typeInferrer) handleFuncCallExpr(x *ast.FuncCallExpr) {
var (
tp *types.FieldType
chs = charset.CharsetBin
)
switch x.FnName.L {
case "abs", "ifnull", "nullif":
tp = x.Args[0].GetType()
// TODO: We should cover all types.
if x.FnName.L == "abs" && tp.Tp == mysql.TypeDatetime {
tp = types.NewFieldType(mysql.TypeDouble)
}
case "round":
t := x.Args[0].GetType().Tp
switch t {
case mysql.TypeBit, mysql.TypeTiny, mysql.TypeShort, mysql.TypeInt24, mysql.TypeLonglong:
tp = types.NewFieldType(mysql.TypeLonglong)
case mysql.TypeNewDecimal:
tp = types.NewFieldType(mysql.TypeNewDecimal)
default:
tp = types.NewFieldType(mysql.TypeDouble)
}
case "greatest", "least":
for _, arg := range x.Args {
InferType(v.sc, arg)
}
if len(x.Args) > 0 {
tp = x.Args[0].GetType()
for i := 1; i < len(x.Args); i++ {
tp = mergeCmpType(tp, x.Args[i].GetType())
}
}
case "interval":
tp = types.NewFieldType(mysql.TypeLonglong)
case "ceil", "ceiling", "floor":
t := x.Args[0].GetType().Tp
if t == mysql.TypeNull || t == mysql.TypeFloat || t == mysql.TypeDouble || t == mysql.TypeVarchar ||
t == mysql.TypeTinyBlob || t == mysql.TypeMediumBlob || t == mysql.TypeLongBlob ||
t == mysql.TypeBlob || t == mysql.TypeVarString || t == mysql.TypeString {
tp = types.NewFieldType(mysql.TypeDouble)
} else {
tp = types.NewFieldType(mysql.TypeLonglong)
}
case "ln", "log", "log2", "log10", "sqrt", "pi", "exp", "degrees":
tp = types.NewFieldType(mysql.TypeDouble)
case "sin":
tp = types.NewFieldType(mysql.TypeDouble)
case "acos", "asin", "atan":
tp = types.NewFieldType(mysql.TypeDouble)
case "pow", "power", "rand":
tp = types.NewFieldType(mysql.TypeDouble)
case "radians":
tp = types.NewFieldType(mysql.TypeDouble)
case "curdate", "current_date", "date", "from_days":
tp = types.NewFieldType(mysql.TypeDate)
case "curtime", "current_time", "timediff", "maketime":
tp = types.NewFieldType(mysql.TypeDuration)
tp.Decimal = v.getFsp(x)
case "date_add", "date_sub", "adddate", "subdate", "timestamp":
tp = types.NewFieldType(mysql.TypeDatetime)
case "microsecond", "second", "minute", "hour", "day", "week", "month", "year",
"dayofweek", "dayofmonth", "dayofyear", "weekday", "weekofyear", "yearweek", "datediff",
"found_rows", "length", "extract", "locate", "unix_timestamp", "quarter", "is_ipv4":
tp = types.NewFieldType(mysql.TypeLonglong)
case "now", "sysdate", "current_timestamp", "utc_timestamp":
tp = types.NewFieldType(mysql.TypeDatetime)
tp.Decimal = v.getFsp(x)
case "from_unixtime":
if len(x.Args) == 1 {
tp = types.NewFieldType(mysql.TypeDatetime)
} else {
tp = types.NewFieldType(mysql.TypeVarString)
chs = v.defaultCharset
}
case "str_to_date":
tp = types.NewFieldType(mysql.TypeDatetime)
case "dayname", "version", "database", "user", "current_user", "schema",
"concat", "concat_ws", "left", "lcase", "lower", "repeat",
"replace", "ucase", "upper", "convert", "substring", "elt",
"substring_index", "trim", "ltrim", "rtrim", "reverse", "hex", "unhex",
"date_format", "rpad", "lpad", "char_func", "conv", "make_set", "oct", "uuid",
"insert_func", "bin":
tp = types.NewFieldType(mysql.TypeVarString)
chs = v.defaultCharset
case "strcmp", "isnull", "bit_length", "char_length", "character_length", "crc32", "timestampdiff",
"sign", "is_ipv6", "ord", "instr":
tp = types.NewFieldType(mysql.TypeLonglong)
case "connection_id":
tp = types.NewFieldType(mysql.TypeLonglong)
tp.Flag |= mysql.UnsignedFlag
case "find_in_set", ast.Field:
tp = types.NewFieldType(mysql.TypeLonglong)
case "if":
// TODO: fix this
// See https://dev.mysql.com/doc/refman/5.5/en/control-flow-functions.html#function_if
// The default return type of IF() (which may matter when it is stored into a temporary table) is calculated as follows.
// Expression Return Value
// expr2 or expr3 returns a string string
// expr2 or expr3 returns a floating-point value floating-point
// expr2 or expr3 returns an integer integer
tp = x.Args[1].GetType()
case "get_lock", "release_lock":
tp = types.NewFieldType(mysql.TypeLonglong)
case ast.AesEncrypt, ast.AesDecrypt, ast.SHA2, ast.InetNtoa:
tp = types.NewFieldType(mysql.TypeVarString)
chs = v.defaultCharset
case ast.MD5:
tp = types.NewFieldType(mysql.TypeVarString)
chs = v.defaultCharset
tp.Flen = 32
case ast.SHA, ast.SHA1:
tp = types.NewFieldType(mysql.TypeVarString)
chs = v.defaultCharset
tp.Flen = 40
case ast.RandomBytes:
tp = types.NewFieldType(mysql.TypeVarString)
case ast.Coalesce:
tp = aggFieldType(x.Args)
if tp.Tp == mysql.TypeVarchar {
tp.Tp = mysql.TypeVarString
}
classType := aggTypeClass(x.Args, &tp.Flag)
if classType == types.ClassString && !mysql.HasBinaryFlag(tp.Flag) {
tp.Charset, tp.Collate = types.DefaultCharsetForType(tp.Tp)
}
case ast.AnyValue:
tp = x.Args[0].GetType()
default:
tp = types.NewFieldType(mysql.TypeUnspecified)
}
// If charset is unspecified.
if len(tp.Charset) == 0 {
tp.Charset = chs
cln := charset.CollationBin
if chs != charset.CharsetBin {
var err error
cln, err = charset.GetDefaultCollation(chs)
if err != nil {
v.err = err
}
}
tp.Collate = cln
}
x.SetType(tp)
}
// The return type of a CASE expression is the compatible aggregated type of all return values,
// but also depends on the context in which it is used.
// If used in a string context, the result is returned as a string.
// If used in a numeric context, the result is returned as a decimal, real, or integer value.
func (v *typeInferrer) handleCaseExpr(x *ast.CaseExpr) {
exprs := make([]ast.ExprNode, 0, len(x.WhenClauses)+1)
for _, w := range x.WhenClauses {
exprs = append(exprs, w.Result)
}
if x.ElseClause != nil {
exprs = append(exprs, x.ElseClause)
}
tp := aggFieldType(exprs)
if tp.Tp == mysql.TypeVarchar {
tp.Tp = mysql.TypeVarString
}
classType := aggTypeClass(exprs, &tp.Flag)
if classType == types.ClassString && !mysql.HasBinaryFlag(tp.Flag) {
tp.Charset, tp.Collate = types.DefaultCharsetForType(tp.Tp)
} else {
tp.Charset = charset.CharsetBin
tp.Collate = charset.CollationBin
}
x.SetType(tp)
}
// like expression expects the target expression and pattern to be a string, if it's not, we add a cast function.
func (v *typeInferrer) handleLikeExpr(x *ast.PatternLikeExpr) {
x.SetType(types.NewFieldType(mysql.TypeLonglong))
x.Type.Charset = charset.CharsetBin
x.Type.Collate = charset.CollationBin
x.Expr = v.addCastToString(x.Expr)
x.Pattern = v.addCastToString(x.Pattern)
}
// regexp expression expects the target expression and pattern to be a string, if it's not, we add a cast function.
func (v *typeInferrer) handleRegexpExpr(x *ast.PatternRegexpExpr) {
x.SetType(types.NewFieldType(mysql.TypeLonglong))
x.Type.Charset = charset.CharsetBin
x.Type.Collate = charset.CollationBin
x.Expr = v.addCastToString(x.Expr)
x.Pattern = v.addCastToString(x.Pattern)
}
// AddCastToString adds a cast function to string type if the expr charset is not UTF8.
func (v *typeInferrer) addCastToString(expr ast.ExprNode) ast.ExprNode {
if !mysql.IsUTF8Charset(expr.GetType().Charset) {
castTp := types.NewFieldType(mysql.TypeString)
castTp.Charset, castTp.Collate = types.DefaultCharsetForType(mysql.TypeString)
if val, ok := expr.(*ast.ValueExpr); ok {
newVal, err := val.Datum.ConvertTo(v.sc, castTp)
if err != nil {
v.err = errors.Trace(err)
}
expr.SetDatum(newVal)
} else {
castFunc := &ast.FuncCastExpr{
Expr: expr,
Tp: castTp,
FunctionType: ast.CastFunction,
}
expr = castFunc
}
expr.SetType(castTp)
}
return expr
}
// ConvertValueToColumnTypeIfNeeded checks if the expr in PatternInExpr is column name,
// and casts function to the items in the list.
func (v *typeInferrer) convertValueToColumnTypeIfNeeded(x *ast.PatternInExpr) {
if cn, ok := x.Expr.(*ast.ColumnNameExpr); ok && cn.Refer != nil {
ft := cn.Refer.Column.FieldType
for _, expr := range x.List {
if valueExpr, ok := expr.(*ast.ValueExpr); ok {
newDatum, err := valueExpr.Datum.ConvertTo(v.sc, &ft)
if err != nil {
v.err = errors.Trace(err)
}
cmp, err := newDatum.CompareDatum(v.sc, valueExpr.Datum)
if err != nil {
v.err = errors.Trace(err)
}
if cmp != 0 {
// The value will never match the column, do not set newDatum.
continue
}
valueExpr.SetDatum(newDatum)
}
}
if v.err != nil {
// TODO: Errors should be handled differently according to query context.
log.Errorf("inferor type for pattern in error %v", v.err)
v.err = nil
}
}
}
func aggFieldType(args []ast.ExprNode) *types.FieldType {
var currType types.FieldType
for _, arg := range args {
t := arg.GetType()
if currType.Tp == mysql.TypeUnspecified {
currType = *t
continue
}
mtp := types.MergeFieldType(currType.Tp, t.Tp)
currType.Tp = mtp
}
return &currType
}
func aggTypeClass(args []ast.ExprNode, flag *uint) types.TypeClass {
var (
tpClass = types.ClassString
unsigned bool
gotFirst bool
gotBinString bool
)
for _, arg := range args {
argFieldType := arg.GetType()
if argFieldType.Tp == mysql.TypeNull {
continue
}
argTypeClass := argFieldType.ToClass()
if argTypeClass == types.ClassString && mysql.HasBinaryFlag(argFieldType.Flag) {
gotBinString = true
}
if !gotFirst {
gotFirst = true
tpClass = argTypeClass
unsigned = mysql.HasUnsignedFlag(argFieldType.Flag)
} else {
tpClass = mergeTypeClass(tpClass, argTypeClass, unsigned, mysql.HasUnsignedFlag(argFieldType.Flag))
unsigned = unsigned && mysql.HasUnsignedFlag(argFieldType.Flag)
}
}
setTypeFlag(flag, uint(mysql.UnsignedFlag), unsigned)
setTypeFlag(flag, uint(mysql.BinaryFlag), tpClass != types.ClassString || gotBinString)
return tpClass
}
func setTypeFlag(flag *uint, flagItem uint, on bool) {
if on {
*flag |= flagItem
} else {
*flag &= ^flagItem
}
}
func mergeTypeClass(a, b types.TypeClass, aUnsigned, bUnsigned bool) types.TypeClass {
if a == types.ClassString || b == types.ClassString {
return types.ClassString
} else if a == types.ClassReal || b == types.ClassReal {
return types.ClassReal
} else if a == types.ClassDecimal || b == types.ClassDecimal || aUnsigned != bUnsigned {
return types.ClassDecimal
}
return types.ClassInt
}