Merge pull request #968 from pingcap/zxylvlp/change-datum-in-ast
evaluator: change Get/SetValue to Get/SetDatum.
This commit is contained in:
@ -64,6 +64,10 @@ type ExprNode interface {
|
||||
SetValue(val interface{})
|
||||
// GetValue gets value of the expression.
|
||||
GetValue() interface{}
|
||||
// SetDatum sets datum to the expression.
|
||||
SetDatum(datum types.Datum)
|
||||
// GetDatum gets datum of the expression.
|
||||
GetDatum() *types.Datum
|
||||
// SetFlag sets flag to the expression.
|
||||
// Flag indicates whether the expression contains
|
||||
// parameter marker, reference, aggregate function...
|
||||
|
||||
10
ast/base.go
10
ast/base.go
@ -67,6 +67,16 @@ type exprNode struct {
|
||||
flag uint64
|
||||
}
|
||||
|
||||
// SetDatum implements Expression interface.
|
||||
func (en *exprNode) SetDatum(datum types.Datum) {
|
||||
en.Datum = datum
|
||||
}
|
||||
|
||||
// GetDatum implements Expression interface.
|
||||
func (en *exprNode) GetDatum() *types.Datum {
|
||||
return &en.Datum
|
||||
}
|
||||
|
||||
// SetType implements Expression interface.
|
||||
func (en *exprNode) SetType(tp *types.FieldType) {
|
||||
en.Type = tp
|
||||
|
||||
@ -178,38 +178,39 @@ func (e *Evaluator) between(v *ast.BetweenExpr) bool {
|
||||
if e.err != nil {
|
||||
return false
|
||||
}
|
||||
v.SetValue(ret.GetValue())
|
||||
v.SetDatum(*ret.GetDatum())
|
||||
return true
|
||||
}
|
||||
|
||||
func (e *Evaluator) caseExpr(v *ast.CaseExpr) bool {
|
||||
var target interface{} = true
|
||||
tmp := types.NewDatum(boolToInt64(true))
|
||||
target := &tmp
|
||||
if v.Value != nil {
|
||||
target = v.Value.GetValue()
|
||||
target = v.Value.GetDatum()
|
||||
}
|
||||
if target != nil {
|
||||
if target.Kind() != types.KindNull {
|
||||
for _, val := range v.WhenClauses {
|
||||
cmp, err := types.Compare(target, val.Expr.GetValue())
|
||||
cmp, err := target.CompareDatum(*val.Expr.GetDatum())
|
||||
if err != nil {
|
||||
e.err = errors.Trace(err)
|
||||
return false
|
||||
}
|
||||
if cmp == 0 {
|
||||
v.SetValue(val.Result.GetValue())
|
||||
v.SetDatum(*val.Result.GetDatum())
|
||||
return true
|
||||
}
|
||||
}
|
||||
}
|
||||
if v.ElseClause != nil {
|
||||
v.SetValue(v.ElseClause.GetValue())
|
||||
v.SetDatum(*v.ElseClause.GetDatum())
|
||||
} else {
|
||||
v.SetValue(nil)
|
||||
v.SetNull()
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
func (e *Evaluator) columnName(v *ast.ColumnNameExpr) bool {
|
||||
v.SetValue(v.Refer.Expr.GetValue())
|
||||
v.SetDatum(*v.Refer.Expr.GetDatum())
|
||||
return true
|
||||
}
|
||||
|
||||
@ -218,11 +219,12 @@ func (e *Evaluator) defaultExpr(v *ast.DefaultExpr) bool {
|
||||
}
|
||||
|
||||
func (e *Evaluator) compareSubquery(cs *ast.CompareSubqueryExpr) bool {
|
||||
lv := cs.L.GetValue()
|
||||
if lv == nil {
|
||||
cs.SetValue(nil)
|
||||
lvDatum := cs.L.GetDatum()
|
||||
if lvDatum.Kind() == types.KindNull {
|
||||
cs.SetNull()
|
||||
return true
|
||||
}
|
||||
lv := lvDatum.GetValue()
|
||||
x, err := e.checkResult(cs, lv, cs.R.GetValue().([]interface{}))
|
||||
if err != nil {
|
||||
e.err = errors.Trace(err)
|
||||
@ -302,17 +304,17 @@ func (e *Evaluator) checkAnyResult(cs *ast.CompareSubqueryExpr, lv interface{},
|
||||
}
|
||||
|
||||
func (e *Evaluator) existsSubquery(v *ast.ExistsSubqueryExpr) bool {
|
||||
r := v.Sel.GetValue()
|
||||
if r == nil {
|
||||
v.SetValue(0)
|
||||
datum := v.Sel.GetDatum()
|
||||
if datum.Kind() == types.KindNull {
|
||||
v.SetInt64(0)
|
||||
return true
|
||||
}
|
||||
r := datum.GetValue()
|
||||
rows, _ := r.([]interface{})
|
||||
if len(rows) > 0 {
|
||||
v.SetValue(1)
|
||||
v.SetInt64(1)
|
||||
} else {
|
||||
v.SetValue(0)
|
||||
|
||||
v.SetInt64(0)
|
||||
}
|
||||
return true
|
||||
}
|
||||
@ -321,7 +323,7 @@ func (e *Evaluator) existsSubquery(v *ast.ExistsSubqueryExpr) bool {
|
||||
// Get the value from v.SubQuery and set it to v.
|
||||
func (e *Evaluator) subqueryExpr(v *ast.SubqueryExpr) bool {
|
||||
if v.SubqueryExec != nil {
|
||||
v.SetValue(v.SubqueryExec.GetValue())
|
||||
v.SetDatum(*v.SubqueryExec.GetDatum())
|
||||
}
|
||||
v.Evaluated = true
|
||||
return true
|
||||
@ -346,7 +348,7 @@ func (e *Evaluator) subqueryExec(v ast.SubqueryExec) bool {
|
||||
}
|
||||
switch len(rows) {
|
||||
case 0:
|
||||
v.SetValue(nil)
|
||||
v.GetDatum().SetNull()
|
||||
case 1:
|
||||
v.SetValue(rows[0])
|
||||
default:
|
||||
@ -389,9 +391,9 @@ func (e *Evaluator) checkInList(not bool, in interface{}, list []interface{}) in
|
||||
}
|
||||
|
||||
func (e *Evaluator) patternIn(n *ast.PatternInExpr) bool {
|
||||
lhs := n.Expr.GetValue()
|
||||
if lhs == nil {
|
||||
n.SetValue(nil)
|
||||
lhs := n.Expr.GetDatum()
|
||||
if lhs.Kind() == types.KindNull {
|
||||
n.SetNull()
|
||||
return true
|
||||
}
|
||||
if n.Sel == nil {
|
||||
@ -399,7 +401,7 @@ func (e *Evaluator) patternIn(n *ast.PatternInExpr) bool {
|
||||
for _, ei := range n.List {
|
||||
values = append(values, ei.GetValue())
|
||||
}
|
||||
x := e.checkInList(n.Not, lhs, values)
|
||||
x := e.checkInList(n.Not, lhs.GetValue(), values)
|
||||
if e.err != nil {
|
||||
return false
|
||||
}
|
||||
@ -410,7 +412,7 @@ func (e *Evaluator) patternIn(n *ast.PatternInExpr) bool {
|
||||
sel := se.SubqueryExec
|
||||
|
||||
res := sel.GetValue().([]interface{})
|
||||
x := e.checkInList(n.Not, lhs, res)
|
||||
x := e.checkInList(n.Not, lhs.GetValue(), res)
|
||||
if e.err != nil {
|
||||
return false
|
||||
}
|
||||
@ -420,21 +422,21 @@ func (e *Evaluator) patternIn(n *ast.PatternInExpr) bool {
|
||||
|
||||
func (e *Evaluator) isNull(v *ast.IsNullExpr) bool {
|
||||
var boolVal bool
|
||||
if v.Expr.GetValue() == nil {
|
||||
if v.Expr.GetDatum().Kind() == types.KindNull {
|
||||
boolVal = true
|
||||
}
|
||||
if v.Not {
|
||||
boolVal = !boolVal
|
||||
}
|
||||
v.SetValue(boolToInt64(boolVal))
|
||||
v.SetInt64(boolToInt64(boolVal))
|
||||
return true
|
||||
}
|
||||
|
||||
func (e *Evaluator) isTruth(v *ast.IsTruthExpr) bool {
|
||||
var boolVal bool
|
||||
val := v.Expr.GetValue()
|
||||
if val != nil {
|
||||
ival, err := types.ToBool(val)
|
||||
datum := v.Expr.GetDatum()
|
||||
if datum.Kind() != types.KindNull {
|
||||
ival, err := datum.ToBool()
|
||||
if err != nil {
|
||||
e.err = errors.Trace(err)
|
||||
return false
|
||||
@ -446,7 +448,7 @@ func (e *Evaluator) isTruth(v *ast.IsTruthExpr) bool {
|
||||
if v.Not {
|
||||
boolVal = !boolVal
|
||||
}
|
||||
v.SetValue(boolToInt64(boolVal))
|
||||
v.GetDatum().SetInt64(boolToInt64(boolVal))
|
||||
return true
|
||||
}
|
||||
|
||||
@ -455,12 +457,12 @@ func (e *Evaluator) paramMarker(v *ast.ParamMarkerExpr) bool {
|
||||
}
|
||||
|
||||
func (e *Evaluator) parentheses(v *ast.ParenthesesExpr) bool {
|
||||
v.SetValue(v.Expr.GetValue())
|
||||
v.SetDatum(*v.Expr.GetDatum())
|
||||
return true
|
||||
}
|
||||
|
||||
func (e *Evaluator) position(v *ast.PositionExpr) bool {
|
||||
v.SetValue(v.Refer.Expr.GetValue())
|
||||
v.SetDatum(*v.Refer.Expr.GetDatum())
|
||||
return true
|
||||
}
|
||||
|
||||
@ -479,135 +481,82 @@ func (e *Evaluator) unaryOperation(u *ast.UnaryOperationExpr) bool {
|
||||
e.err = errors.Errorf("%v", er)
|
||||
}
|
||||
}()
|
||||
a := u.V.GetValue()
|
||||
if a == nil {
|
||||
u.SetValue(nil)
|
||||
aDatum := u.V.GetDatum()
|
||||
if aDatum.Kind() == types.KindNull {
|
||||
u.SetNull()
|
||||
return true
|
||||
}
|
||||
switch op := u.Op; op {
|
||||
case opcode.Not:
|
||||
n, err := types.ToBool(a)
|
||||
n, err := aDatum.ToBool()
|
||||
if err != nil {
|
||||
e.err = errors.Trace(err)
|
||||
} else if n == 0 {
|
||||
u.SetValue(int64(1))
|
||||
u.SetInt64(1)
|
||||
} else {
|
||||
u.SetValue(int64(0))
|
||||
u.SetInt64(0)
|
||||
}
|
||||
case opcode.BitNeg:
|
||||
// for bit operation, we will use int64 first, then return uint64
|
||||
n, err := types.ToInt64(a)
|
||||
n, err := aDatum.ToInt64()
|
||||
if err != nil {
|
||||
e.err = errors.Trace(err)
|
||||
return false
|
||||
}
|
||||
u.SetValue(uint64(^n))
|
||||
u.SetUint64(uint64(^n))
|
||||
case opcode.Plus:
|
||||
switch x := a.(type) {
|
||||
case bool:
|
||||
u.SetValue(boolToInt64(x))
|
||||
case float32:
|
||||
u.SetValue(+x)
|
||||
case float64:
|
||||
u.SetValue(+x)
|
||||
case int:
|
||||
u.SetValue(+x)
|
||||
case int8:
|
||||
u.SetValue(+x)
|
||||
case int16:
|
||||
u.SetValue(+x)
|
||||
case int32:
|
||||
u.SetValue(+x)
|
||||
case int64:
|
||||
u.SetValue(+x)
|
||||
case uint:
|
||||
u.SetValue(+x)
|
||||
case uint8:
|
||||
u.SetValue(+x)
|
||||
case uint16:
|
||||
u.SetValue(+x)
|
||||
case uint32:
|
||||
u.SetValue(+x)
|
||||
case uint64:
|
||||
u.SetValue(+x)
|
||||
case mysql.Duration:
|
||||
u.SetValue(x)
|
||||
case mysql.Time:
|
||||
u.SetValue(x)
|
||||
case string:
|
||||
u.SetValue(x)
|
||||
case mysql.Decimal:
|
||||
u.SetValue(x)
|
||||
case []byte:
|
||||
u.SetValue(x)
|
||||
case mysql.Hex:
|
||||
u.SetValue(x)
|
||||
case mysql.Bit:
|
||||
u.SetValue(x)
|
||||
case mysql.Enum:
|
||||
u.SetValue(x)
|
||||
case mysql.Set:
|
||||
u.SetValue(x)
|
||||
switch aDatum.Kind() {
|
||||
case types.KindInt64,
|
||||
types.KindUint64,
|
||||
types.KindFloat64,
|
||||
types.KindFloat32,
|
||||
types.KindMysqlDuration,
|
||||
types.KindMysqlTime,
|
||||
types.KindString,
|
||||
types.KindMysqlDecimal,
|
||||
types.KindBytes,
|
||||
types.KindMysqlHex,
|
||||
types.KindMysqlBit,
|
||||
types.KindMysqlEnum,
|
||||
types.KindMysqlSet:
|
||||
u.SetDatum(*aDatum)
|
||||
default:
|
||||
e.err = ErrInvalidOperation
|
||||
return false
|
||||
}
|
||||
case opcode.Minus:
|
||||
switch x := a.(type) {
|
||||
case bool:
|
||||
if x {
|
||||
u.SetValue(int64(-1))
|
||||
} else {
|
||||
u.SetValue(int64(0))
|
||||
}
|
||||
case float32:
|
||||
u.SetValue(-x)
|
||||
case float64:
|
||||
u.SetValue(-x)
|
||||
case int:
|
||||
u.SetValue(-x)
|
||||
case int8:
|
||||
u.SetValue(-x)
|
||||
case int16:
|
||||
u.SetValue(-x)
|
||||
case int32:
|
||||
u.SetValue(-x)
|
||||
case int64:
|
||||
u.SetValue(-x)
|
||||
case uint:
|
||||
u.SetValue(-int64(x))
|
||||
case uint8:
|
||||
u.SetValue(-int64(x))
|
||||
case uint16:
|
||||
u.SetValue(-int64(x))
|
||||
case uint32:
|
||||
u.SetValue(-int64(x))
|
||||
case uint64:
|
||||
// TODO: check overflow and do more test for unsigned type
|
||||
u.SetValue(-int64(x))
|
||||
case mysql.Duration:
|
||||
u.SetValue(mysql.ZeroDecimal.Sub(x.ToNumber()))
|
||||
case mysql.Time:
|
||||
u.SetValue(mysql.ZeroDecimal.Sub(x.ToNumber()))
|
||||
case string:
|
||||
f, err := types.StrToFloat(x)
|
||||
switch aDatum.Kind() {
|
||||
case types.KindInt64:
|
||||
u.SetInt64(-aDatum.GetInt64())
|
||||
case types.KindUint64:
|
||||
u.SetInt64(-int64(aDatum.GetUint64()))
|
||||
case types.KindFloat64:
|
||||
u.SetFloat64(-aDatum.GetFloat64())
|
||||
case types.KindFloat32:
|
||||
u.SetFloat32(-aDatum.GetFloat32())
|
||||
case types.KindMysqlDuration:
|
||||
u.SetValue(mysql.ZeroDecimal.Sub(aDatum.GetMysqlDuration().ToNumber()))
|
||||
case types.KindMysqlTime:
|
||||
u.SetValue(mysql.ZeroDecimal.Sub(aDatum.GetMysqlTime().ToNumber()))
|
||||
case types.KindString:
|
||||
f, err := types.StrToFloat(aDatum.GetString())
|
||||
e.err = errors.Trace(err)
|
||||
u.SetValue(-f)
|
||||
case mysql.Decimal:
|
||||
f, _ := x.Float64()
|
||||
u.SetFloat64(-f)
|
||||
case types.KindMysqlDecimal:
|
||||
f, _ := aDatum.GetMysqlDecimal().Float64()
|
||||
u.SetValue(mysql.NewDecimalFromFloat(-f))
|
||||
case []byte:
|
||||
f, err := types.StrToFloat(string(x))
|
||||
case types.KindBytes:
|
||||
f, err := types.StrToFloat(string(aDatum.GetBytes()))
|
||||
e.err = errors.Trace(err)
|
||||
u.SetValue(-f)
|
||||
case mysql.Hex:
|
||||
u.SetValue(-x.ToNumber())
|
||||
case mysql.Bit:
|
||||
u.SetValue(-x.ToNumber())
|
||||
case mysql.Enum:
|
||||
u.SetValue(-x.ToNumber())
|
||||
case mysql.Set:
|
||||
u.SetValue(-x.ToNumber())
|
||||
u.SetFloat64(-f)
|
||||
case types.KindMysqlHex:
|
||||
u.SetFloat64(-aDatum.GetMysqlHex().ToNumber())
|
||||
case types.KindMysqlBit:
|
||||
u.SetFloat64(-aDatum.GetMysqlBit().ToNumber())
|
||||
case types.KindMysqlEnum:
|
||||
u.SetFloat64(-aDatum.GetMysqlEnum().ToNumber())
|
||||
case types.KindMysqlSet:
|
||||
u.SetFloat64(-aDatum.GetMysqlSet().ToNumber())
|
||||
default:
|
||||
e.err = ErrInvalidOperation
|
||||
return false
|
||||
@ -621,7 +570,7 @@ func (e *Evaluator) unaryOperation(u *ast.UnaryOperationExpr) bool {
|
||||
}
|
||||
|
||||
func (e *Evaluator) values(v *ast.ValuesExpr) bool {
|
||||
v.SetValue(v.Column.GetValue())
|
||||
v.SetDatum(*v.Column.GetDatum())
|
||||
return true
|
||||
}
|
||||
|
||||
@ -632,11 +581,11 @@ func (e *Evaluator) variable(v *ast.VariableExpr) bool {
|
||||
if !v.IsSystem {
|
||||
// user vars
|
||||
if value, ok := sessionVars.Users[name]; ok {
|
||||
v.SetValue(value)
|
||||
v.SetString(value)
|
||||
return true
|
||||
}
|
||||
// select null user vars is permitted.
|
||||
v.SetValue(nil)
|
||||
v.SetNull()
|
||||
return true
|
||||
}
|
||||
|
||||
@ -649,16 +598,18 @@ func (e *Evaluator) variable(v *ast.VariableExpr) bool {
|
||||
|
||||
if !v.IsGlobal {
|
||||
if value, ok := sessionVars.Systems[name]; ok {
|
||||
v.SetValue(value)
|
||||
v.SetString(value)
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
value, err := globalVars.GetGlobalSysVar(e.ctx, name)
|
||||
if err != nil {
|
||||
e.err = errors.Trace(err)
|
||||
return false
|
||||
}
|
||||
v.SetValue(value)
|
||||
|
||||
v.SetString(value)
|
||||
return true
|
||||
}
|
||||
|
||||
@ -689,7 +640,7 @@ func (e *Evaluator) funcCast(v *ast.FuncCastExpr) bool {
|
||||
value := v.Expr.GetValue()
|
||||
// Casting nil to any type returns null
|
||||
if value == nil {
|
||||
v.SetValue(nil)
|
||||
v.SetNull()
|
||||
return true
|
||||
}
|
||||
var err error
|
||||
@ -719,7 +670,7 @@ func (e *Evaluator) aggregateFunc(v *ast.AggregateFuncExpr) bool {
|
||||
|
||||
func (e *Evaluator) evalAggCount(v *ast.AggregateFuncExpr) {
|
||||
ctx := v.GetContext()
|
||||
v.SetValue(ctx.Count)
|
||||
v.SetInt64(ctx.Count)
|
||||
}
|
||||
|
||||
func (e *Evaluator) evalAggSetValue(v *ast.AggregateFuncExpr) {
|
||||
|
||||
@ -34,6 +34,16 @@ type subquery struct {
|
||||
is infoschema.InfoSchema
|
||||
}
|
||||
|
||||
// SetDatum implements Expression interface.
|
||||
func (sq *subquery) SetDatum(datum types.Datum) {
|
||||
sq.Datum = datum
|
||||
}
|
||||
|
||||
// GetDatum implements Expression interface.
|
||||
func (sq *subquery) GetDatum() *types.Datum {
|
||||
return &sq.Datum
|
||||
}
|
||||
|
||||
// SetFlag implements Expression interface.
|
||||
func (sq *subquery) SetFlag(flag uint64) {
|
||||
sq.flag = flag
|
||||
|
||||
@ -973,6 +973,120 @@ func (d *Datum) convertToMysqlSet(target *FieldType) (Datum, error) {
|
||||
return ret, nil
|
||||
}
|
||||
|
||||
// ToBool converts to a bool.
|
||||
// We will use 1 for true, and 0 for false.
|
||||
func (d *Datum) ToBool() (int64, error) {
|
||||
isZero := false
|
||||
switch d.Kind() {
|
||||
case KindInt64:
|
||||
isZero = (d.GetInt64() == 0)
|
||||
case KindUint64:
|
||||
isZero = (d.GetUint64() == 0)
|
||||
case KindFloat32:
|
||||
isZero = (d.GetFloat32() == 0)
|
||||
case KindFloat64:
|
||||
isZero = (d.GetFloat64() == 0)
|
||||
case KindString:
|
||||
s := d.GetString()
|
||||
if len(s) == 0 {
|
||||
isZero = true
|
||||
}
|
||||
n, err := StrToInt(s)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
isZero = (n == 0)
|
||||
case KindBytes:
|
||||
bs := d.GetBytes()
|
||||
if len(bs) == 0 {
|
||||
isZero = true
|
||||
} else {
|
||||
n, err := StrToInt(string(bs))
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
isZero = (n == 0)
|
||||
}
|
||||
case KindMysqlTime:
|
||||
isZero = d.GetMysqlTime().IsZero()
|
||||
case KindMysqlDuration:
|
||||
isZero = (d.GetMysqlDuration().Duration == 0)
|
||||
case KindMysqlDecimal:
|
||||
v, _ := d.GetMysqlDecimal().Float64()
|
||||
isZero = (v == 0)
|
||||
case KindMysqlHex:
|
||||
isZero = (d.GetMysqlHex().ToNumber() == 0)
|
||||
case KindMysqlBit:
|
||||
isZero = (d.GetMysqlBit().ToNumber() == 0)
|
||||
case KindMysqlEnum:
|
||||
isZero = (d.GetMysqlEnum().ToNumber() == 0)
|
||||
case KindMysqlSet:
|
||||
isZero = (d.GetMysqlSet().ToNumber() == 0)
|
||||
default:
|
||||
return 0, errors.Errorf("cannot convert %v(type %T) to bool", d.GetValue(), d.GetValue())
|
||||
}
|
||||
if isZero {
|
||||
return 0, nil
|
||||
}
|
||||
return 1, nil
|
||||
}
|
||||
|
||||
// ToInt64 converts to a int64.
|
||||
func (d *Datum) ToInt64() (int64, error) {
|
||||
tp := mysql.TypeLonglong
|
||||
lowerBound := signedLowerBound[tp]
|
||||
upperBound := signedUpperBound[tp]
|
||||
switch d.Kind() {
|
||||
case KindInt64:
|
||||
return convertIntToInt(d.GetInt64(), lowerBound, upperBound, tp)
|
||||
case KindUint64:
|
||||
return convertUintToInt(d.GetUint64(), upperBound, tp)
|
||||
case KindFloat32:
|
||||
return convertFloatToInt(float64(d.GetFloat32()), lowerBound, upperBound, tp)
|
||||
case KindFloat64:
|
||||
return convertFloatToInt(d.GetFloat64(), lowerBound, upperBound, tp)
|
||||
case KindString:
|
||||
s := d.GetString()
|
||||
fval, err := StrToFloat(s)
|
||||
if err != nil {
|
||||
return 0, errors.Trace(err)
|
||||
}
|
||||
return convertFloatToInt(fval, lowerBound, upperBound, tp)
|
||||
case KindBytes:
|
||||
s := string(d.GetBytes())
|
||||
fval, err := StrToFloat(s)
|
||||
if err != nil {
|
||||
return 0, errors.Trace(err)
|
||||
}
|
||||
return convertFloatToInt(fval, lowerBound, upperBound, tp)
|
||||
case KindMysqlTime:
|
||||
// 2011-11-10 11:11:11.999999 -> 20111110111112
|
||||
ival := d.GetMysqlTime().ToNumber().Round(0).IntPart()
|
||||
return convertIntToInt(ival, lowerBound, upperBound, tp)
|
||||
case KindMysqlDuration:
|
||||
// 11:11:11.999999 -> 111112
|
||||
ival := d.GetMysqlDuration().ToNumber().Round(0).IntPart()
|
||||
return convertIntToInt(ival, lowerBound, upperBound, tp)
|
||||
case KindMysqlDecimal:
|
||||
fval, _ := d.GetMysqlDecimal().Float64()
|
||||
return convertFloatToInt(fval, lowerBound, upperBound, tp)
|
||||
case KindMysqlHex:
|
||||
fval := d.GetMysqlHex().ToNumber()
|
||||
return convertFloatToInt(fval, lowerBound, upperBound, tp)
|
||||
case KindMysqlBit:
|
||||
fval := d.GetMysqlBit().ToNumber()
|
||||
return convertFloatToInt(fval, lowerBound, upperBound, tp)
|
||||
case KindMysqlEnum:
|
||||
fval := d.GetMysqlEnum().ToNumber()
|
||||
return convertFloatToInt(fval, lowerBound, upperBound, tp)
|
||||
case KindMysqlSet:
|
||||
fval := d.GetMysqlSet().ToNumber()
|
||||
return convertFloatToInt(fval, lowerBound, upperBound, tp)
|
||||
default:
|
||||
return 0, errors.Errorf("cannot convert %v(type %T) to int64", d.GetValue(), d.GetValue())
|
||||
}
|
||||
}
|
||||
|
||||
func invalidConv(d *Datum, tp byte) (Datum, error) {
|
||||
return Datum{}, errors.Errorf("cannot convert %v to type %s", d, TypeStr(tp))
|
||||
}
|
||||
|
||||
@ -15,6 +15,7 @@ package types
|
||||
|
||||
import (
|
||||
. "github.com/pingcap/check"
|
||||
"github.com/pingcap/tidb/mysql"
|
||||
)
|
||||
|
||||
var _ = Suite(&testDatumSuite{})
|
||||
@ -38,3 +39,80 @@ func (ts *testDatumSuite) TestDatum(c *C) {
|
||||
c.Assert(x, DeepEquals, val)
|
||||
}
|
||||
}
|
||||
|
||||
func testDatumToBool(c *C, in interface{}, res int) {
|
||||
datum := NewDatum(in)
|
||||
res64 := int64(res)
|
||||
b, err := datum.ToBool()
|
||||
c.Assert(err, IsNil)
|
||||
c.Assert(b, Equals, res64)
|
||||
}
|
||||
func (ts *testDatumSuite) TestToBool(c *C) {
|
||||
testDatumToBool(c, int(0), 0)
|
||||
testDatumToBool(c, int64(0), 0)
|
||||
testDatumToBool(c, uint64(0), 0)
|
||||
testDatumToBool(c, float32(0), 0)
|
||||
testDatumToBool(c, float64(0), 0)
|
||||
testDatumToBool(c, "", 0)
|
||||
testDatumToBool(c, "0", 0)
|
||||
testDatumToBool(c, []byte{}, 0)
|
||||
testDatumToBool(c, []byte("0"), 0)
|
||||
testDatumToBool(c, mysql.Hex{Value: 0}, 0)
|
||||
testDatumToBool(c, mysql.Bit{Value: 0, Width: 8}, 0)
|
||||
testDatumToBool(c, mysql.Enum{Name: "a", Value: 1}, 1)
|
||||
testDatumToBool(c, mysql.Set{Name: "a", Value: 1}, 1)
|
||||
|
||||
t, err := mysql.ParseTime("2011-11-10 11:11:11.999999", mysql.TypeTimestamp, 6)
|
||||
c.Assert(err, IsNil)
|
||||
testDatumToBool(c, t, 1)
|
||||
|
||||
td, err := mysql.ParseDuration("11:11:11.999999", 6)
|
||||
c.Assert(err, IsNil)
|
||||
testDatumToBool(c, td, 1)
|
||||
|
||||
ft := NewFieldType(mysql.TypeNewDecimal)
|
||||
ft.Decimal = 5
|
||||
v, err := Convert(3.1415926, ft)
|
||||
c.Assert(err, IsNil)
|
||||
testDatumToBool(c, v, 1)
|
||||
d := NewDatum(&invalidMockType{})
|
||||
_, err = d.ToBool()
|
||||
c.Assert(err, NotNil)
|
||||
}
|
||||
|
||||
func testDatumToInt64(c *C, val interface{}, expect int64) {
|
||||
d := NewDatum(val)
|
||||
b, err := d.ToInt64()
|
||||
c.Assert(err, IsNil)
|
||||
c.Assert(b, Equals, expect)
|
||||
}
|
||||
|
||||
func (ts *testTypeConvertSuite) TestToInt64(c *C) {
|
||||
testDatumToInt64(c, "0", int64(0))
|
||||
testDatumToInt64(c, int(0), int64(0))
|
||||
testDatumToInt64(c, int64(0), int64(0))
|
||||
testDatumToInt64(c, uint64(0), int64(0))
|
||||
testDatumToInt64(c, float32(3.1), int64(3))
|
||||
testDatumToInt64(c, float64(3.1), int64(3))
|
||||
testDatumToInt64(c, mysql.Hex{Value: 100}, int64(100))
|
||||
testDatumToInt64(c, mysql.Bit{Value: 100, Width: 8}, int64(100))
|
||||
testDatumToInt64(c, mysql.Enum{Name: "a", Value: 1}, int64(1))
|
||||
testDatumToInt64(c, mysql.Set{Name: "a", Value: 1}, int64(1))
|
||||
|
||||
t, err := mysql.ParseTime("2011-11-10 11:11:11.999999", mysql.TypeTimestamp, 0)
|
||||
c.Assert(err, IsNil)
|
||||
testDatumToInt64(c, t, int64(20111110111112))
|
||||
|
||||
td, err := mysql.ParseDuration("11:11:11.999999", 6)
|
||||
c.Assert(err, IsNil)
|
||||
testDatumToInt64(c, td, int64(111112))
|
||||
|
||||
ft := NewFieldType(mysql.TypeNewDecimal)
|
||||
ft.Decimal = 5
|
||||
v, err := Convert(3.1415926, ft)
|
||||
c.Assert(err, IsNil)
|
||||
testDatumToInt64(c, v, int64(3))
|
||||
|
||||
_, err = ToInt64(&invalidMockType{})
|
||||
c.Assert(err, NotNil)
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user