From eaa75a81d56471b602efa3f4ee19bc5345f2cdff Mon Sep 17 00:00:00 2001 From: xzhangxian1008 Date: Thu, 29 Aug 2024 22:43:27 +0800 Subject: [PATCH] expression: fix incorrect result when the result of casting const value to duration type is null (#55454) close pingcap/tidb#51842 --- pkg/expression/builtin_compare.go | 58 ++++++++++++++++++- pkg/expression/integration_test/BUILD.bazel | 2 +- .../integration_test/integration_test.go | 24 ++++++++ 3 files changed, 80 insertions(+), 4 deletions(-) diff --git a/pkg/expression/builtin_compare.go b/pkg/expression/builtin_compare.go index 2ab2bcd2ae..44f9a1db5b 100644 --- a/pkg/expression/builtin_compare.go +++ b/pkg/expression/builtin_compare.go @@ -1603,6 +1603,48 @@ func matchRefineRule3Pattern(conEvalType types.EvalType, exprType *types.FieldTy (conEvalType == types.ETReal || conEvalType == types.ETDecimal || conEvalType == types.ETInt) } +// handleDurationTypeComparison handles comparisons between a duration type column and a non-duration type constant. +// If the constant cannot be cast to a duration type and the comparison operator is `<=>`, the expression is rewritten as `0 <=> 1`. +// This is necessary to maintain compatibility with MySQL behavior under the following conditions: +// 1. When a duration type column is compared with a non-duration type constant, MySQL casts the duration column to the non-duration type. +// This cast prevents the use of indexes on the duration column. In TiDB, we instead cast the non-duration type constant to the duration type. +// 2. If the non-duration type constant cannot be successfully cast to a duration type, the cast returns null. A duration type constant, however, +// can always be cast to a non-duration type without returning null. +// 3. If the duration type column's value is null and the non-duration type constant cannot be cast to a duration type, and the comparison operator +// is `<=>` (null equal), then in TiDB, `durationColumn <=> non-durationTypeConstant` evaluates to `null <=> null`, returning true. In MySQL, +// it would evaluate to `null <=> not-null constant`, returning false. +// +// To ensure MySQL compatibility, we need to handle this case specifically. If the non-duration type constant cannot be cast to a duration type, +// we rewrite the expression to always return false by converting it to `0 <=> 1`. +func (c *compareFunctionClass) handleDurationTypeComparison(ctx BuildContext, arg0, arg1 Expression) (_ []Expression, err error) { + // check if a constant value becomes null after being cast to a duration type. + castToDurationIsNull := func(ctx BuildContext, arg Expression) (bool, error) { + f := WrapWithCastAsDuration(ctx, arg) + _, isNull, err := f.EvalDuration(ctx.GetEvalCtx(), chunk.Row{}) + if err != nil { + return false, err + } + return isNull, nil + } + + arg0Const, arg0IsCon := arg0.(*Constant) + arg1Const, arg1IsCon := arg1.(*Constant) + + var isNull bool + if arg0IsCon && arg0Const.DeferredExpr == nil && !arg1IsCon && arg1.GetType(ctx.GetEvalCtx()).GetType() == mysql.TypeDuration { + isNull, err = castToDurationIsNull(ctx, arg0) + } else if arg1IsCon && arg1Const.DeferredExpr == nil && !arg0IsCon && arg0.GetType(ctx.GetEvalCtx()).GetType() == mysql.TypeDuration { + isNull, err = castToDurationIsNull(ctx, arg1) + } + if err != nil { + return nil, err + } + if isNull { + return []Expression{NewZero(), NewOne()}, nil + } + return nil, nil +} + // Since the argument refining of cmp functions can bring some risks to the plan-cache, the optimizer // needs to decide to whether to skip the refining or skip plan-cache for safety. // For example, `unsigned_int_col > ?(-1)` can be refined to `True`, but the validation of this result @@ -1654,9 +1696,12 @@ func allowCmpArgsRefining4PlanCache(ctx BuildContext, args []Expression) (allowR } // refineArgs will rewrite the arguments if the compare expression is -// 1. `int column non-int constant` or `non-int constant int column`. E.g., `a < 1.1` will be rewritten to `a < 2`. -// 2. It also handles comparing year type with int constant if the int constant falls into a sensible year representation. -// 3. It also handles comparing datetime/timestamp column with numeric constant, try to cast numeric constant as timestamp type, do nothing if failed. +// 1. `int column non-int constant` or `non-int constant int column`. E.g., `a < 1.1` will be rewritten to `a < 2`. +// 2. It also handles comparing year type with int constant if the int constant falls into a sensible year representation. +// 3. It also handles comparing datetime/timestamp column with numeric constant, try to cast numeric constant as timestamp type, do nothing if failed. +// 4. Handles special cases where a duration type column is compared with a non-duration type constant, particularly when the constant +// cannot be cast to a duration type, ensuring compatibility with MySQL’s behavior by rewriting the expression as `0 <=> 1`. +// // This refining operation depends on the values of these args, but these values can change when using plan-cache. // So we have to skip this operation or mark the plan as over-optimized when using plan-cache. func (c *compareFunctionClass) refineArgs(ctx BuildContext, args []Expression) ([]Expression, error) { @@ -1677,6 +1722,13 @@ func (c *compareFunctionClass) refineArgs(ctx BuildContext, args []Expression) ( return nil, err } + // Handle comparison between a duration type column and a non-duration type constant. + if c.op == opcode.NullEQ { + if result, err := c.handleDurationTypeComparison(ctx, args[0], args[1]); err != nil || result != nil { + return result, err + } + } + if arg0IsCon && !arg1IsCon && matchRefineRule3Pattern(arg0EvalType, arg1Type) { return c.refineNumericConstantCmpDatetime(ctx, args, arg0, 0), nil } diff --git a/pkg/expression/integration_test/BUILD.bazel b/pkg/expression/integration_test/BUILD.bazel index 97fbba2ef2..8006de9b71 100644 --- a/pkg/expression/integration_test/BUILD.bazel +++ b/pkg/expression/integration_test/BUILD.bazel @@ -8,7 +8,7 @@ go_test( "main_test.go", ], flaky = True, - shard_count = 41, + shard_count = 42, deps = [ "//pkg/config", "//pkg/domain", diff --git a/pkg/expression/integration_test/integration_test.go b/pkg/expression/integration_test/integration_test.go index 603a0ced97..c70449a4c0 100644 --- a/pkg/expression/integration_test/integration_test.go +++ b/pkg/expression/integration_test/integration_test.go @@ -3547,3 +3547,27 @@ func TestIssue43527(t *testing.T) { "SELECT @total := @total + d FROM (SELECT d FROM test) AS temp, (SELECT @total := b FROM test) AS T1 where @total >= 100", ).Check(testkit.Rows("200", "300", "400", "500")) } + +func TestIssue51842(t *testing.T) { + store := testkit.CreateMockStore(t) + tk := testkit.NewTestKit(t, store) + tk.MustExec("use test") + tk.MustExec("drop table if exists t;") + tk.MustExec("CREATE TABLE t0(c0 DOUBLE);") + tk.MustExec("REPLACE INTO t0(c0) VALUES (0.40194983109852933);") + tk.MustExec("CREATE VIEW v0(c0) AS SELECT CAST(')' AS TIME) FROM t0 WHERE '0.030417148673465677';") + res := tk.MustQuery("SELECT f1 FROM (SELECT NULLIF(v0.c0, 1371581446) AS f1 FROM v0, t0) AS t WHERE f1 <=> 1292367147;").String() // test int + require.Equal(t, 0, len(res)) + res = tk.MustQuery("SELECT f1 FROM (SELECT NULLIF(v0.c0, 1371581446) AS f1 FROM v0, t0) AS t WHERE f1 <=> cast(123988.42132 as real);").String() // test real + require.Equal(t, 0, len(res)) + res = tk.MustQuery("SELECT f1 FROM (SELECT NULLIF(v0.c0, 1371581446) AS f1 FROM v0, t0) AS t WHERE f1 <=> cast(123988.42132 as decimal);").String() // test decimal + require.Equal(t, 0, len(res)) + res = tk.MustQuery("SELECT f1 FROM (SELECT NULLIF(v0.c0, 1371581446) AS f1 FROM v0, t0) AS t WHERE f1 <=> cast('fdasge' as char);").String() // test string + require.Equal(t, 0, len(res)) + res = tk.MustQuery("SELECT f1 FROM (SELECT NULLIF(v0.c0, 1371581446) AS f1 FROM v0, t0) AS t WHERE f1 <=> cast('10:10:10' as time);").String() // test time + require.Equal(t, 0, len(res)) + res = tk.MustQuery("SELECT f1 FROM (SELECT NULLIF(v0.c0, 1371581446) AS f1 FROM v0, t0) AS t WHERE f1 <=> cast(2024 as year);").String() // test year + require.Equal(t, 0, len(res)) + res = tk.MustQuery("SELECT f1 FROM (SELECT NULLIF(v0.c0, 1371581446) AS f1 FROM v0, t0) AS t WHERE f1 <=> cast('2024-1-1 10:10:10' as datetime);").String() // test datetime + require.Equal(t, 0, len(res)) +}