From 5ee0847cabb3cf80f079e2235f6a218f02ca7aa4 Mon Sep 17 00:00:00 2001 From: Kenan Yao Date: Tue, 20 Aug 2019 23:15:09 +0800 Subject: [PATCH] expression: fix wrong result of `Not`/`IsTrue`/`IsFalse` functions (#10498) --- expression/builtin_op.go | 96 ++++++++++++++++++++++++++++++----- expression/builtin_op_test.go | 18 +++++++ expression/distsql_builtin.go | 8 ++- expression/expr_to_pb_test.go | 2 +- expression/util_test.go | 2 +- go.mod | 2 +- go.sum | 4 +- 7 files changed, 112 insertions(+), 20 deletions(-) diff --git a/expression/builtin_op.go b/expression/builtin_op.go index 3b27a09aa7..c7af218cdb 100644 --- a/expression/builtin_op.go +++ b/expression/builtin_op.go @@ -17,6 +17,7 @@ import ( "fmt" "math" + "github.com/pingcap/errors" "github.com/pingcap/parser/mysql" "github.com/pingcap/parser/opcode" "github.com/pingcap/tidb/sessionctx" @@ -52,7 +53,9 @@ var ( _ builtinFunc = &builtinRealIsNullSig{} _ builtinFunc = &builtinStringIsNullSig{} _ builtinFunc = &builtinTimeIsNullSig{} - _ builtinFunc = &builtinUnaryNotSig{} + _ builtinFunc = &builtinUnaryNotRealSig{} + _ builtinFunc = &builtinUnaryNotDecimalSig{} + _ builtinFunc = &builtinUnaryNotIntSig{} ) type logicAndFunctionClass struct { @@ -383,8 +386,10 @@ func (c *isTrueOrFalseFunctionClass) getFunction(ctx sessionctx.Context, args [] } argTp := args[0].GetType().EvalType() - if argTp != types.ETReal && argTp != types.ETDecimal { + if argTp == types.ETTimestamp || argTp == types.ETDatetime || argTp == types.ETDuration { argTp = types.ETInt + } else if argTp == types.ETJson || argTp == types.ETString { + argTp = types.ETReal } bf := newBaseBuiltinFuncWithTp(ctx, args, types.ETInt, argTp) @@ -403,6 +408,8 @@ func (c *isTrueOrFalseFunctionClass) getFunction(ctx sessionctx.Context, args [] case types.ETInt: sig = &builtinIntIsTrueSig{bf} sig.setPbCode(tipb.ScalarFuncSig_IntIsTrue) + default: + return nil, errors.Errorf("unexpected types.EvalType %v", argTp) } case opcode.IsFalsity: switch argTp { @@ -415,6 +422,8 @@ func (c *isTrueOrFalseFunctionClass) getFunction(ctx sessionctx.Context, args [] case types.ETInt: sig = &builtinIntIsFalseSig{bf} sig.setPbCode(tipb.ScalarFuncSig_IntIsFalse) + default: + return nil, errors.Errorf("unexpected types.EvalType %v", argTp) } } return sig, nil @@ -588,33 +597,94 @@ func (c *unaryNotFunctionClass) getFunction(ctx sessionctx.Context, args []Expre return nil, err } - bf := newBaseBuiltinFuncWithTp(ctx, args, types.ETInt, types.ETInt) + argTp := args[0].GetType().EvalType() + if argTp == types.ETTimestamp || argTp == types.ETDatetime || argTp == types.ETDuration { + argTp = types.ETInt + } else if argTp == types.ETJson || argTp == types.ETString { + argTp = types.ETReal + } + + bf := newBaseBuiltinFuncWithTp(ctx, args, types.ETInt, argTp) bf.tp.Flen = 1 - sig := &builtinUnaryNotSig{bf} - sig.setPbCode(tipb.ScalarFuncSig_UnaryNot) + var sig builtinFunc + switch argTp { + case types.ETReal: + sig = &builtinUnaryNotRealSig{bf} + sig.setPbCode(tipb.ScalarFuncSig_UnaryNotReal) + case types.ETDecimal: + sig = &builtinUnaryNotDecimalSig{bf} + sig.setPbCode(tipb.ScalarFuncSig_UnaryNotDecimal) + case types.ETInt: + sig = &builtinUnaryNotIntSig{bf} + sig.setPbCode(tipb.ScalarFuncSig_UnaryNotInt) + default: + return nil, errors.Errorf("unexpected types.EvalType %v", argTp) + } return sig, nil } -type builtinUnaryNotSig struct { +type builtinUnaryNotRealSig struct { baseBuiltinFunc } -func (b *builtinUnaryNotSig) Clone() builtinFunc { - newSig := &builtinUnaryNotSig{} +func (b *builtinUnaryNotRealSig) Clone() builtinFunc { + newSig := &builtinUnaryNotRealSig{} newSig.cloneFrom(&b.baseBuiltinFunc) return newSig } -func (b *builtinUnaryNotSig) evalInt(row chunk.Row) (int64, bool, error) { +func (b *builtinUnaryNotRealSig) evalInt(row chunk.Row) (int64, bool, error) { + arg, isNull, err := b.args[0].EvalReal(b.ctx, row) + if isNull || err != nil { + return 0, true, err + } + if arg == 0 { + return 1, false, nil + } + return 0, false, nil +} + +type builtinUnaryNotDecimalSig struct { + baseBuiltinFunc +} + +func (b *builtinUnaryNotDecimalSig) Clone() builtinFunc { + newSig := &builtinUnaryNotDecimalSig{} + newSig.cloneFrom(&b.baseBuiltinFunc) + return newSig +} + +func (b *builtinUnaryNotDecimalSig) evalInt(row chunk.Row) (int64, bool, error) { + arg, isNull, err := b.args[0].EvalDecimal(b.ctx, row) + if isNull || err != nil { + return 0, true, err + } + if arg.IsZero() { + return 1, false, nil + } + return 0, false, nil +} + +type builtinUnaryNotIntSig struct { + baseBuiltinFunc +} + +func (b *builtinUnaryNotIntSig) Clone() builtinFunc { + newSig := &builtinUnaryNotIntSig{} + newSig.cloneFrom(&b.baseBuiltinFunc) + return newSig +} + +func (b *builtinUnaryNotIntSig) evalInt(row chunk.Row) (int64, bool, error) { arg, isNull, err := b.args[0].EvalInt(b.ctx, row) if isNull || err != nil { - return 0, isNull, err + return 0, true, err } - if arg != 0 { - return 0, false, nil + if arg == 0 { + return 1, false, nil } - return 1, false, nil + return 0, false, nil } type unaryMinusFunctionClass struct { diff --git a/expression/builtin_op_test.go b/expression/builtin_op_test.go index b2f700cb7c..99fb13d76b 100644 --- a/expression/builtin_op_test.go +++ b/expression/builtin_op_test.go @@ -441,6 +441,9 @@ func (s *testEvaluatorSuite) TestUnaryNot(c *C) { {[]interface{}{123}, 0, false, false}, {[]interface{}{-123}, 0, false, false}, {[]interface{}{"123"}, 0, false, false}, + {[]interface{}{float64(0.3)}, 0, false, false}, + {[]interface{}{"0.3"}, 0, false, false}, + {[]interface{}{types.NewDecFromFloatForTest(0.3)}, 0, false, false}, {[]interface{}{nil}, 0, true, false}, {[]interface{}{errors.New("must error")}, 0, false, true}, @@ -514,6 +517,21 @@ func (s *testEvaluatorSuite) TestIsTrueOrFalse(c *C) { isTrue: 0, isFalse: 1, }, + { + args: []interface{}{"0.3"}, + isTrue: 1, + isFalse: 0, + }, + { + args: []interface{}{float64(0.3)}, + isTrue: 1, + isFalse: 0, + }, + { + args: []interface{}{types.NewDecFromFloatForTest(0.3)}, + isTrue: 1, + isFalse: 0, + }, { args: []interface{}{nil}, isTrue: 0, diff --git a/expression/distsql_builtin.go b/expression/distsql_builtin.go index 42045b1dde..4c4fdc72bf 100644 --- a/expression/distsql_builtin.go +++ b/expression/distsql_builtin.go @@ -320,8 +320,12 @@ func getSignatureByPB(ctx sessionctx.Context, sigCode tipb.ScalarFuncSig, tp *ti case tipb.ScalarFuncSig_BitNegSig: f = &builtinBitNegSig{base} - case tipb.ScalarFuncSig_UnaryNot: - f = &builtinUnaryNotSig{base} + case tipb.ScalarFuncSig_UnaryNotReal: + f = &builtinUnaryNotRealSig{base} + case tipb.ScalarFuncSig_UnaryNotDecimal: + f = &builtinUnaryNotDecimalSig{base} + case tipb.ScalarFuncSig_UnaryNotInt: + f = &builtinUnaryNotIntSig{base} case tipb.ScalarFuncSig_UnaryMinusInt: f = &builtinUnaryMinusIntSig{base} case tipb.ScalarFuncSig_UnaryMinusReal: diff --git a/expression/expr_to_pb_test.go b/expression/expr_to_pb_test.go index 252608bdf1..2331488f60 100644 --- a/expression/expr_to_pb_test.go +++ b/expression/expr_to_pb_test.go @@ -432,7 +432,7 @@ func (s *testEvaluatorSuite) TestPushDownSwitcher(c *C) { }{ {ast.And, tipb.ScalarFuncSig_BitAndSig, true}, {ast.Or, tipb.ScalarFuncSig_BitOrSig, false}, - {ast.UnaryNot, tipb.ScalarFuncSig_UnaryNot, true}, + {ast.UnaryNot, tipb.ScalarFuncSig_UnaryNotInt, true}, } var enabled []string for i, funcName := range cases { diff --git a/expression/util_test.go b/expression/util_test.go index 1311ca13a0..d46fc7500c 100644 --- a/expression/util_test.go +++ b/expression/util_test.go @@ -120,7 +120,7 @@ func (s *testUtilSuite) TestClone(c *check.C) { &builtinNameConstJSONSig{}, &builtinLogicAndSig{}, &builtinLogicOrSig{}, &builtinLogicXorSig{}, &builtinRealIsTrueSig{}, &builtinDecimalIsTrueSig{}, &builtinIntIsTrueSig{}, &builtinRealIsFalseSig{}, &builtinDecimalIsFalseSig{}, &builtinIntIsFalseSig{}, &builtinUnaryMinusIntSig{}, &builtinDecimalIsNullSig{}, &builtinDurationIsNullSig{}, &builtinIntIsNullSig{}, &builtinRealIsNullSig{}, - &builtinStringIsNullSig{}, &builtinTimeIsNullSig{}, &builtinUnaryNotSig{}, &builtinSleepSig{}, &builtinInIntSig{}, + &builtinStringIsNullSig{}, &builtinTimeIsNullSig{}, &builtinUnaryNotRealSig{}, &builtinUnaryNotDecimalSig{}, &builtinUnaryNotIntSig{}, &builtinSleepSig{}, &builtinInIntSig{}, &builtinInStringSig{}, &builtinInDecimalSig{}, &builtinInRealSig{}, &builtinInTimeSig{}, &builtinInDurationSig{}, &builtinInJSONSig{}, &builtinRowSig{}, &builtinSetVarSig{}, &builtinGetVarSig{}, &builtinLockSig{}, &builtinReleaseLockSig{}, &builtinValuesIntSig{}, &builtinValuesRealSig{}, &builtinValuesDecimalSig{}, &builtinValuesStringSig{}, diff --git a/go.mod b/go.mod index 3d84eadc98..c52647a431 100644 --- a/go.mod +++ b/go.mod @@ -44,7 +44,7 @@ require ( github.com/pingcap/parser v0.0.0-20190819021501-c5d6ce829420 github.com/pingcap/pd v0.0.0-20190712044914-75a1f9f3062b github.com/pingcap/tidb-tools v2.1.3-0.20190321065848-1e8b48f5c168+incompatible - github.com/pingcap/tipb v0.0.0-20190428032612-535e1abaa330 + github.com/pingcap/tipb v0.0.0-20190806070524-16909e03435e github.com/prometheus/client_golang v0.9.0 github.com/prometheus/client_model v0.0.0-20180712105110-5c3871d89910 github.com/prometheus/common v0.0.0-20181020173914-7e9e6cabbd39 // indirect diff --git a/go.sum b/go.sum index b1f13edde1..5725706c12 100644 --- a/go.sum +++ b/go.sum @@ -171,8 +171,8 @@ github.com/pingcap/pd v0.0.0-20190712044914-75a1f9f3062b h1:oS9PftxQqgcRouKhhdaB github.com/pingcap/pd v0.0.0-20190712044914-75a1f9f3062b/go.mod h1:3DlDlFT7EF64A1bmb/tulZb6wbPSagm5G4p1AlhaEDs= github.com/pingcap/tidb-tools v2.1.3-0.20190321065848-1e8b48f5c168+incompatible h1:MkWCxgZpJBgY2f4HtwWMMFzSBb3+JPzeJgF3VrXE/bU= github.com/pingcap/tidb-tools v2.1.3-0.20190321065848-1e8b48f5c168+incompatible/go.mod h1:XGdcy9+yqlDSEMTpOXnwf3hiTeqrV6MN/u1se9N8yIM= -github.com/pingcap/tipb v0.0.0-20190428032612-535e1abaa330 h1:rRMLMjIMFulCX9sGKZ1hoov/iROMsKyC8Snc02nSukw= -github.com/pingcap/tipb v0.0.0-20190428032612-535e1abaa330/go.mod h1:RtkHW8WbcNxj8lsbzjaILci01CtYnYbIkQhjyZWrWVI= +github.com/pingcap/tipb v0.0.0-20190806070524-16909e03435e h1:H7meq8QPmWGImOkHTQYAWw82zwIqndJaCDPVUknOHbM= +github.com/pingcap/tipb v0.0.0-20190806070524-16909e03435e/go.mod h1:RtkHW8WbcNxj8lsbzjaILci01CtYnYbIkQhjyZWrWVI= github.com/pkg/errors v0.8.0/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= github.com/pkg/errors v0.8.1 h1:iURUrRGxPUNPdy5/HRSm+Yj6okJ6UtLINN0Q9M4+h3I= github.com/pkg/errors v0.8.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0=