From 5a81a2f5d4c1398899b5fb70caa670d2325c544b Mon Sep 17 00:00:00 2001 From: zhaoxingyu Date: Fri, 11 Mar 2016 13:37:58 +0800 Subject: [PATCH 1/3] evaluator: change Get/SetValue to Get/SetDatum. Chnage Get/SetValue to Get/SetDatum in evaluator.go . --- ast/ast.go | 4 + ast/base.go | 10 ++ evaluator/evaluator.go | 241 ++++++++++++++++------------------------- executor/subquery.go | 10 ++ util/types/datum.go | 13 +++ 5 files changed, 133 insertions(+), 145 deletions(-) diff --git a/ast/ast.go b/ast/ast.go index e96031f732..f52da910cd 100644 --- a/ast/ast.go +++ b/ast/ast.go @@ -64,6 +64,10 @@ type ExprNode interface { SetValue(val interface{}) // GetValue gets value of the expression. GetValue() interface{} + // SetDatum set datum to the expression. + SetDatum(datum *types.Datum) + // GetDatum get datum of the expression. + GetDatum() *types.Datum // SetFlag sets flag to the expression. // Flag indicates whether the expression contains // parameter marker, reference, aggregate function... diff --git a/ast/base.go b/ast/base.go index e98a7a0512..a749254c2d 100644 --- a/ast/base.go +++ b/ast/base.go @@ -67,6 +67,16 @@ type exprNode struct { flag uint64 } +// SetDatum implements Expression interface. +func (en *exprNode) SetDatum(datum *types.Datum) { + en.Datum = *datum +} + +// GetDatum implements Expression interface. +func (en *exprNode) GetDatum() *types.Datum { + return &en.Datum +} + // SetType implements Expression interface. func (en *exprNode) SetType(tp *types.FieldType) { en.Type = tp diff --git a/evaluator/evaluator.go b/evaluator/evaluator.go index bfe9494891..707d405211 100644 --- a/evaluator/evaluator.go +++ b/evaluator/evaluator.go @@ -178,38 +178,39 @@ func (e *Evaluator) between(v *ast.BetweenExpr) bool { if e.err != nil { return false } - v.SetValue(ret.GetValue()) + v.SetDatum(ret.GetDatum()) return true } func (e *Evaluator) caseExpr(v *ast.CaseExpr) bool { - var target interface{} = true + tmp := types.NewDatum(boolToInt64(true)) + target := &tmp if v.Value != nil { - target = v.Value.GetValue() + target = v.Value.GetDatum() } - if target != nil { + if target.Kind() != types.KindNull { for _, val := range v.WhenClauses { - cmp, err := types.Compare(target, val.Expr.GetValue()) + cmp, err := target.CompareDatum(*val.Expr.GetDatum()) if err != nil { e.err = errors.Trace(err) return false } if cmp == 0 { - v.SetValue(val.Result.GetValue()) + v.SetDatum(val.Result.GetDatum()) return true } } } if v.ElseClause != nil { - v.SetValue(v.ElseClause.GetValue()) + v.SetDatum(v.ElseClause.GetDatum()) } else { - v.SetValue(nil) + v.SetNull() } return true } func (e *Evaluator) columnName(v *ast.ColumnNameExpr) bool { - v.SetValue(v.Refer.Expr.GetValue()) + v.SetDatum(v.Refer.Expr.GetDatum()) return true } @@ -218,11 +219,12 @@ func (e *Evaluator) defaultExpr(v *ast.DefaultExpr) bool { } func (e *Evaluator) compareSubquery(cs *ast.CompareSubqueryExpr) bool { - lv := cs.L.GetValue() - if lv == nil { - cs.SetValue(nil) + lvDatum := cs.L.GetDatum() + if lvDatum.Kind() == types.KindNull { + cs.SetNull() return true } + lv := lvDatum.GetValue() x, err := e.checkResult(cs, lv, cs.R.GetValue().([]interface{})) if err != nil { e.err = errors.Trace(err) @@ -302,17 +304,17 @@ func (e *Evaluator) checkAnyResult(cs *ast.CompareSubqueryExpr, lv interface{}, } func (e *Evaluator) existsSubquery(v *ast.ExistsSubqueryExpr) bool { - r := v.Sel.GetValue() - if r == nil { - v.SetValue(0) + datum := v.Sel.GetDatum() + if datum.Kind() == types.KindNull { + v.SetInt64(0) return true } + r := datum.GetValue() rows, _ := r.([]interface{}) if len(rows) > 0 { - v.SetValue(1) + v.SetInt64(1) } else { - v.SetValue(0) - + v.SetInt64(0) } return true } @@ -321,7 +323,7 @@ func (e *Evaluator) existsSubquery(v *ast.ExistsSubqueryExpr) bool { // Get the value from v.SubQuery and set it to v. func (e *Evaluator) subqueryExpr(v *ast.SubqueryExpr) bool { if v.SubqueryExec != nil { - v.SetValue(v.SubqueryExec.GetValue()) + v.SetDatum(v.SubqueryExec.GetDatum()) } v.Evaluated = true return true @@ -346,7 +348,7 @@ func (e *Evaluator) subqueryExec(v ast.SubqueryExec) bool { } switch len(rows) { case 0: - v.SetValue(nil) + v.GetDatum().SetNull() case 1: v.SetValue(rows[0]) default: @@ -389,9 +391,9 @@ func (e *Evaluator) checkInList(not bool, in interface{}, list []interface{}) in } func (e *Evaluator) patternIn(n *ast.PatternInExpr) bool { - lhs := n.Expr.GetValue() - if lhs == nil { - n.SetValue(nil) + lhs := n.Expr.GetDatum() + if lhs.Kind() == types.KindNull { + n.SetNull() return true } if n.Sel == nil { @@ -399,7 +401,7 @@ func (e *Evaluator) patternIn(n *ast.PatternInExpr) bool { for _, ei := range n.List { values = append(values, ei.GetValue()) } - x := e.checkInList(n.Not, lhs, values) + x := e.checkInList(n.Not, lhs.GetValue(), values) if e.err != nil { return false } @@ -410,7 +412,7 @@ func (e *Evaluator) patternIn(n *ast.PatternInExpr) bool { sel := se.SubqueryExec res := sel.GetValue().([]interface{}) - x := e.checkInList(n.Not, lhs, res) + x := e.checkInList(n.Not, lhs.GetValue(), res) if e.err != nil { return false } @@ -420,21 +422,21 @@ func (e *Evaluator) patternIn(n *ast.PatternInExpr) bool { func (e *Evaluator) isNull(v *ast.IsNullExpr) bool { var boolVal bool - if v.Expr.GetValue() == nil { + if v.Expr.GetDatum().Kind() == types.KindNull { boolVal = true } if v.Not { boolVal = !boolVal } - v.SetValue(boolToInt64(boolVal)) + v.SetInt64(boolToInt64(boolVal)) return true } func (e *Evaluator) isTruth(v *ast.IsTruthExpr) bool { var boolVal bool - val := v.Expr.GetValue() - if val != nil { - ival, err := types.ToBool(val) + datum := v.Expr.GetDatum() + if datum.Kind() != types.KindNull { + ival, err := datum.ToBool() if err != nil { e.err = errors.Trace(err) return false @@ -446,7 +448,7 @@ func (e *Evaluator) isTruth(v *ast.IsTruthExpr) bool { if v.Not { boolVal = !boolVal } - v.SetValue(boolToInt64(boolVal)) + v.GetDatum().SetInt64(boolToInt64(boolVal)) return true } @@ -455,12 +457,12 @@ func (e *Evaluator) paramMarker(v *ast.ParamMarkerExpr) bool { } func (e *Evaluator) parentheses(v *ast.ParenthesesExpr) bool { - v.SetValue(v.Expr.GetValue()) + v.SetDatum(v.Expr.GetDatum()) return true } func (e *Evaluator) position(v *ast.PositionExpr) bool { - v.SetValue(v.Refer.Expr.GetValue()) + v.SetDatum(v.Refer.Expr.GetDatum()) return true } @@ -479,135 +481,82 @@ func (e *Evaluator) unaryOperation(u *ast.UnaryOperationExpr) bool { e.err = errors.Errorf("%v", er) } }() - a := u.V.GetValue() - if a == nil { - u.SetValue(nil) + aDatum := u.V.GetDatum() + if aDatum.Kind() == types.KindNull { + u.SetNull() return true } switch op := u.Op; op { case opcode.Not: - n, err := types.ToBool(a) + n, err := aDatum.ToBool() if err != nil { e.err = errors.Trace(err) } else if n == 0 { - u.SetValue(int64(1)) + u.SetInt64(1) } else { - u.SetValue(int64(0)) + u.SetInt64(0) } case opcode.BitNeg: // for bit operation, we will use int64 first, then return uint64 - n, err := types.ToInt64(a) + n, err := aDatum.ToInt64() if err != nil { e.err = errors.Trace(err) return false } - u.SetValue(uint64(^n)) + u.SetUint64(uint64(^n)) case opcode.Plus: - switch x := a.(type) { - case bool: - u.SetValue(boolToInt64(x)) - case float32: - u.SetValue(+x) - case float64: - u.SetValue(+x) - case int: - u.SetValue(+x) - case int8: - u.SetValue(+x) - case int16: - u.SetValue(+x) - case int32: - u.SetValue(+x) - case int64: - u.SetValue(+x) - case uint: - u.SetValue(+x) - case uint8: - u.SetValue(+x) - case uint16: - u.SetValue(+x) - case uint32: - u.SetValue(+x) - case uint64: - u.SetValue(+x) - case mysql.Duration: - u.SetValue(x) - case mysql.Time: - u.SetValue(x) - case string: - u.SetValue(x) - case mysql.Decimal: - u.SetValue(x) - case []byte: - u.SetValue(x) - case mysql.Hex: - u.SetValue(x) - case mysql.Bit: - u.SetValue(x) - case mysql.Enum: - u.SetValue(x) - case mysql.Set: - u.SetValue(x) + switch aDatum.Kind() { + case types.KindInt64, + types.KindUint64, + types.KindFloat64, + types.KindFloat32, + types.KindMysqlDuration, + types.KindMysqlTime, + types.KindString, + types.KindMysqlDecimal, + types.KindBytes, + types.KindMysqlHex, + types.KindMysqlBit, + types.KindMysqlEnum, + types.KindMysqlSet: + u.SetDatum(aDatum) default: e.err = ErrInvalidOperation return false } case opcode.Minus: - switch x := a.(type) { - case bool: - if x { - u.SetValue(int64(-1)) - } else { - u.SetValue(int64(0)) - } - case float32: - u.SetValue(-x) - case float64: - u.SetValue(-x) - case int: - u.SetValue(-x) - case int8: - u.SetValue(-x) - case int16: - u.SetValue(-x) - case int32: - u.SetValue(-x) - case int64: - u.SetValue(-x) - case uint: - u.SetValue(-int64(x)) - case uint8: - u.SetValue(-int64(x)) - case uint16: - u.SetValue(-int64(x)) - case uint32: - u.SetValue(-int64(x)) - case uint64: - // TODO: check overflow and do more test for unsigned type - u.SetValue(-int64(x)) - case mysql.Duration: - u.SetValue(mysql.ZeroDecimal.Sub(x.ToNumber())) - case mysql.Time: - u.SetValue(mysql.ZeroDecimal.Sub(x.ToNumber())) - case string: - f, err := types.StrToFloat(x) + switch aDatum.Kind() { + case types.KindInt64: + u.SetInt64(-aDatum.GetInt64()) + case types.KindUint64: + u.SetInt64(-int64(aDatum.GetUint64())) + case types.KindFloat64: + u.SetFloat64(-aDatum.GetFloat64()) + case types.KindFloat32: + u.SetFloat32(-aDatum.GetFloat32()) + case types.KindMysqlDuration: + u.SetValue(mysql.ZeroDecimal.Sub(aDatum.GetMysqlDuration().ToNumber())) + case types.KindMysqlTime: + u.SetValue(mysql.ZeroDecimal.Sub(aDatum.GetMysqlTime().ToNumber())) + case types.KindString: + f, err := types.StrToFloat(aDatum.GetString()) e.err = errors.Trace(err) - u.SetValue(-f) - case mysql.Decimal: - f, _ := x.Float64() + u.SetFloat64(-f) + case types.KindMysqlDecimal: + f, _ := aDatum.GetMysqlDecimal().Float64() u.SetValue(mysql.NewDecimalFromFloat(-f)) - case []byte: - f, err := types.StrToFloat(string(x)) + case types.KindBytes: + f, err := types.StrToFloat(string(aDatum.GetBytes())) e.err = errors.Trace(err) - u.SetValue(-f) - case mysql.Hex: - u.SetValue(-x.ToNumber()) - case mysql.Bit: - u.SetValue(-x.ToNumber()) - case mysql.Enum: - u.SetValue(-x.ToNumber()) - case mysql.Set: - u.SetValue(-x.ToNumber()) + u.SetFloat64(-f) + case types.KindMysqlHex: + u.SetFloat64(-aDatum.GetMysqlHex().ToNumber()) + case types.KindMysqlBit: + u.SetFloat64(-aDatum.GetMysqlBit().ToNumber()) + case types.KindMysqlEnum: + u.SetFloat64(-aDatum.GetMysqlEnum().ToNumber()) + case types.KindMysqlSet: + u.SetFloat64(-aDatum.GetMysqlSet().ToNumber()) default: e.err = ErrInvalidOperation return false @@ -621,7 +570,7 @@ func (e *Evaluator) unaryOperation(u *ast.UnaryOperationExpr) bool { } func (e *Evaluator) values(v *ast.ValuesExpr) bool { - v.SetValue(v.Column.GetValue()) + v.SetDatum(v.Column.GetDatum()) return true } @@ -632,11 +581,11 @@ func (e *Evaluator) variable(v *ast.VariableExpr) bool { if !v.IsSystem { // user vars if value, ok := sessionVars.Users[name]; ok { - v.SetValue(value) + v.SetString(value) return true } // select null user vars is permitted. - v.SetValue(nil) + v.SetNull() return true } @@ -649,16 +598,18 @@ func (e *Evaluator) variable(v *ast.VariableExpr) bool { if !v.IsGlobal { if value, ok := sessionVars.Systems[name]; ok { - v.SetValue(value) + v.SetString(value) return true } } + value, err := globalVars.GetGlobalSysVar(e.ctx, name) if err != nil { e.err = errors.Trace(err) return false } - v.SetValue(value) + + v.SetString(value) return true } @@ -689,7 +640,7 @@ func (e *Evaluator) funcCast(v *ast.FuncCastExpr) bool { value := v.Expr.GetValue() // Casting nil to any type returns null if value == nil { - v.SetValue(nil) + v.SetNull() return true } var err error @@ -719,7 +670,7 @@ func (e *Evaluator) aggregateFunc(v *ast.AggregateFuncExpr) bool { func (e *Evaluator) evalAggCount(v *ast.AggregateFuncExpr) { ctx := v.GetContext() - v.SetValue(ctx.Count) + v.SetInt64(ctx.Count) } func (e *Evaluator) evalAggSetValue(v *ast.AggregateFuncExpr) { diff --git a/executor/subquery.go b/executor/subquery.go index fc5c8bdf39..4ec9b6610c 100644 --- a/executor/subquery.go +++ b/executor/subquery.go @@ -34,6 +34,16 @@ type subquery struct { is infoschema.InfoSchema } +// SetDatum implements Expression interface. +func (sq *subquery) SetDatum(datum *types.Datum) { + sq.Datum = *datum +} + +// GetDatum implements Expression interface. +func (sq *subquery) GetDatum() *types.Datum { + return &sq.Datum +} + // SetFlag implements Expression interface. func (sq *subquery) SetFlag(flag uint64) { sq.flag = flag diff --git a/util/types/datum.go b/util/types/datum.go index 8a8754ea07..62b255fa76 100644 --- a/util/types/datum.go +++ b/util/types/datum.go @@ -973,6 +973,19 @@ func (d *Datum) convertToMysqlSet(target *FieldType) (Datum, error) { return ret, nil } +// ToBool converts to a bool. +// We will use 1 for true, and 0 for false. +func (d *Datum) ToBool() (int64, error) { + // TODO: Make it do not use GatValue + return ToBool(d.GetValue()) +} + +// ToInt64 converts to a bool. +func (d *Datum) ToInt64() (int64, error) { + // TODO: Make it do not use GatValue + return ToInt64(d.GetValue()) +} + func invalidConv(d *Datum, tp byte) (Datum, error) { return Datum{}, errors.Errorf("cannot convert %v to type %s", d, TypeStr(tp)) } From 0fa8df585a1d0f4fa0e96673cd5d11328dc69e74 Mon Sep 17 00:00:00 2001 From: zhaoxingyu Date: Fri, 11 Mar 2016 19:58:40 +0800 Subject: [PATCH 2/3] evaluator: change SetDatum and implement ToBool, ToInt Change SetDatum and implement ToBool, ToInt. --- ast/ast.go | 6 +-- ast/base.go | 4 +- evaluator/evaluator.go | 18 +++---- executor/subquery.go | 4 +- util/types/datum.go | 109 +++++++++++++++++++++++++++++++++++++-- util/types/datum_test.go | 78 ++++++++++++++++++++++++++++ 6 files changed, 199 insertions(+), 20 deletions(-) diff --git a/ast/ast.go b/ast/ast.go index f52da910cd..c726e4b6ba 100644 --- a/ast/ast.go +++ b/ast/ast.go @@ -64,9 +64,9 @@ type ExprNode interface { SetValue(val interface{}) // GetValue gets value of the expression. GetValue() interface{} - // SetDatum set datum to the expression. - SetDatum(datum *types.Datum) - // GetDatum get datum of the expression. + // SetDatum sets datum to the expression. + SetDatum(datum types.Datum) + // GetDatum gets datum of the expression. GetDatum() *types.Datum // SetFlag sets flag to the expression. // Flag indicates whether the expression contains diff --git a/ast/base.go b/ast/base.go index a749254c2d..1415f64df2 100644 --- a/ast/base.go +++ b/ast/base.go @@ -68,8 +68,8 @@ type exprNode struct { } // SetDatum implements Expression interface. -func (en *exprNode) SetDatum(datum *types.Datum) { - en.Datum = *datum +func (en *exprNode) SetDatum(datum types.Datum) { + en.Datum = datum } // GetDatum implements Expression interface. diff --git a/evaluator/evaluator.go b/evaluator/evaluator.go index 707d405211..67243703fb 100644 --- a/evaluator/evaluator.go +++ b/evaluator/evaluator.go @@ -178,7 +178,7 @@ func (e *Evaluator) between(v *ast.BetweenExpr) bool { if e.err != nil { return false } - v.SetDatum(ret.GetDatum()) + v.SetDatum(*ret.GetDatum()) return true } @@ -196,13 +196,13 @@ func (e *Evaluator) caseExpr(v *ast.CaseExpr) bool { return false } if cmp == 0 { - v.SetDatum(val.Result.GetDatum()) + v.SetDatum(*val.Result.GetDatum()) return true } } } if v.ElseClause != nil { - v.SetDatum(v.ElseClause.GetDatum()) + v.SetDatum(*v.ElseClause.GetDatum()) } else { v.SetNull() } @@ -210,7 +210,7 @@ func (e *Evaluator) caseExpr(v *ast.CaseExpr) bool { } func (e *Evaluator) columnName(v *ast.ColumnNameExpr) bool { - v.SetDatum(v.Refer.Expr.GetDatum()) + v.SetDatum(*v.Refer.Expr.GetDatum()) return true } @@ -323,7 +323,7 @@ func (e *Evaluator) existsSubquery(v *ast.ExistsSubqueryExpr) bool { // Get the value from v.SubQuery and set it to v. func (e *Evaluator) subqueryExpr(v *ast.SubqueryExpr) bool { if v.SubqueryExec != nil { - v.SetDatum(v.SubqueryExec.GetDatum()) + v.SetDatum(*v.SubqueryExec.GetDatum()) } v.Evaluated = true return true @@ -457,12 +457,12 @@ func (e *Evaluator) paramMarker(v *ast.ParamMarkerExpr) bool { } func (e *Evaluator) parentheses(v *ast.ParenthesesExpr) bool { - v.SetDatum(v.Expr.GetDatum()) + v.SetDatum(*v.Expr.GetDatum()) return true } func (e *Evaluator) position(v *ast.PositionExpr) bool { - v.SetDatum(v.Refer.Expr.GetDatum()) + v.SetDatum(*v.Refer.Expr.GetDatum()) return true } @@ -519,7 +519,7 @@ func (e *Evaluator) unaryOperation(u *ast.UnaryOperationExpr) bool { types.KindMysqlBit, types.KindMysqlEnum, types.KindMysqlSet: - u.SetDatum(aDatum) + u.SetDatum(*aDatum) default: e.err = ErrInvalidOperation return false @@ -570,7 +570,7 @@ func (e *Evaluator) unaryOperation(u *ast.UnaryOperationExpr) bool { } func (e *Evaluator) values(v *ast.ValuesExpr) bool { - v.SetDatum(v.Column.GetDatum()) + v.SetDatum(*v.Column.GetDatum()) return true } diff --git a/executor/subquery.go b/executor/subquery.go index 4ec9b6610c..d120011122 100644 --- a/executor/subquery.go +++ b/executor/subquery.go @@ -35,8 +35,8 @@ type subquery struct { } // SetDatum implements Expression interface. -func (sq *subquery) SetDatum(datum *types.Datum) { - sq.Datum = *datum +func (sq *subquery) SetDatum(datum types.Datum) { + sq.Datum = datum } // GetDatum implements Expression interface. diff --git a/util/types/datum.go b/util/types/datum.go index 62b255fa76..ef7d807848 100644 --- a/util/types/datum.go +++ b/util/types/datum.go @@ -976,14 +976,115 @@ func (d *Datum) convertToMysqlSet(target *FieldType) (Datum, error) { // ToBool converts to a bool. // We will use 1 for true, and 0 for false. func (d *Datum) ToBool() (int64, error) { - // TODO: Make it do not use GatValue - return ToBool(d.GetValue()) + isZero := false + switch d.Kind() { + case KindInt64: + isZero = (d.GetInt64() == 0) + case KindUint64: + isZero = (d.GetUint64() == 0) + case KindFloat32: + isZero = (d.GetFloat32() == 0) + case KindFloat64: + isZero = (d.GetFloat64() == 0) + case KindString: + s := d.GetString() + if len(s) == 0 { + isZero = true + } + n, err := StrToInt(s) + if err != nil { + return 0, err + } + isZero = (n == 0) + case KindBytes: + bs := d.GetBytes() + if len(bs) == 0 { + isZero = true + } else { + n, err := StrToInt(string(bs)) + if err != nil { + return 0, err + } + isZero = (n == 0) + } + case KindMysqlTime: + isZero = d.GetMysqlTime().IsZero() + case KindMysqlDuration: + isZero = (d.GetMysqlDuration().Duration == 0) + case KindMysqlDecimal: + v, _ := d.GetMysqlDecimal().Float64() + isZero = (v == 0) + case KindMysqlHex: + isZero = (d.GetMysqlHex().ToNumber() == 0) + case KindMysqlBit: + isZero = (d.GetMysqlBit().ToNumber() == 0) + case KindMysqlEnum: + isZero = (d.GetMysqlEnum().ToNumber() == 0) + case KindMysqlSet: + isZero = (d.GetMysqlSet().ToNumber() == 0) + default: + return 0, errors.Errorf("cannot convert %v(type %T) to bool", d.GetValue(), d.GetValue()) + } + if isZero { + return 0, nil + } + return 1, nil } // ToInt64 converts to a bool. func (d *Datum) ToInt64() (int64, error) { - // TODO: Make it do not use GatValue - return ToInt64(d.GetValue()) + tp := mysql.TypeLonglong + lowerBound := signedLowerBound[tp] + upperBound := signedUpperBound[tp] + switch d.Kind() { + case KindInt64: + return convertIntToInt(d.GetInt64(), lowerBound, upperBound, tp) + case KindUint64: + return convertUintToInt(d.GetUint64(), upperBound, tp) + case KindFloat32: + return convertFloatToInt(float64(d.GetFloat32()), lowerBound, upperBound, tp) + case KindFloat64: + return convertFloatToInt(d.GetFloat64(), lowerBound, upperBound, tp) + case KindString: + s := d.GetString() + fval, err := StrToFloat(s) + if err != nil { + return 0, errors.Trace(err) + } + return convertFloatToInt(fval, lowerBound, upperBound, tp) + case KindBytes: + s := string(d.GetBytes()) + fval, err := StrToFloat(s) + if err != nil { + return 0, errors.Trace(err) + } + return convertFloatToInt(fval, lowerBound, upperBound, tp) + case KindMysqlTime: + // 2011-11-10 11:11:11.999999 -> 20111110111112 + ival := d.GetMysqlTime().ToNumber().Round(0).IntPart() + return convertIntToInt(ival, lowerBound, upperBound, tp) + case KindMysqlDuration: + // 11:11:11.999999 -> 111112 + ival := d.GetMysqlDuration().ToNumber().Round(0).IntPart() + return convertIntToInt(ival, lowerBound, upperBound, tp) + case KindMysqlDecimal: + fval, _ := d.GetMysqlDecimal().Float64() + return convertFloatToInt(fval, lowerBound, upperBound, tp) + case KindMysqlHex: + fval := d.GetMysqlHex().ToNumber() + return convertFloatToInt(fval, lowerBound, upperBound, tp) + case KindMysqlBit: + fval := d.GetMysqlBit().ToNumber() + return convertFloatToInt(fval, lowerBound, upperBound, tp) + case KindMysqlEnum: + fval := d.GetMysqlEnum().ToNumber() + return convertFloatToInt(fval, lowerBound, upperBound, tp) + case KindMysqlSet: + fval := d.GetMysqlSet().ToNumber() + return convertFloatToInt(fval, lowerBound, upperBound, tp) + default: + return 0, errors.Errorf("cannot convert %v(type %T) to int64", d.GetValue(), d.GetValue()) + } } func invalidConv(d *Datum, tp byte) (Datum, error) { diff --git a/util/types/datum_test.go b/util/types/datum_test.go index 3d01cf9da2..2da774854b 100644 --- a/util/types/datum_test.go +++ b/util/types/datum_test.go @@ -15,6 +15,7 @@ package types import ( . "github.com/pingcap/check" + "github.com/pingcap/tidb/mysql" ) var _ = Suite(&testDatumSuite{}) @@ -38,3 +39,80 @@ func (ts *testDatumSuite) TestDatum(c *C) { c.Assert(x, DeepEquals, val) } } + +func testDatumToBool(c *C, in interface{}, res int) { + datum := NewDatum(in) + res64 := int64(res) + b, err := datum.ToBool() + c.Assert(err, IsNil) + c.Assert(b, Equals, res64) +} +func (ts *testDatumSuite) TestToBool(c *C) { + testDatumToBool(c, int(0), 0) + testDatumToBool(c, int64(0), 0) + testDatumToBool(c, uint64(0), 0) + testDatumToBool(c, float32(0), 0) + testDatumToBool(c, float64(0), 0) + testDatumToBool(c, "", 0) + testDatumToBool(c, "0", 0) + testDatumToBool(c, []byte{}, 0) + testDatumToBool(c, []byte("0"), 0) + testDatumToBool(c, mysql.Hex{Value: 0}, 0) + testDatumToBool(c, mysql.Bit{Value: 0, Width: 8}, 0) + testDatumToBool(c, mysql.Enum{Name: "a", Value: 1}, 1) + testDatumToBool(c, mysql.Set{Name: "a", Value: 1}, 1) + + t, err := mysql.ParseTime("2011-11-10 11:11:11.999999", mysql.TypeTimestamp, 6) + c.Assert(err, IsNil) + testDatumToBool(c, t, 1) + + td, err := mysql.ParseDuration("11:11:11.999999", 6) + c.Assert(err, IsNil) + testDatumToBool(c, td, 1) + + ft := NewFieldType(mysql.TypeNewDecimal) + ft.Decimal = 5 + v, err := Convert(3.1415926, ft) + c.Assert(err, IsNil) + testDatumToBool(c, v, 1) + d := NewDatum(&invalidMockType{}) + _, err = d.ToBool() + c.Assert(err, NotNil) +} + +func testDatumToInt64(c *C, val interface{}, expect int64) { + d := NewDatum(val) + b, err := d.ToInt64() + c.Assert(err, IsNil) + c.Assert(b, Equals, expect) +} + +func (ts *testTypeConvertSuite) TestToInt64(c *C) { + testDatumToInt64(c, "0", int64(0)) + testDatumToInt64(c, int(0), int64(0)) + testDatumToInt64(c, int64(0), int64(0)) + testDatumToInt64(c, uint64(0), int64(0)) + testDatumToInt64(c, float32(3.1), int64(3)) + testDatumToInt64(c, float64(3.1), int64(3)) + testDatumToInt64(c, mysql.Hex{Value: 100}, int64(100)) + testDatumToInt64(c, mysql.Bit{Value: 100, Width: 8}, int64(100)) + testDatumToInt64(c, mysql.Enum{Name: "a", Value: 1}, int64(1)) + testDatumToInt64(c, mysql.Set{Name: "a", Value: 1}, int64(1)) + + t, err := mysql.ParseTime("2011-11-10 11:11:11.999999", mysql.TypeTimestamp, 0) + c.Assert(err, IsNil) + testDatumToInt64(c, t, int64(20111110111112)) + + td, err := mysql.ParseDuration("11:11:11.999999", 6) + c.Assert(err, IsNil) + testDatumToInt64(c, td, int64(111112)) + + ft := NewFieldType(mysql.TypeNewDecimal) + ft.Decimal = 5 + v, err := Convert(3.1415926, ft) + c.Assert(err, IsNil) + testDatumToInt64(c, v, int64(3)) + + _, err = ToInt64(&invalidMockType{}) + c.Assert(err, NotNil) +} From 65fd9a77c67e1c2c19be29c2b5091bcd3b888aa7 Mon Sep 17 00:00:00 2001 From: zhaoxingyu Date: Fri, 11 Mar 2016 20:27:01 +0800 Subject: [PATCH 3/3] Change bool to int64. --- util/types/datum.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/util/types/datum.go b/util/types/datum.go index ef7d807848..e9290795c0 100644 --- a/util/types/datum.go +++ b/util/types/datum.go @@ -1031,7 +1031,7 @@ func (d *Datum) ToBool() (int64, error) { return 1, nil } -// ToInt64 converts to a bool. +// ToInt64 converts to a int64. func (d *Datum) ToInt64() (int64, error) { tp := mysql.TypeLonglong lowerBound := signedLowerBound[tp]