diff --git a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/expression/ExpressionRewriteTestHelper.java b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/expression/ExpressionRewriteTestHelper.java index ffd3c29a18..60d4384207 100644 --- a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/expression/ExpressionRewriteTestHelper.java +++ b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/expression/ExpressionRewriteTestHelper.java @@ -30,6 +30,7 @@ import org.apache.doris.nereids.types.BigIntType; import org.apache.doris.nereids.types.BooleanType; import org.apache.doris.nereids.types.DataType; import org.apache.doris.nereids.types.DateV2Type; +import org.apache.doris.nereids.types.DecimalV3Type; import org.apache.doris.nereids.types.DoubleType; import org.apache.doris.nereids.types.IntegerType; import org.apache.doris.nereids.types.StringType; @@ -124,6 +125,8 @@ public abstract class ExpressionRewriteTestHelper { return BooleanType.INSTANCE; case 'C': return DateV2Type.INSTANCE; + case 'M': + return DecimalV3Type.SYSTEM_DEFAULT; default: return BigIntType.INSTANCE; } diff --git a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/expression/SimplifyArithmeticRuleTest.java b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/expression/SimplifyArithmeticRuleTest.java index 87524c621d..6ecbdf8c72 100644 --- a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/expression/SimplifyArithmeticRuleTest.java +++ b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/expression/SimplifyArithmeticRuleTest.java @@ -52,6 +52,45 @@ class SimplifyArithmeticRuleTest extends ExpressionRewriteTestHelper { assertRewriteAfterTypeCoercion("IA * IB / 2 / IC * 2 * ID / 4", "(((cast((IA * IB) as DOUBLE) / cast(IC as DOUBLE)) * cast(ID as DOUBLE)) / 4.0)"); } + @Test + void testSimplifyArithmeticRuleOnly() { + executor = new ExpressionRuleExecutor(ImmutableList.of( + SimplifyArithmeticRule.INSTANCE + )); + + // add and subtract + assertRewriteAfterTypeCoercion("-IA - ((1 + IB) - (3 - IC))", "(((((0 - 1) + 3) - IA) - IB) - IC)"); + assertRewriteAfterTypeCoercion("-2 - IA - ((1 - IB) - (3 + IC))", "(((((-2 - 1) + 3) - IA) + IB) + IC)"); + assertRewriteAfterTypeCoercion("-IA - 2 - ((IB + 1) - (3 - (IC - 4)))", "(((((((0 - 2) - 1) + 3) + 4) - IA) - IB) - IC)"); + assertRewriteAfterTypeCoercion("-IA - 2 - ((IB - 1) - (3 - (IC + 4)))", "(((((((0 - 2) + 1) + 3) - 4) - IA) - IB) - IC)"); + assertRewriteAfterTypeCoercion("-IA - 2 - ((-IB - 1) - (3 + (IC + 4)))", "((((((((0 - 2) - 0) + 1) + 3) + 4) - IA) + IB) + IC)"); + assertRewriteAfterTypeCoercion("IA - 2 - ((-IB - 1) - (3 + (IC + 4)))", "(((IA + IB) + IC) - ((((2 + 0) - 1) - 3) - 4))"); + + // multiply and divide + assertRewriteAfterTypeCoercion("2 / IA / ((1 / IB) / (3 * IC))", "((((cast(2 as DOUBLE) / cast(1 as DOUBLE)) / cast(IA as DOUBLE)) * cast(IB as DOUBLE)) * cast((3 * IC) as DOUBLE))"); + assertRewriteAfterTypeCoercion("IA / 2 / ((IB * 1) / (3 / (IC / 4)))", "(((cast(IA as DOUBLE) / cast((IB * 1) as DOUBLE)) / cast(IC as DOUBLE)) / ((cast(2 as DOUBLE) / cast(3 as DOUBLE)) / cast(4 as DOUBLE)))"); + assertRewriteAfterTypeCoercion("IA / 2 / ((IB / 1) / (3 / (IC * 4)))", "(((cast(IA as DOUBLE) / cast(IB as DOUBLE)) / cast((IC * 4) as DOUBLE)) / ((cast(2 as DOUBLE) / cast(1 as DOUBLE)) / cast(3 as DOUBLE)))"); + assertRewriteAfterTypeCoercion("IA / 2 / ((IB / 1) / (3 * (IC * 4)))", "(((cast(IA as DOUBLE) / cast(IB as DOUBLE)) * cast((3 * (IC * 4)) as DOUBLE)) / (cast(2 as DOUBLE) / cast(1 as DOUBLE)))"); + + // hybrid + // root is subtract + assertRewriteAfterTypeCoercion("-2 - IA * ((1 - IB) - (3 / IC))", "(cast(-2 as DOUBLE) - (cast(IA as DOUBLE) * (cast((1 - IB) as DOUBLE) - (cast(3 as DOUBLE) / cast(IC as DOUBLE)))))"); + assertRewriteAfterTypeCoercion("-IA - 2 - ((IB * 1) - (3 * (IC / 4)))", "((cast(((0 - IA) - 2) as DOUBLE) - cast((IB * 1) as DOUBLE)) + (cast(3 as DOUBLE) * (cast(IC as DOUBLE) / cast(4 as DOUBLE))))"); + // root is add + assertRewriteAfterTypeCoercion("-IA * 2 + ((IB - 1) / (3 - (IC + 4)))", "(cast(((0 - IA) * 2) as DOUBLE) + (cast((IB - 1) as DOUBLE) / cast((3 - (IC + 4)) as DOUBLE)))"); + assertRewriteAfterTypeCoercion("-IA + 2 + ((IB - 1) - (3 * (IC + 4)))", "(((((0 + 2) - 1) - IA) + IB) - (3 * (IC + 4)))"); + // root is multiply + assertRewriteAfterTypeCoercion("-IA / 2 * ((-IB - 1) - (3 + (IC + 4)))", "((cast((0 - IA) as DOUBLE) * cast((((0 - IB) - 1) - (3 + (IC + 4))) as DOUBLE)) / cast(2 as DOUBLE))"); + assertRewriteAfterTypeCoercion("-IA / 2 * ((-IB - 1) * (3 / (IC + 4)))", "(((cast((0 - IA) as DOUBLE) * cast(((0 - IB) - 1) as DOUBLE)) / cast((IC + 4) as DOUBLE)) / (cast(2 as DOUBLE) / cast(3 as DOUBLE)))"); + // root is divide + assertRewriteAfterTypeCoercion("(-IA / 2) / ((-IB - 1) - (3 + (IC + 4)))", "((cast((0 - IA) as DOUBLE) / cast((((0 - IB) - 1) - (3 + (IC + 4))) as DOUBLE)) / cast(2 as DOUBLE))"); + assertRewriteAfterTypeCoercion("(-IA / 2) / ((-IB - 1) / (3 + (IC * 4)))", "(((cast((0 - IA) as DOUBLE) / cast(((0 - IB) - 1) as DOUBLE)) * cast((3 + (IC * 4)) as DOUBLE)) / cast(2 as DOUBLE))"); + + // unsupported decimal + assertRewriteAfterTypeCoercion("-2 - MA - ((1 - IB) - (3 + IC))", "((cast(-2 as DECIMALV3(38, 0)) - MA) - cast(((1 - IB) - (3 + IC)) as DECIMALV3(38, 0)))"); + assertRewriteAfterTypeCoercion("-IA / 2.0 * ((-IB - 1) - (3 + (IC + 4)))", "((cast((0 - IA) as DECIMALV3(25, 5)) / 2.0) * cast((((0 - IB) - 1) - (3 + (IC + 4))) as DECIMALV3(20, 0)))"); + } + @Test void testSimplifyArithmeticComparison() { executor = new ExpressionRuleExecutor(ImmutableList.of(