[feature](Nereids): convert CaseWhen to If (#23040)
Add a rule to optimize CASE WHEN expression. Rewrite rule to convert CASE WHEN to IF. For example: CASE WHEN a > 1 THEN 1 ELSE 0 END -> IF(a > 1, 1, 0)
This commit is contained in:
@ -29,7 +29,6 @@ import org.apache.doris.nereids.rules.expression.rules.SimplifyArithmeticRule;
|
||||
import org.apache.doris.nereids.rules.expression.rules.SimplifyCastRule;
|
||||
import org.apache.doris.nereids.rules.expression.rules.SimplifyNotExprRule;
|
||||
import org.apache.doris.nereids.rules.expression.rules.SupportJavaDateFormatter;
|
||||
import org.apache.doris.nereids.trees.expressions.Expression;
|
||||
|
||||
import com.google.common.collect.ImmutableList;
|
||||
|
||||
@ -60,10 +59,5 @@ public class ExpressionNormalization extends ExpressionRewrite {
|
||||
public ExpressionNormalization() {
|
||||
super(new ExpressionRuleExecutor(NORMALIZE_REWRITE_RULES));
|
||||
}
|
||||
|
||||
@Override
|
||||
public Expression rewrite(Expression expression, ExpressionRewriteContext context) {
|
||||
return super.rewrite(expression, context);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@ -0,0 +1,49 @@
|
||||
// Licensed to the Apache Software Foundation (ASF) under one
|
||||
// or more contributor license agreements. See the NOTICE file
|
||||
// distributed with this work for additional information
|
||||
// regarding copyright ownership. The ASF licenses this file
|
||||
// to you under the Apache License, Version 2.0 (the
|
||||
// "License"); you may not use this file except in compliance
|
||||
// with the License. You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing,
|
||||
// software distributed under the License is distributed on an
|
||||
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
// KIND, either express or implied. See the License for the
|
||||
// specific language governing permissions and limitations
|
||||
// under the License.
|
||||
|
||||
package org.apache.doris.nereids.rules.expression.rules;
|
||||
|
||||
import org.apache.doris.nereids.rules.expression.AbstractExpressionRewriteRule;
|
||||
import org.apache.doris.nereids.rules.expression.ExpressionRewriteContext;
|
||||
import org.apache.doris.nereids.trees.expressions.CaseWhen;
|
||||
import org.apache.doris.nereids.trees.expressions.Expression;
|
||||
import org.apache.doris.nereids.trees.expressions.WhenClause;
|
||||
import org.apache.doris.nereids.trees.expressions.functions.scalar.If;
|
||||
import org.apache.doris.nereids.trees.expressions.literal.NullLiteral;
|
||||
|
||||
/**
|
||||
* Rewrite rule to convert CASE WHEN to IF.
|
||||
* For example:
|
||||
* CASE WHEN a > 1 THEN 1 ELSE 0 END -> IF(a > 1, 1, 0)
|
||||
*/
|
||||
public class CaseWhenToIf extends AbstractExpressionRewriteRule {
|
||||
|
||||
public static CaseWhenToIf INSTANCE = new CaseWhenToIf();
|
||||
|
||||
@Override
|
||||
public Expression visitCaseWhen(CaseWhen caseWhen, ExpressionRewriteContext context) {
|
||||
Expression expr = caseWhen;
|
||||
if (caseWhen.getWhenClauses().size() == 1) {
|
||||
WhenClause whenClause = caseWhen.getWhenClauses().get(0);
|
||||
Expression operand = whenClause.getOperand();
|
||||
Expression result = whenClause.getResult();
|
||||
expr = new If(operand, result, caseWhen.getDefaultValue().orElse(new NullLiteral(result.getDataType())));
|
||||
}
|
||||
// TODO: traverse expr in CASE WHEN / If.
|
||||
return expr;
|
||||
}
|
||||
}
|
||||
@ -59,10 +59,10 @@ import java.math.BigDecimal;
|
||||
/**
|
||||
* all expr rewrite rule test case.
|
||||
*/
|
||||
public class ExpressionRewriteTest extends ExpressionRewriteTestHelper {
|
||||
class ExpressionRewriteTest extends ExpressionRewriteTestHelper {
|
||||
|
||||
@Test
|
||||
public void testNotRewrite() {
|
||||
void testNotRewrite() {
|
||||
executor = new ExpressionRuleExecutor(ImmutableList.of(SimplifyNotExprRule.INSTANCE));
|
||||
|
||||
assertRewrite("not x", "not x");
|
||||
@ -87,7 +87,7 @@ public class ExpressionRewriteTest extends ExpressionRewriteTestHelper {
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testNormalizeExpressionRewrite() {
|
||||
void testNormalizeExpressionRewrite() {
|
||||
executor = new ExpressionRuleExecutor(ImmutableList.of(NormalizeBinaryPredicatesRule.INSTANCE));
|
||||
|
||||
assertRewrite("1 = 1", "1 = 1");
|
||||
@ -99,7 +99,7 @@ public class ExpressionRewriteTest extends ExpressionRewriteTestHelper {
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testDistinctPredicatesRewrite() {
|
||||
void testDistinctPredicatesRewrite() {
|
||||
executor = new ExpressionRuleExecutor(ImmutableList.of(DistinctPredicatesRule.INSTANCE));
|
||||
|
||||
assertRewrite("a = 1", "a = 1");
|
||||
@ -111,7 +111,7 @@ public class ExpressionRewriteTest extends ExpressionRewriteTestHelper {
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testExtractCommonFactorRewrite() {
|
||||
void testExtractCommonFactorRewrite() {
|
||||
executor = new ExpressionRuleExecutor(ImmutableList.of(ExtractCommonFactorRule.INSTANCE));
|
||||
|
||||
assertRewrite("a", "a");
|
||||
@ -164,7 +164,7 @@ public class ExpressionRewriteTest extends ExpressionRewriteTestHelper {
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testInPredicateToEqualToRule() {
|
||||
void testInPredicateToEqualToRule() {
|
||||
executor = new ExpressionRuleExecutor(ImmutableList.of(InPredicateToEqualToRule.INSTANCE));
|
||||
|
||||
assertRewrite("a in (1)", "a = 1");
|
||||
@ -180,14 +180,14 @@ public class ExpressionRewriteTest extends ExpressionRewriteTestHelper {
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testInPredicateDedup() {
|
||||
void testInPredicateDedup() {
|
||||
executor = new ExpressionRuleExecutor(ImmutableList.of(InPredicateDedup.INSTANCE));
|
||||
|
||||
assertRewrite("a in (1, 2, 1, 2)", "a in (1, 2)");
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testSimplifyCastRule() {
|
||||
void testSimplifyCastRule() {
|
||||
executor = new ExpressionRuleExecutor(ImmutableList.of(SimplifyCastRule.INSTANCE));
|
||||
|
||||
// deduplicate
|
||||
@ -219,7 +219,7 @@ public class ExpressionRewriteTest extends ExpressionRewriteTestHelper {
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testSimplifyComparisonPredicateRule() {
|
||||
void testSimplifyComparisonPredicateRule() {
|
||||
executor = new ExpressionRuleExecutor(ImmutableList.of(SimplifyCastRule.INSTANCE, SimplifyComparisonPredicate.INSTANCE));
|
||||
|
||||
Expression dtv2 = new DateTimeV2Literal(1, 1, 1, 1, 1, 1, 0);
|
||||
@ -271,7 +271,7 @@ public class ExpressionRewriteTest extends ExpressionRewriteTestHelper {
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testSimplifyDecimalV3Comparison() {
|
||||
void testSimplifyDecimalV3Comparison() {
|
||||
executor = new ExpressionRuleExecutor(ImmutableList.of(SimplifyDecimalV3Comparison.INSTANCE));
|
||||
|
||||
// do rewrite
|
||||
|
||||
Reference in New Issue
Block a user