diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/ExpressionNormalization.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/ExpressionNormalization.java index bd234204b2..6cff3553b4 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/ExpressionNormalization.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/ExpressionNormalization.java @@ -29,7 +29,6 @@ import org.apache.doris.nereids.rules.expression.rules.SimplifyArithmeticRule; import org.apache.doris.nereids.rules.expression.rules.SimplifyCastRule; import org.apache.doris.nereids.rules.expression.rules.SimplifyNotExprRule; import org.apache.doris.nereids.rules.expression.rules.SupportJavaDateFormatter; -import org.apache.doris.nereids.trees.expressions.Expression; import com.google.common.collect.ImmutableList; @@ -60,10 +59,5 @@ public class ExpressionNormalization extends ExpressionRewrite { public ExpressionNormalization() { super(new ExpressionRuleExecutor(NORMALIZE_REWRITE_RULES)); } - - @Override - public Expression rewrite(Expression expression, ExpressionRewriteContext context) { - return super.rewrite(expression, context); - } } diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rules/CaseWhenToIf.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rules/CaseWhenToIf.java new file mode 100644 index 0000000000..6372338406 --- /dev/null +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rules/CaseWhenToIf.java @@ -0,0 +1,49 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +package org.apache.doris.nereids.rules.expression.rules; + +import org.apache.doris.nereids.rules.expression.AbstractExpressionRewriteRule; +import org.apache.doris.nereids.rules.expression.ExpressionRewriteContext; +import org.apache.doris.nereids.trees.expressions.CaseWhen; +import org.apache.doris.nereids.trees.expressions.Expression; +import org.apache.doris.nereids.trees.expressions.WhenClause; +import org.apache.doris.nereids.trees.expressions.functions.scalar.If; +import org.apache.doris.nereids.trees.expressions.literal.NullLiteral; + +/** + * Rewrite rule to convert CASE WHEN to IF. + * For example: + * CASE WHEN a > 1 THEN 1 ELSE 0 END -> IF(a > 1, 1, 0) + */ +public class CaseWhenToIf extends AbstractExpressionRewriteRule { + + public static CaseWhenToIf INSTANCE = new CaseWhenToIf(); + + @Override + public Expression visitCaseWhen(CaseWhen caseWhen, ExpressionRewriteContext context) { + Expression expr = caseWhen; + if (caseWhen.getWhenClauses().size() == 1) { + WhenClause whenClause = caseWhen.getWhenClauses().get(0); + Expression operand = whenClause.getOperand(); + Expression result = whenClause.getResult(); + expr = new If(operand, result, caseWhen.getDefaultValue().orElse(new NullLiteral(result.getDataType()))); + } + // TODO: traverse expr in CASE WHEN / If. + return expr; + } +} diff --git a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/expression/ExpressionRewriteTest.java b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/expression/ExpressionRewriteTest.java index 5b3cdd7dd1..1010e7df27 100644 --- a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/expression/ExpressionRewriteTest.java +++ b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/expression/ExpressionRewriteTest.java @@ -59,10 +59,10 @@ import java.math.BigDecimal; /** * all expr rewrite rule test case. */ -public class ExpressionRewriteTest extends ExpressionRewriteTestHelper { +class ExpressionRewriteTest extends ExpressionRewriteTestHelper { @Test - public void testNotRewrite() { + void testNotRewrite() { executor = new ExpressionRuleExecutor(ImmutableList.of(SimplifyNotExprRule.INSTANCE)); assertRewrite("not x", "not x"); @@ -87,7 +87,7 @@ public class ExpressionRewriteTest extends ExpressionRewriteTestHelper { } @Test - public void testNormalizeExpressionRewrite() { + void testNormalizeExpressionRewrite() { executor = new ExpressionRuleExecutor(ImmutableList.of(NormalizeBinaryPredicatesRule.INSTANCE)); assertRewrite("1 = 1", "1 = 1"); @@ -99,7 +99,7 @@ public class ExpressionRewriteTest extends ExpressionRewriteTestHelper { } @Test - public void testDistinctPredicatesRewrite() { + void testDistinctPredicatesRewrite() { executor = new ExpressionRuleExecutor(ImmutableList.of(DistinctPredicatesRule.INSTANCE)); assertRewrite("a = 1", "a = 1"); @@ -111,7 +111,7 @@ public class ExpressionRewriteTest extends ExpressionRewriteTestHelper { } @Test - public void testExtractCommonFactorRewrite() { + void testExtractCommonFactorRewrite() { executor = new ExpressionRuleExecutor(ImmutableList.of(ExtractCommonFactorRule.INSTANCE)); assertRewrite("a", "a"); @@ -164,7 +164,7 @@ public class ExpressionRewriteTest extends ExpressionRewriteTestHelper { } @Test - public void testInPredicateToEqualToRule() { + void testInPredicateToEqualToRule() { executor = new ExpressionRuleExecutor(ImmutableList.of(InPredicateToEqualToRule.INSTANCE)); assertRewrite("a in (1)", "a = 1"); @@ -180,14 +180,14 @@ public class ExpressionRewriteTest extends ExpressionRewriteTestHelper { } @Test - public void testInPredicateDedup() { + void testInPredicateDedup() { executor = new ExpressionRuleExecutor(ImmutableList.of(InPredicateDedup.INSTANCE)); assertRewrite("a in (1, 2, 1, 2)", "a in (1, 2)"); } @Test - public void testSimplifyCastRule() { + void testSimplifyCastRule() { executor = new ExpressionRuleExecutor(ImmutableList.of(SimplifyCastRule.INSTANCE)); // deduplicate @@ -219,7 +219,7 @@ public class ExpressionRewriteTest extends ExpressionRewriteTestHelper { } @Test - public void testSimplifyComparisonPredicateRule() { + void testSimplifyComparisonPredicateRule() { executor = new ExpressionRuleExecutor(ImmutableList.of(SimplifyCastRule.INSTANCE, SimplifyComparisonPredicate.INSTANCE)); Expression dtv2 = new DateTimeV2Literal(1, 1, 1, 1, 1, 1, 0); @@ -271,7 +271,7 @@ public class ExpressionRewriteTest extends ExpressionRewriteTestHelper { } @Test - public void testSimplifyDecimalV3Comparison() { + void testSimplifyDecimalV3Comparison() { executor = new ExpressionRuleExecutor(ImmutableList.of(SimplifyDecimalV3Comparison.INSTANCE)); // do rewrite