From 1399e55df71a419454e2a3bb014138905f4c8a58 Mon Sep 17 00:00:00 2001 From: Yuanjia Zhang Date: Fri, 20 Dec 2019 01:46:27 -0600 Subject: [PATCH] expression: fix invalid compare-operations in vectorized `builtinInDecimalSig ` (#14156) --- expression/builtin_other_vec_generated.go | 2 +- expression/builtin_other_vec_test.go | 31 +++++++++++++++++++++++ expression/generator/other_vec.go | 2 +- 3 files changed, 33 insertions(+), 2 deletions(-) diff --git a/expression/builtin_other_vec_generated.go b/expression/builtin_other_vec_generated.go index b807e19a76..11ffa01437 100644 --- a/expression/builtin_other_vec_generated.go +++ b/expression/builtin_other_vec_generated.go @@ -188,7 +188,7 @@ func (b *builtinInDecimalSig) vecEvalInt(input *chunk.Chunk, result *chunk.Colum arg0 := args0[i] arg1 := args1[i] compareResult = 1 - if arg0 == arg1 { + if arg0.Compare(&arg1) == 0 { compareResult = 0 } if compareResult == 0 { diff --git a/expression/builtin_other_vec_test.go b/expression/builtin_other_vec_test.go index fd13197334..74cb3145d1 100644 --- a/expression/builtin_other_vec_test.go +++ b/expression/builtin_other_vec_test.go @@ -14,11 +14,15 @@ package expression import ( + "fmt" + "math/rand" "testing" . "github.com/pingcap/check" "github.com/pingcap/parser/ast" "github.com/pingcap/tidb/types" + "github.com/pingcap/tidb/util/chunk" + "github.com/pingcap/tidb/util/mock" ) var vecBuiltinOtherCases = map[string][]vecExprBenchCase{ @@ -44,3 +48,30 @@ func (s *testEvaluatorSuite) TestVectorizedBuiltinOtherFunc(c *C) { func BenchmarkVectorizedBuiltinOtherFunc(b *testing.B) { benchmarkVectorizedBuiltinFunc(b, vecBuiltinOtherCases) } + +func (s *testEvaluatorSuite) TestInDecimal(c *C) { + ctx := mock.NewContext() + ft := eType2FieldType(types.ETDecimal) + col0 := &Column{RetType: ft, Index: 0} + col1 := &Column{RetType: ft, Index: 1} + inFunc, err := funcs[ast.In].getFunction(ctx, []Expression{col0, col1}) + c.Assert(err, IsNil) + + input := chunk.NewChunkWithCapacity([]*types.FieldType{ft, ft}, 1024) + for i := 0; i < 1024; i++ { + d0 := new(types.MyDecimal) + d1 := new(types.MyDecimal) + v := fmt.Sprintf("%v", float64(rand.Intn(1000))+rand.Float64()) + c.Assert(d0.FromString([]byte(v)), IsNil) + v += "00" + c.Assert(d1.FromString([]byte(v)), IsNil) + input.Column(0).AppendMyDecimal(d0) + input.Column(1).AppendMyDecimal(d1) + c.Assert(input.Column(0).GetDecimal(i).GetDigitsFrac(), Not(Equals), input.Column(1).GetDecimal(i).GetDigitsFrac()) + } + result := chunk.NewColumn(ft, 1024) + c.Assert(inFunc.vecEvalInt(input, result), IsNil) + for i := 0; i < 1024; i++ { + c.Assert(result.GetInt64(0), Equals, int64(1)) + } +} diff --git a/expression/generator/other_vec.go b/expression/generator/other_vec.go index c5719a3b12..40c3cd0d22 100644 --- a/expression/generator/other_vec.go +++ b/expression/generator/other_vec.go @@ -97,7 +97,7 @@ var builtinInTmpl = template.Must(template.New("builtinInTmpl").Parse(` } {{- else if eq .Input.TypeName "Decimal" -}} compareResult = 1 - if arg0 == arg1 { + if arg0.Compare(&arg1) == 0 { compareResult = 0 } {{- else if eq .Input.TypeName "Time" -}}