diff --git a/executor/executor_test.go b/executor/executor_test.go index e627601a7d..e5361341de 100644 --- a/executor/executor_test.go +++ b/executor/executor_test.go @@ -1279,3 +1279,28 @@ func (s *testSuite) TestUsignedPKColumn(c *C) { result = tk.MustQuery("select * from t where b=1;") result.Check(testkit.Rows("1 1 2")) } + +func (s *testSuite) TestDatumXAPI(c *C) { + defer testleak.AfterTest(c)() + tk := testkit.NewTestKit(c, s.store) + tk.MustExec("use test") + tk.MustExec("drop table if exists t") + tk.MustExec("create table t (a decimal(10,6), b decimal, index idx_b (b))") + tk.MustExec("insert t values (1.1, 1.1)") + tk.MustExec("insert t values (2.2, 2.2)") + tk.MustExec("insert t values (3.3, 3.3)") + result := tk.MustQuery("select * from t where a > 1.5") + result.Check(testkit.Rows("2.200000 2.2", "3.300000 3.3")) + result = tk.MustQuery("select * from t where b > 1.5") + result.Check(testkit.Rows("2.200000 2.2", "3.300000 3.3")) + + tk.MustExec("drop table if exists t") + tk.MustExec("create table t (a time(3), b time, index idx_a (a))") + tk.MustExec("insert t values ('11:11:11', '11:11:11')") + tk.MustExec("insert t values ('11:11:12', '11:11:12')") + tk.MustExec("insert t values ('11:11:13', '11:11:13')") + result = tk.MustQuery("select * from t where a > '11:11:11.5'") + result.Check(testkit.Rows("11:11:12 11:11:12", "11:11:13 11:11:13")) + result = tk.MustQuery("select * from t where b > '11:11:11.5'") + result.Check(testkit.Rows("11:11:12 11:11:12", "11:11:13 11:11:13")) +} diff --git a/executor/executor_xapi.go b/executor/executor_xapi.go index 5f653f5448..8c4607a769 100644 --- a/executor/executor_xapi.go +++ b/executor/executor_xapi.go @@ -686,8 +686,8 @@ func (b *executorBuilder) columnNameToPBExpr(client kv.Client, column *ast.Colum return nil } switch column.Refer.Expr.GetType().Tp { - case mysql.TypeBit, mysql.TypeSet, mysql.TypeEnum, mysql.TypeDecimal, mysql.TypeNewDecimal, mysql.TypeGeometry, - mysql.TypeDate, mysql.TypeNewDate, mysql.TypeDatetime, mysql.TypeDuration, mysql.TypeTimestamp, mysql.TypeYear: + case mysql.TypeBit, mysql.TypeSet, mysql.TypeEnum, mysql.TypeDecimal, mysql.TypeGeometry, + mysql.TypeDate, mysql.TypeNewDate, mysql.TypeDatetime, mysql.TypeTimestamp, mysql.TypeYear: return nil } matched := false @@ -732,6 +732,12 @@ func (b *executorBuilder) datumToPBExpr(client kv.Client, d types.Datum) *tipb.E case types.KindFloat64: tp = tipb.ExprType_Float64 val = codec.EncodeFloat(nil, d.GetFloat64()) + case types.KindMysqlDuration: + tp = tipb.ExprType_MysqlDuration + val = codec.EncodeInt(nil, int64(d.GetMysqlDuration().Duration)) + case types.KindMysqlDecimal: + tp = tipb.ExprType_MysqlDecimal + val = codec.EncodeDecimal(nil, d.GetMysqlDecimal()) default: return nil } diff --git a/store/localstore/local_region.go b/store/localstore/local_region.go index e2bc3b3425..698af7e80c 100644 --- a/store/localstore/local_region.go +++ b/store/localstore/local_region.go @@ -353,7 +353,7 @@ func (rs *localRegion) evalWhereForRow(ctx *selectContext, h int64) (bool, error } else if err != nil { return false, errors.Trace(err) } - _, datum, err := codec.DecodeOne(data) + datum, err := tablecodec.DecodeColumnValue(data, col) if err != nil { return false, errors.Trace(err) } diff --git a/store/tikv/mock-tikv/cop_handler.go b/store/tikv/mock-tikv/cop_handler.go index 26001a66f2..9bb738838f 100644 --- a/store/tikv/mock-tikv/cop_handler.go +++ b/store/tikv/mock-tikv/cop_handler.go @@ -368,11 +368,12 @@ func (h *rpcHandler) evalWhereForRow(ctx *selectContext, handle int64) (bool, er } ctx.eval.Row[colID] = types.Datum{} } else { - _, datum, err := codec.DecodeOne(data) + var d types.Datum + d, err = tablecodec.DecodeColumnValue(data, col) if err != nil { return false, errors.Trace(err) } - ctx.eval.Row[colID] = datum + ctx.eval.Row[colID] = d } } } diff --git a/util/types/datum.go b/util/types/datum.go index 609acfd234..d4e55d9cc1 100644 --- a/util/types/datum.go +++ b/util/types/datum.go @@ -1398,6 +1398,18 @@ func NewFloat32Datum(f float32) (d Datum) { return d } +// NewDurationDatum creates a new Datum from a mysql.Duration value. +func NewDurationDatum(dur mysql.Duration) (d Datum) { + d.SetMysqlDuration(dur) + return d +} + +// NewDecimalDatum creates a new Datum form a mysql.Decimal value. +func NewDecimalDatum(dec mysql.Decimal) (d Datum) { + d.SetMysqlDecimal(dec) + return d +} + // MakeDatums creates datum slice from interfaces. func MakeDatums(args ...interface{}) []Datum { datums := make([]Datum, len(args)) diff --git a/xapi/tablecodec/tablecodec.go b/xapi/tablecodec/tablecodec.go index e6a7191376..8ae871fe2d 100644 --- a/xapi/tablecodec/tablecodec.go +++ b/xapi/tablecodec/tablecodec.go @@ -138,6 +138,9 @@ func unflatten(datum types.Datum, ft *types.FieldType) (types.Datum, error) { datum.SetValue(dur) return datum, nil case mysql.TypeNewDecimal: + if datum.Kind() == types.KindMysqlDecimal { + return datum, nil + } dec, err := mysql.ParseDecimal(datum.GetString()) if err != nil { return datum, errors.Trace(err) @@ -166,6 +169,26 @@ func unflatten(datum types.Datum, ft *types.FieldType) (types.Datum, error) { return datum, nil } +// DecodeColumnValue decodes data to a Datum according to the column info. +func DecodeColumnValue(data []byte, col *tipb.ColumnInfo) (types.Datum, error) { + _, d, err := codec.DecodeOne(data) + if err != nil { + return types.Datum{}, errors.Trace(err) + } + ft := &types.FieldType{ + Tp: byte(col.GetTp()), + Flen: int(col.GetColumnLen()), + Decimal: int(col.GetDecimal()), + Elems: col.Elems, + Collate: mysql.Collations[uint8(col.GetCollation())], + } + colDatum, err := unflatten(d, ft) + if err != nil { + return types.Datum{}, errors.Trace(err) + } + return colDatum, nil +} + // EncodeIndexSeekKey encodes an index value to kv.Key. func EncodeIndexSeekKey(tableID int64, idxID int64, encodedValue []byte) kv.Key { key := make([]byte, 0, prefixLen+len(encodedValue)) diff --git a/xapi/xeval/eval.go b/xapi/xeval/eval.go index 678354ee43..37dd42325e 100644 --- a/xapi/xeval/eval.go +++ b/xapi/xeval/eval.go @@ -16,8 +16,10 @@ package xeval import ( "sort" "strings" + "time" "github.com/juju/errors" + "github.com/pingcap/tidb/mysql" "github.com/pingcap/tidb/terror" "github.com/pingcap/tidb/util/codec" "github.com/pingcap/tidb/util/types" @@ -66,6 +68,10 @@ func (e *Evaluator) Eval(expr *tipb.Expr) (types.Datum, error) { return e.evalFloat(expr.Val, true) case tipb.ExprType_Float64: return e.evalFloat(expr.Val, false) + case tipb.ExprType_MysqlDecimal: + return e.evalDecimal(expr.Val) + case tipb.ExprType_MysqlDuration: + return e.evalDuration(expr.Val) case tipb.ExprType_ColumnRef: return e.evalColumnRef(expr.Val) case tipb.ExprType_LT: @@ -149,6 +155,26 @@ func (e *Evaluator) evalFloat(val []byte, f32 bool) (types.Datum, error) { return d, nil } +func (e *Evaluator) evalDecimal(val []byte) (types.Datum, error) { + var d types.Datum + _, dec, err := codec.DecodeDecimal(val) + if err != nil { + return d, ErrInvalid.Gen("invalid decimal % x", val) + } + d.SetMysqlDecimal(dec) + return d, nil +} + +func (e *Evaluator) evalDuration(val []byte) (types.Datum, error) { + var d types.Datum + _, i, err := codec.DecodeInt(val) + if err != nil { + return d, ErrInvalid.Gen("invalid duration %d", i) + } + d.SetMysqlDuration(mysql.Duration{Duration: time.Duration(i), Fsp: mysql.MaxFsp}) + return d, nil +} + func (e *Evaluator) evalLT(expr *tipb.Expr) (types.Datum, error) { cmp, err := e.compareTwoChildren(expr) if err != nil { diff --git a/xapi/xeval/eval_test.go b/xapi/xeval/eval_test.go index 479d3515ac..cc5bd7a4bb 100644 --- a/xapi/xeval/eval_test.go +++ b/xapi/xeval/eval_test.go @@ -15,8 +15,10 @@ package xeval import ( "testing" + "time" . "github.com/pingcap/check" + "github.com/pingcap/tidb/mysql" "github.com/pingcap/tidb/util/codec" "github.com/pingcap/tidb/util/types" "github.com/pingcap/tipb/go-tipb" @@ -70,8 +72,12 @@ func (s *testEvalSuite) TestEval(c *C) { types.Datum{}, }, { - datumExpr(types.Datum{}), - types.Datum{}, + datumExpr(types.NewDurationDatum(mysql.Duration{Duration: time.Hour})), + types.NewDurationDatum(mysql.Duration{Duration: time.Hour}), + }, + { + datumExpr(types.NewDecimalDatum(mysql.NewDecimalFromFloat(1.1))), + types.NewDecimalDatum(mysql.NewDecimalFromFloat(1.1)), }, { columnExpr(1), @@ -261,6 +267,12 @@ func datumExpr(d types.Datum) *tipb.Expr { case types.KindFloat64: expr.Tp = tipb.ExprType_Float64.Enum() expr.Val = codec.EncodeFloat(nil, d.GetFloat64()) + case types.KindMysqlDuration: + expr.Tp = tipb.ExprType_MysqlDuration.Enum() + expr.Val = codec.EncodeInt(nil, int64(d.GetMysqlDuration().Duration)) + case types.KindMysqlDecimal: + expr.Tp = tipb.ExprType_MysqlDecimal.Enum() + expr.Val = codec.EncodeDecimal(nil, d.GetMysqlDecimal()) default: expr.Tp = tipb.ExprType_Null.Enum() }