Merge pull request #968 from pingcap/zxylvlp/change-datum-in-ast

evaluator: change Get/SetValue to Get/SetDatum.
This commit is contained in:
goroutine
2016-03-12 21:44:42 +08:00
6 changed files with 312 additions and 145 deletions

View File

@ -64,6 +64,10 @@ type ExprNode interface {
SetValue(val interface{})
// GetValue gets value of the expression.
GetValue() interface{}
// SetDatum sets datum to the expression.
SetDatum(datum types.Datum)
// GetDatum gets datum of the expression.
GetDatum() *types.Datum
// SetFlag sets flag to the expression.
// Flag indicates whether the expression contains
// parameter marker, reference, aggregate function...

View File

@ -67,6 +67,16 @@ type exprNode struct {
flag uint64
}
// SetDatum implements Expression interface.
func (en *exprNode) SetDatum(datum types.Datum) {
en.Datum = datum
}
// GetDatum implements Expression interface.
func (en *exprNode) GetDatum() *types.Datum {
return &en.Datum
}
// SetType implements Expression interface.
func (en *exprNode) SetType(tp *types.FieldType) {
en.Type = tp

View File

@ -178,38 +178,39 @@ func (e *Evaluator) between(v *ast.BetweenExpr) bool {
if e.err != nil {
return false
}
v.SetValue(ret.GetValue())
v.SetDatum(*ret.GetDatum())
return true
}
func (e *Evaluator) caseExpr(v *ast.CaseExpr) bool {
var target interface{} = true
tmp := types.NewDatum(boolToInt64(true))
target := &tmp
if v.Value != nil {
target = v.Value.GetValue()
target = v.Value.GetDatum()
}
if target != nil {
if target.Kind() != types.KindNull {
for _, val := range v.WhenClauses {
cmp, err := types.Compare(target, val.Expr.GetValue())
cmp, err := target.CompareDatum(*val.Expr.GetDatum())
if err != nil {
e.err = errors.Trace(err)
return false
}
if cmp == 0 {
v.SetValue(val.Result.GetValue())
v.SetDatum(*val.Result.GetDatum())
return true
}
}
}
if v.ElseClause != nil {
v.SetValue(v.ElseClause.GetValue())
v.SetDatum(*v.ElseClause.GetDatum())
} else {
v.SetValue(nil)
v.SetNull()
}
return true
}
func (e *Evaluator) columnName(v *ast.ColumnNameExpr) bool {
v.SetValue(v.Refer.Expr.GetValue())
v.SetDatum(*v.Refer.Expr.GetDatum())
return true
}
@ -218,11 +219,12 @@ func (e *Evaluator) defaultExpr(v *ast.DefaultExpr) bool {
}
func (e *Evaluator) compareSubquery(cs *ast.CompareSubqueryExpr) bool {
lv := cs.L.GetValue()
if lv == nil {
cs.SetValue(nil)
lvDatum := cs.L.GetDatum()
if lvDatum.Kind() == types.KindNull {
cs.SetNull()
return true
}
lv := lvDatum.GetValue()
x, err := e.checkResult(cs, lv, cs.R.GetValue().([]interface{}))
if err != nil {
e.err = errors.Trace(err)
@ -302,17 +304,17 @@ func (e *Evaluator) checkAnyResult(cs *ast.CompareSubqueryExpr, lv interface{},
}
func (e *Evaluator) existsSubquery(v *ast.ExistsSubqueryExpr) bool {
r := v.Sel.GetValue()
if r == nil {
v.SetValue(0)
datum := v.Sel.GetDatum()
if datum.Kind() == types.KindNull {
v.SetInt64(0)
return true
}
r := datum.GetValue()
rows, _ := r.([]interface{})
if len(rows) > 0 {
v.SetValue(1)
v.SetInt64(1)
} else {
v.SetValue(0)
v.SetInt64(0)
}
return true
}
@ -321,7 +323,7 @@ func (e *Evaluator) existsSubquery(v *ast.ExistsSubqueryExpr) bool {
// Get the value from v.SubQuery and set it to v.
func (e *Evaluator) subqueryExpr(v *ast.SubqueryExpr) bool {
if v.SubqueryExec != nil {
v.SetValue(v.SubqueryExec.GetValue())
v.SetDatum(*v.SubqueryExec.GetDatum())
}
v.Evaluated = true
return true
@ -346,7 +348,7 @@ func (e *Evaluator) subqueryExec(v ast.SubqueryExec) bool {
}
switch len(rows) {
case 0:
v.SetValue(nil)
v.GetDatum().SetNull()
case 1:
v.SetValue(rows[0])
default:
@ -389,9 +391,9 @@ func (e *Evaluator) checkInList(not bool, in interface{}, list []interface{}) in
}
func (e *Evaluator) patternIn(n *ast.PatternInExpr) bool {
lhs := n.Expr.GetValue()
if lhs == nil {
n.SetValue(nil)
lhs := n.Expr.GetDatum()
if lhs.Kind() == types.KindNull {
n.SetNull()
return true
}
if n.Sel == nil {
@ -399,7 +401,7 @@ func (e *Evaluator) patternIn(n *ast.PatternInExpr) bool {
for _, ei := range n.List {
values = append(values, ei.GetValue())
}
x := e.checkInList(n.Not, lhs, values)
x := e.checkInList(n.Not, lhs.GetValue(), values)
if e.err != nil {
return false
}
@ -410,7 +412,7 @@ func (e *Evaluator) patternIn(n *ast.PatternInExpr) bool {
sel := se.SubqueryExec
res := sel.GetValue().([]interface{})
x := e.checkInList(n.Not, lhs, res)
x := e.checkInList(n.Not, lhs.GetValue(), res)
if e.err != nil {
return false
}
@ -420,21 +422,21 @@ func (e *Evaluator) patternIn(n *ast.PatternInExpr) bool {
func (e *Evaluator) isNull(v *ast.IsNullExpr) bool {
var boolVal bool
if v.Expr.GetValue() == nil {
if v.Expr.GetDatum().Kind() == types.KindNull {
boolVal = true
}
if v.Not {
boolVal = !boolVal
}
v.SetValue(boolToInt64(boolVal))
v.SetInt64(boolToInt64(boolVal))
return true
}
func (e *Evaluator) isTruth(v *ast.IsTruthExpr) bool {
var boolVal bool
val := v.Expr.GetValue()
if val != nil {
ival, err := types.ToBool(val)
datum := v.Expr.GetDatum()
if datum.Kind() != types.KindNull {
ival, err := datum.ToBool()
if err != nil {
e.err = errors.Trace(err)
return false
@ -446,7 +448,7 @@ func (e *Evaluator) isTruth(v *ast.IsTruthExpr) bool {
if v.Not {
boolVal = !boolVal
}
v.SetValue(boolToInt64(boolVal))
v.GetDatum().SetInt64(boolToInt64(boolVal))
return true
}
@ -455,12 +457,12 @@ func (e *Evaluator) paramMarker(v *ast.ParamMarkerExpr) bool {
}
func (e *Evaluator) parentheses(v *ast.ParenthesesExpr) bool {
v.SetValue(v.Expr.GetValue())
v.SetDatum(*v.Expr.GetDatum())
return true
}
func (e *Evaluator) position(v *ast.PositionExpr) bool {
v.SetValue(v.Refer.Expr.GetValue())
v.SetDatum(*v.Refer.Expr.GetDatum())
return true
}
@ -479,135 +481,82 @@ func (e *Evaluator) unaryOperation(u *ast.UnaryOperationExpr) bool {
e.err = errors.Errorf("%v", er)
}
}()
a := u.V.GetValue()
if a == nil {
u.SetValue(nil)
aDatum := u.V.GetDatum()
if aDatum.Kind() == types.KindNull {
u.SetNull()
return true
}
switch op := u.Op; op {
case opcode.Not:
n, err := types.ToBool(a)
n, err := aDatum.ToBool()
if err != nil {
e.err = errors.Trace(err)
} else if n == 0 {
u.SetValue(int64(1))
u.SetInt64(1)
} else {
u.SetValue(int64(0))
u.SetInt64(0)
}
case opcode.BitNeg:
// for bit operation, we will use int64 first, then return uint64
n, err := types.ToInt64(a)
n, err := aDatum.ToInt64()
if err != nil {
e.err = errors.Trace(err)
return false
}
u.SetValue(uint64(^n))
u.SetUint64(uint64(^n))
case opcode.Plus:
switch x := a.(type) {
case bool:
u.SetValue(boolToInt64(x))
case float32:
u.SetValue(+x)
case float64:
u.SetValue(+x)
case int:
u.SetValue(+x)
case int8:
u.SetValue(+x)
case int16:
u.SetValue(+x)
case int32:
u.SetValue(+x)
case int64:
u.SetValue(+x)
case uint:
u.SetValue(+x)
case uint8:
u.SetValue(+x)
case uint16:
u.SetValue(+x)
case uint32:
u.SetValue(+x)
case uint64:
u.SetValue(+x)
case mysql.Duration:
u.SetValue(x)
case mysql.Time:
u.SetValue(x)
case string:
u.SetValue(x)
case mysql.Decimal:
u.SetValue(x)
case []byte:
u.SetValue(x)
case mysql.Hex:
u.SetValue(x)
case mysql.Bit:
u.SetValue(x)
case mysql.Enum:
u.SetValue(x)
case mysql.Set:
u.SetValue(x)
switch aDatum.Kind() {
case types.KindInt64,
types.KindUint64,
types.KindFloat64,
types.KindFloat32,
types.KindMysqlDuration,
types.KindMysqlTime,
types.KindString,
types.KindMysqlDecimal,
types.KindBytes,
types.KindMysqlHex,
types.KindMysqlBit,
types.KindMysqlEnum,
types.KindMysqlSet:
u.SetDatum(*aDatum)
default:
e.err = ErrInvalidOperation
return false
}
case opcode.Minus:
switch x := a.(type) {
case bool:
if x {
u.SetValue(int64(-1))
} else {
u.SetValue(int64(0))
}
case float32:
u.SetValue(-x)
case float64:
u.SetValue(-x)
case int:
u.SetValue(-x)
case int8:
u.SetValue(-x)
case int16:
u.SetValue(-x)
case int32:
u.SetValue(-x)
case int64:
u.SetValue(-x)
case uint:
u.SetValue(-int64(x))
case uint8:
u.SetValue(-int64(x))
case uint16:
u.SetValue(-int64(x))
case uint32:
u.SetValue(-int64(x))
case uint64:
// TODO: check overflow and do more test for unsigned type
u.SetValue(-int64(x))
case mysql.Duration:
u.SetValue(mysql.ZeroDecimal.Sub(x.ToNumber()))
case mysql.Time:
u.SetValue(mysql.ZeroDecimal.Sub(x.ToNumber()))
case string:
f, err := types.StrToFloat(x)
switch aDatum.Kind() {
case types.KindInt64:
u.SetInt64(-aDatum.GetInt64())
case types.KindUint64:
u.SetInt64(-int64(aDatum.GetUint64()))
case types.KindFloat64:
u.SetFloat64(-aDatum.GetFloat64())
case types.KindFloat32:
u.SetFloat32(-aDatum.GetFloat32())
case types.KindMysqlDuration:
u.SetValue(mysql.ZeroDecimal.Sub(aDatum.GetMysqlDuration().ToNumber()))
case types.KindMysqlTime:
u.SetValue(mysql.ZeroDecimal.Sub(aDatum.GetMysqlTime().ToNumber()))
case types.KindString:
f, err := types.StrToFloat(aDatum.GetString())
e.err = errors.Trace(err)
u.SetValue(-f)
case mysql.Decimal:
f, _ := x.Float64()
u.SetFloat64(-f)
case types.KindMysqlDecimal:
f, _ := aDatum.GetMysqlDecimal().Float64()
u.SetValue(mysql.NewDecimalFromFloat(-f))
case []byte:
f, err := types.StrToFloat(string(x))
case types.KindBytes:
f, err := types.StrToFloat(string(aDatum.GetBytes()))
e.err = errors.Trace(err)
u.SetValue(-f)
case mysql.Hex:
u.SetValue(-x.ToNumber())
case mysql.Bit:
u.SetValue(-x.ToNumber())
case mysql.Enum:
u.SetValue(-x.ToNumber())
case mysql.Set:
u.SetValue(-x.ToNumber())
u.SetFloat64(-f)
case types.KindMysqlHex:
u.SetFloat64(-aDatum.GetMysqlHex().ToNumber())
case types.KindMysqlBit:
u.SetFloat64(-aDatum.GetMysqlBit().ToNumber())
case types.KindMysqlEnum:
u.SetFloat64(-aDatum.GetMysqlEnum().ToNumber())
case types.KindMysqlSet:
u.SetFloat64(-aDatum.GetMysqlSet().ToNumber())
default:
e.err = ErrInvalidOperation
return false
@ -621,7 +570,7 @@ func (e *Evaluator) unaryOperation(u *ast.UnaryOperationExpr) bool {
}
func (e *Evaluator) values(v *ast.ValuesExpr) bool {
v.SetValue(v.Column.GetValue())
v.SetDatum(*v.Column.GetDatum())
return true
}
@ -632,11 +581,11 @@ func (e *Evaluator) variable(v *ast.VariableExpr) bool {
if !v.IsSystem {
// user vars
if value, ok := sessionVars.Users[name]; ok {
v.SetValue(value)
v.SetString(value)
return true
}
// select null user vars is permitted.
v.SetValue(nil)
v.SetNull()
return true
}
@ -649,16 +598,18 @@ func (e *Evaluator) variable(v *ast.VariableExpr) bool {
if !v.IsGlobal {
if value, ok := sessionVars.Systems[name]; ok {
v.SetValue(value)
v.SetString(value)
return true
}
}
value, err := globalVars.GetGlobalSysVar(e.ctx, name)
if err != nil {
e.err = errors.Trace(err)
return false
}
v.SetValue(value)
v.SetString(value)
return true
}
@ -689,7 +640,7 @@ func (e *Evaluator) funcCast(v *ast.FuncCastExpr) bool {
value := v.Expr.GetValue()
// Casting nil to any type returns null
if value == nil {
v.SetValue(nil)
v.SetNull()
return true
}
var err error
@ -719,7 +670,7 @@ func (e *Evaluator) aggregateFunc(v *ast.AggregateFuncExpr) bool {
func (e *Evaluator) evalAggCount(v *ast.AggregateFuncExpr) {
ctx := v.GetContext()
v.SetValue(ctx.Count)
v.SetInt64(ctx.Count)
}
func (e *Evaluator) evalAggSetValue(v *ast.AggregateFuncExpr) {

View File

@ -34,6 +34,16 @@ type subquery struct {
is infoschema.InfoSchema
}
// SetDatum implements Expression interface.
func (sq *subquery) SetDatum(datum types.Datum) {
sq.Datum = datum
}
// GetDatum implements Expression interface.
func (sq *subquery) GetDatum() *types.Datum {
return &sq.Datum
}
// SetFlag implements Expression interface.
func (sq *subquery) SetFlag(flag uint64) {
sq.flag = flag

View File

@ -973,6 +973,120 @@ func (d *Datum) convertToMysqlSet(target *FieldType) (Datum, error) {
return ret, nil
}
// ToBool converts to a bool.
// We will use 1 for true, and 0 for false.
func (d *Datum) ToBool() (int64, error) {
isZero := false
switch d.Kind() {
case KindInt64:
isZero = (d.GetInt64() == 0)
case KindUint64:
isZero = (d.GetUint64() == 0)
case KindFloat32:
isZero = (d.GetFloat32() == 0)
case KindFloat64:
isZero = (d.GetFloat64() == 0)
case KindString:
s := d.GetString()
if len(s) == 0 {
isZero = true
}
n, err := StrToInt(s)
if err != nil {
return 0, err
}
isZero = (n == 0)
case KindBytes:
bs := d.GetBytes()
if len(bs) == 0 {
isZero = true
} else {
n, err := StrToInt(string(bs))
if err != nil {
return 0, err
}
isZero = (n == 0)
}
case KindMysqlTime:
isZero = d.GetMysqlTime().IsZero()
case KindMysqlDuration:
isZero = (d.GetMysqlDuration().Duration == 0)
case KindMysqlDecimal:
v, _ := d.GetMysqlDecimal().Float64()
isZero = (v == 0)
case KindMysqlHex:
isZero = (d.GetMysqlHex().ToNumber() == 0)
case KindMysqlBit:
isZero = (d.GetMysqlBit().ToNumber() == 0)
case KindMysqlEnum:
isZero = (d.GetMysqlEnum().ToNumber() == 0)
case KindMysqlSet:
isZero = (d.GetMysqlSet().ToNumber() == 0)
default:
return 0, errors.Errorf("cannot convert %v(type %T) to bool", d.GetValue(), d.GetValue())
}
if isZero {
return 0, nil
}
return 1, nil
}
// ToInt64 converts to a int64.
func (d *Datum) ToInt64() (int64, error) {
tp := mysql.TypeLonglong
lowerBound := signedLowerBound[tp]
upperBound := signedUpperBound[tp]
switch d.Kind() {
case KindInt64:
return convertIntToInt(d.GetInt64(), lowerBound, upperBound, tp)
case KindUint64:
return convertUintToInt(d.GetUint64(), upperBound, tp)
case KindFloat32:
return convertFloatToInt(float64(d.GetFloat32()), lowerBound, upperBound, tp)
case KindFloat64:
return convertFloatToInt(d.GetFloat64(), lowerBound, upperBound, tp)
case KindString:
s := d.GetString()
fval, err := StrToFloat(s)
if err != nil {
return 0, errors.Trace(err)
}
return convertFloatToInt(fval, lowerBound, upperBound, tp)
case KindBytes:
s := string(d.GetBytes())
fval, err := StrToFloat(s)
if err != nil {
return 0, errors.Trace(err)
}
return convertFloatToInt(fval, lowerBound, upperBound, tp)
case KindMysqlTime:
// 2011-11-10 11:11:11.999999 -> 20111110111112
ival := d.GetMysqlTime().ToNumber().Round(0).IntPart()
return convertIntToInt(ival, lowerBound, upperBound, tp)
case KindMysqlDuration:
// 11:11:11.999999 -> 111112
ival := d.GetMysqlDuration().ToNumber().Round(0).IntPart()
return convertIntToInt(ival, lowerBound, upperBound, tp)
case KindMysqlDecimal:
fval, _ := d.GetMysqlDecimal().Float64()
return convertFloatToInt(fval, lowerBound, upperBound, tp)
case KindMysqlHex:
fval := d.GetMysqlHex().ToNumber()
return convertFloatToInt(fval, lowerBound, upperBound, tp)
case KindMysqlBit:
fval := d.GetMysqlBit().ToNumber()
return convertFloatToInt(fval, lowerBound, upperBound, tp)
case KindMysqlEnum:
fval := d.GetMysqlEnum().ToNumber()
return convertFloatToInt(fval, lowerBound, upperBound, tp)
case KindMysqlSet:
fval := d.GetMysqlSet().ToNumber()
return convertFloatToInt(fval, lowerBound, upperBound, tp)
default:
return 0, errors.Errorf("cannot convert %v(type %T) to int64", d.GetValue(), d.GetValue())
}
}
func invalidConv(d *Datum, tp byte) (Datum, error) {
return Datum{}, errors.Errorf("cannot convert %v to type %s", d, TypeStr(tp))
}

View File

@ -15,6 +15,7 @@ package types
import (
. "github.com/pingcap/check"
"github.com/pingcap/tidb/mysql"
)
var _ = Suite(&testDatumSuite{})
@ -38,3 +39,80 @@ func (ts *testDatumSuite) TestDatum(c *C) {
c.Assert(x, DeepEquals, val)
}
}
func testDatumToBool(c *C, in interface{}, res int) {
datum := NewDatum(in)
res64 := int64(res)
b, err := datum.ToBool()
c.Assert(err, IsNil)
c.Assert(b, Equals, res64)
}
func (ts *testDatumSuite) TestToBool(c *C) {
testDatumToBool(c, int(0), 0)
testDatumToBool(c, int64(0), 0)
testDatumToBool(c, uint64(0), 0)
testDatumToBool(c, float32(0), 0)
testDatumToBool(c, float64(0), 0)
testDatumToBool(c, "", 0)
testDatumToBool(c, "0", 0)
testDatumToBool(c, []byte{}, 0)
testDatumToBool(c, []byte("0"), 0)
testDatumToBool(c, mysql.Hex{Value: 0}, 0)
testDatumToBool(c, mysql.Bit{Value: 0, Width: 8}, 0)
testDatumToBool(c, mysql.Enum{Name: "a", Value: 1}, 1)
testDatumToBool(c, mysql.Set{Name: "a", Value: 1}, 1)
t, err := mysql.ParseTime("2011-11-10 11:11:11.999999", mysql.TypeTimestamp, 6)
c.Assert(err, IsNil)
testDatumToBool(c, t, 1)
td, err := mysql.ParseDuration("11:11:11.999999", 6)
c.Assert(err, IsNil)
testDatumToBool(c, td, 1)
ft := NewFieldType(mysql.TypeNewDecimal)
ft.Decimal = 5
v, err := Convert(3.1415926, ft)
c.Assert(err, IsNil)
testDatumToBool(c, v, 1)
d := NewDatum(&invalidMockType{})
_, err = d.ToBool()
c.Assert(err, NotNil)
}
func testDatumToInt64(c *C, val interface{}, expect int64) {
d := NewDatum(val)
b, err := d.ToInt64()
c.Assert(err, IsNil)
c.Assert(b, Equals, expect)
}
func (ts *testTypeConvertSuite) TestToInt64(c *C) {
testDatumToInt64(c, "0", int64(0))
testDatumToInt64(c, int(0), int64(0))
testDatumToInt64(c, int64(0), int64(0))
testDatumToInt64(c, uint64(0), int64(0))
testDatumToInt64(c, float32(3.1), int64(3))
testDatumToInt64(c, float64(3.1), int64(3))
testDatumToInt64(c, mysql.Hex{Value: 100}, int64(100))
testDatumToInt64(c, mysql.Bit{Value: 100, Width: 8}, int64(100))
testDatumToInt64(c, mysql.Enum{Name: "a", Value: 1}, int64(1))
testDatumToInt64(c, mysql.Set{Name: "a", Value: 1}, int64(1))
t, err := mysql.ParseTime("2011-11-10 11:11:11.999999", mysql.TypeTimestamp, 0)
c.Assert(err, IsNil)
testDatumToInt64(c, t, int64(20111110111112))
td, err := mysql.ParseDuration("11:11:11.999999", 6)
c.Assert(err, IsNil)
testDatumToInt64(c, td, int64(111112))
ft := NewFieldType(mysql.TypeNewDecimal)
ft.Decimal = 5
v, err := Convert(3.1415926, ft)
c.Assert(err, IsNil)
testDatumToInt64(c, v, int64(3))
_, err = ToInt64(&invalidMockType{})
c.Assert(err, NotNil)
}