From 036e17dcb0f73193e8819a56b2219491cb73607d Mon Sep 17 00:00:00 2001 From: starocean999 <40539150+starocean999@users.noreply.github.com> Date: Mon, 29 Jan 2024 15:26:37 +0800 Subject: [PATCH] [test](nereids)add fe ut for SimplifyArithmeticComparisonRule (#27644) --- .../SimplifyArithmeticComparisonRule.java | 20 ++++++----- .../SimplifyArithmeticComparisonRuleTest.java | 34 +++++++++++++++++++ 2 files changed, 46 insertions(+), 8 deletions(-) diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rules/SimplifyArithmeticComparisonRule.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rules/SimplifyArithmeticComparisonRule.java index eda95ba32b..7606d08247 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rules/SimplifyArithmeticComparisonRule.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rules/SimplifyArithmeticComparisonRule.java @@ -82,20 +82,23 @@ public class SimplifyArithmeticComparisonRule extends AbstractExpressionRewriteR @Override public Expression visitComparisonPredicate(ComparisonPredicate comparison, ExpressionRewriteContext context) { - ComparisonPredicate newComparison = comparison; if (couldRearrange(comparison)) { - newComparison = normalize(comparison); + ComparisonPredicate newComparison = normalize(comparison); if (newComparison == null) { return comparison; } try { - List children = tryRearrangeChildren(newComparison.left(), newComparison.right()); - newComparison = (ComparisonPredicate) newComparison.withChildren(children); + List children = + tryRearrangeChildren(newComparison.left(), newComparison.right(), context); + newComparison = (ComparisonPredicate) visitComparisonPredicate( + (ComparisonPredicate) newComparison.withChildren(children), context); } catch (Exception e) { return comparison; } + return TypeCoercionUtils.processComparisonPredicate(newComparison); + } else { + return comparison; } - return TypeCoercionUtils.processComparisonPredicate(newComparison); } private boolean couldRearrange(ComparisonPredicate cmp) { @@ -104,11 +107,12 @@ public class SimplifyArithmeticComparisonRule extends AbstractExpressionRewriteR && cmp.left().children().stream().anyMatch(Expression::isConstant); } - private List tryRearrangeChildren(Expression left, Expression right) throws Exception { - if (!left.child(1).isLiteral()) { + private List tryRearrangeChildren(Expression left, Expression right, + ExpressionRewriteContext context) throws Exception { + if (!left.child(1).isConstant()) { throw new RuntimeException(String.format("Expected literal when arranging children for Expr %s", left)); } - Literal leftLiteral = (Literal) left.child(1); + Literal leftLiteral = (Literal) FoldConstantRule.INSTANCE.rewrite(left.child(1), context); Expression leftExpr = left.child(0); Class oppositeOperator = rearrangementMap.get(left.getClass()); diff --git a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/expression/rules/SimplifyArithmeticComparisonRuleTest.java b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/expression/rules/SimplifyArithmeticComparisonRuleTest.java index 5a438ded65..fc31daaa94 100644 --- a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/expression/rules/SimplifyArithmeticComparisonRuleTest.java +++ b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/expression/rules/SimplifyArithmeticComparisonRuleTest.java @@ -41,6 +41,40 @@ class SimplifyArithmeticComparisonRuleTest extends ExpressionRewriteTestHelper { assertRewriteAfterSimplify("a + 1 > 1", "a > cast((1 - 1) as INT)", nameToSlot); assertRewriteAfterSimplify("a - 1 > 1", "a > cast((1 + 1) as INT)", nameToSlot); assertRewriteAfterSimplify("a / -2 > 1", "cast((1 * -2) as INT) > a", nameToSlot); + + // test integer type + assertRewriteAfterSimplify("1 + a > 2", "a > cast((2 - 1) as INT)", nameToSlot); + assertRewriteAfterSimplify("-1 + a > 2", "a > cast((2 - (-1)) as INT)", nameToSlot); + assertRewriteAfterSimplify("1 - a > 2", "a < cast((1 - 2) as INT)", nameToSlot); + assertRewriteAfterSimplify("-1 - a > 2", "a < cast(((-1) - 2) as INT)", nameToSlot); + assertRewriteAfterSimplify("2 * a > 1", "((2 * a) > 1)", nameToSlot); + assertRewriteAfterSimplify("-2 * a > 1", "((-2 * a) > 1)", nameToSlot); + assertRewriteAfterSimplify("2 / a > 1", "((2 / a) > 1)", nameToSlot); + assertRewriteAfterSimplify("-2 / a > 1", "((-2 / a) > 1)", nameToSlot); + assertRewriteAfterSimplify("a * 2 > 1", "((a * 2) > 1)", nameToSlot); + assertRewriteAfterSimplify("a * (-2) > 1", "((a * (-2)) > 1)", nameToSlot); + assertRewriteAfterSimplify("a / 2 > 1", "(a > cast((1 * 2) as INT))", nameToSlot); + + // test decimal type + assertRewriteAfterSimplify("1.1 + a > 2.22", "(cast(a as DECIMALV3(12, 2)) > cast((2.22 - 1.1) as DECIMALV3(12, 2)))", nameToSlot); + assertRewriteAfterSimplify("-1.1 + a > 2.22", "(cast(a as DECIMALV3(12, 2)) > cast((2.22 - (-1.1)) as DECIMALV3(12, 2)))", nameToSlot); + assertRewriteAfterSimplify("1.1 - a > 2.22", "(cast(a as DECIMALV3(11, 1)) < cast((1.1 - 2.22) as DECIMALV3(11, 1)))", nameToSlot); + assertRewriteAfterSimplify("-1.1 - a > 2.22", "(cast(a as DECIMALV3(11, 1)) < cast((-1.1 - 2.22) as DECIMALV3(11, 1)))", nameToSlot); + assertRewriteAfterSimplify("2.22 * a > 1.1", "((2.22 * a) > 1.1)", nameToSlot); + assertRewriteAfterSimplify("-2.22 * a > 1.1", "-2.22 * a > 1.1", nameToSlot); + assertRewriteAfterSimplify("2.22 / a > 1.1", "((2.22 / a) > 1.1)", nameToSlot); + assertRewriteAfterSimplify("-2.22 / a > 1.1", "((-2.22 / a) > 1.1)", nameToSlot); + assertRewriteAfterSimplify("a * 2.22 > 1.1", "a * 2.22 > 1.1", nameToSlot); + assertRewriteAfterSimplify("a * (-2.22) > 1.1", "a * (-2.22) > 1.1", nameToSlot); + assertRewriteAfterSimplify("a / 2.22 > 1.1", "(cast(a as DECIMALV3(13, 3)) > cast((1.1 * 2.22) as DECIMALV3(13, 3)))", nameToSlot); + assertRewriteAfterSimplify("a / (-2.22) > 1.1", "(cast((1.1 * -2.22) as DECIMALV3(13, 3)) > cast(a as DECIMALV3(13, 3)))", nameToSlot); + + // test (1 + a) can be processed + assertRewriteAfterSimplify("2 - (1 + a) > 3", "(a < ((2 - 3) - 1))", nameToSlot); + assertRewriteAfterSimplify("(1 - a) / 2 > 3", "(a < (1 - 6))", nameToSlot); + assertRewriteAfterSimplify("1 - a / 2 > 3", "(a < ((1 - 3) * 2))", nameToSlot); + assertRewriteAfterSimplify("(1 - (a + 4)) / 2 > 3", "(cast(a as BIGINT) < ((1 - 6) - 4))", nameToSlot); + assertRewriteAfterSimplify("2 * (1 + a) > 1", "(2 * (1 + a)) > 1", nameToSlot); } private void assertRewriteAfterSimplify(String expr, String expected, Map slotNameToSlot) {