diff --git a/expression/distsql_builtin.go b/expression/distsql_builtin.go index b81719754f..3cc43212cb 100644 --- a/expression/distsql_builtin.go +++ b/expression/distsql_builtin.go @@ -684,10 +684,15 @@ func convertFloat(val []byte, f32 bool) (*Constant, error) { func convertDecimal(val []byte) (*Constant, error) { _, dec, err := codec.DecodeDecimal(val) + var d types.Datum + precision, frac := dec.PrecisionAndFrac() + d.SetMysqlDecimal(dec) + d.SetLength(precision) + d.SetFrac(frac) if err != nil { return nil, errors.Errorf("invalid decimal % x", val) } - return &Constant{Value: dec, RetType: types.NewFieldType(mysql.TypeNewDecimal)}, nil + return &Constant{Value: d, RetType: types.NewFieldType(mysql.TypeNewDecimal)}, nil } func convertDuration(val []byte) (*Constant, error) { diff --git a/types/mydecimal.go b/types/mydecimal.go index aec5395fe8..2219465afc 100644 --- a/types/mydecimal.go +++ b/types/mydecimal.go @@ -1235,7 +1235,8 @@ func (d *MyDecimal) FromBin(bin []byte, precision, frac int) (binSize int, err e mask = 0 } binSize = decimalBinSize(precision, frac) - dCopy := make([]byte, binSize) + dCopy := make([]byte, 40) + dCopy = dCopy[:binSize] copy(dCopy, bin) dCopy[0] ^= 0x80 bin = dCopy diff --git a/util/codec/bench_test.go b/util/codec/bench_test.go index 97e56b199c..c3fda693b6 100644 --- a/util/codec/bench_test.go +++ b/util/codec/bench_test.go @@ -60,3 +60,14 @@ func BenchmarkEncodeIntWithOutSize(b *testing.B) { EncodeInt(nil, 10) } } + +func BenchmarkDecodeDecimal(b *testing.B) { + dec := &types.MyDecimal{} + dec.FromFloat64(1211.1211113) + precision, frac := dec.PrecisionAndFrac() + raw := EncodeDecimal([]byte{}, dec, precision, frac) + b.ResetTimer() + for i := 0; i < b.N; i++ { + DecodeDecimal(raw) + } +} diff --git a/util/codec/codec.go b/util/codec/codec.go index 504d9cca2a..9dc92656ca 100644 --- a/util/codec/codec.go +++ b/util/codec/codec.go @@ -334,7 +334,12 @@ func DecodeOne(b []byte) (remain []byte, d types.Datum, err error) { b, v, err = DecodeCompactBytes(b) d.SetBytes(v) case decimalFlag: - b, d, err = DecodeDecimal(b) + var dec *types.MyDecimal + b, dec, err = DecodeDecimal(b) + precision, frac := dec.PrecisionAndFrac() + d.SetMysqlDecimal(dec) + d.SetLength(precision) + d.SetFrac(frac) case durationFlag: var r int64 b, r, err = DecodeInt(b) @@ -534,12 +539,12 @@ func DecodeOneToChunk(b []byte, chk *chunk.Chunk, colIdx int, ft *types.FieldTyp } chk.AppendBytes(colIdx, v) case decimalFlag: - var d types.Datum - b, d, err = DecodeDecimal(b) + var dec *types.MyDecimal + b, dec, err = DecodeDecimal(b) if err != nil { return nil, errors.Trace(err) } - chk.AppendMyDecimal(colIdx, d.GetMysqlDecimal()) + chk.AppendMyDecimal(colIdx, dec) case durationFlag: var r int64 b, r, err = DecodeInt(b) diff --git a/util/codec/decimal.go b/util/codec/decimal.go index 2368e596a8..1ccdabae1d 100644 --- a/util/codec/decimal.go +++ b/util/codec/decimal.go @@ -35,10 +35,9 @@ func EncodeDecimal(b []byte, dec *types.MyDecimal, precision, frac int) []byte { } // DecodeDecimal decodes bytes to decimal. -func DecodeDecimal(b []byte) ([]byte, types.Datum, error) { - var d types.Datum +func DecodeDecimal(b []byte) ([]byte, *types.MyDecimal, error) { if len(b) < 3 { - return b, d, errors.New("insufficient bytes to decode value") + return b, nil, errors.New("insufficient bytes to decode value") } precision := int(b[0]) frac := int(b[1]) @@ -47,10 +46,7 @@ func DecodeDecimal(b []byte) ([]byte, types.Datum, error) { binSize, err := dec.FromBin(b, precision, frac) b = b[binSize:] if err != nil { - return b, d, errors.Trace(err) + return b, nil, errors.Trace(err) } - d.SetLength(precision) - d.SetFrac(frac) - d.SetMysqlDecimal(dec) - return b, d, nil + return b, dec, nil } diff --git a/util/codec/decimal_test.go b/util/codec/decimal_test.go index 59eebf7650..64d27d4c1f 100644 --- a/util/codec/decimal_test.go +++ b/util/codec/decimal_test.go @@ -15,7 +15,6 @@ package codec import ( . "github.com/pingcap/check" - "github.com/pingcap/tidb/sessionctx/stmtctx" "github.com/pingcap/tidb/types" "github.com/pingcap/tidb/util/testleak" ) @@ -51,7 +50,7 @@ func (s *testDecimalSuite) TestDecimalCodec(c *C) { b := EncodeDecimal([]byte{}, datum.GetMysqlDecimal(), datum.Length(), datum.Frac()) _, d, err := DecodeDecimal(b) c.Assert(err, IsNil) - c.Assert(v.Compare(d.GetMysqlDecimal()), Equals, 0) + c.Assert(v.Compare(d), Equals, 0) } } @@ -72,11 +71,7 @@ func testFrac(c *C, v *types.MyDecimal) { var d1 types.Datum d1.SetMysqlDecimal(v) b := EncodeDecimal([]byte{}, d1.GetMysqlDecimal(), d1.Length(), d1.Frac()) - _, d2, err := DecodeDecimal(b) + _, dec, err := DecodeDecimal(b) c.Assert(err, IsNil) - sc := new(stmtctx.StatementContext) - cmp, err := d1.CompareDatum(sc, &d2) - c.Assert(err, IsNil) - c.Assert(cmp, Equals, 0) - c.Assert(d1.GetMysqlDecimal().String(), Equals, d2.GetMysqlDecimal().String()) + c.Assert(dec.String(), Equals, v.String()) }