diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rewrite/rules/InPredicateToEqualToRule.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rewrite/rules/InPredicateToEqualToRule.java index e71764bb23..a729165077 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rewrite/rules/InPredicateToEqualToRule.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rewrite/rules/InPredicateToEqualToRule.java @@ -22,14 +22,26 @@ import org.apache.doris.nereids.rules.expression.rewrite.ExpressionRewriteContex import org.apache.doris.nereids.trees.expressions.EqualTo; import org.apache.doris.nereids.trees.expressions.Expression; import org.apache.doris.nereids.trees.expressions.InPredicate; +import org.apache.doris.nereids.util.ExpressionUtils; + +import com.google.common.base.Preconditions; import java.util.List; +import java.util.stream.Collectors; /** - * Rewrite InPredicate to an EqualTo Expression, if there exists exactly one element in InPredicate.options + * Paper: Quantifying TPC-H Choke Points and Their Optimizations + * - Figure 14: + *

+ * Rewrite InPredicate to disjunction, if there exists < 3 elements in InPredicate + * Examples: + * where A in (x, y) ==> where A = x or A = y * Examples: * where A in (x) ==> where A = x * where A not in (x) ==> where not A = x (After ExpressionTranslator, "not A = x" will be translated to "A != x") + *

+ * NOTICE: it's related with `SimplifyRange`. + * They are same processes, so must change synchronously. */ public class InPredicateToEqualToRule extends AbstractExpressionRewriteRule { @@ -37,11 +49,16 @@ public class InPredicateToEqualToRule extends AbstractExpressionRewriteRule { @Override public Expression visitInPredicate(InPredicate inPredicate, ExpressionRewriteContext context) { - Expression left = inPredicate.getCompareExpr(); - List right = inPredicate.getOptions(); - if (right.size() != 1) { - return new InPredicate(left.accept(this, context), right); + Expression cmpExpr = inPredicate.getCompareExpr(); + List options = inPredicate.getOptions(); + Preconditions.checkArgument(options.size() > 0, "InPredicate.options should not be empty"); + if (options.size() > 2) { + return new InPredicate(cmpExpr.accept(this, context), options); } - return new EqualTo(left.accept(this, context), right.get(0).accept(this, context)); + Expression newCmpExpr = cmpExpr.accept(this, context); + List disjunction = options.stream() + .map(option -> new EqualTo(newCmpExpr, option.accept(this, context))) + .collect(Collectors.toList()); + return ExpressionUtils.or(disjunction); } } diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rewrite/rules/SimplifyRange.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rewrite/rules/SimplifyRange.java index 010031486c..da2904fb04 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rewrite/rules/SimplifyRange.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rewrite/rules/SimplifyRange.java @@ -44,6 +44,7 @@ import com.google.common.collect.Sets; import java.util.Arrays; import java.util.Collection; +import java.util.Iterator; import java.util.LinkedHashMap; import java.util.List; import java.util.Map; @@ -59,7 +60,7 @@ import java.util.stream.Collectors; * a > 1 or a > 2 => a > 1 * a in (1,2,3) and a > 1 => a in (2,3) * a in (1,2,3) and a in (3,4,5) => a = 3 - * a in(1,2,3) and a in (4,5,6) => false + * a in (1,2,3) and a in (4,5,6) => false * The logic is as follows: * 1. for `And` expression. * 1. extract conjunctions then build `ValueDesc` for each conjunction @@ -402,8 +403,13 @@ public class SimplifyRange extends AbstractExpressionRewriteRule { @Override public Expression toExpression() { + // NOTICE: it's related with `InPredicateToEqualToRule` + // They are same processes, so must change synchronously. if (values.size() == 1) { return new EqualTo(reference, values.iterator().next()); + } else if (values.size() == 2) { + Iterator iterator = values.iterator(); + return new Or(new EqualTo(reference, iterator.next()), new EqualTo(reference, iterator.next())); } else { return new InPredicate(reference, Lists.newArrayList(values)); } diff --git a/fe/fe-core/src/test/java/org/apache/doris/nereids/datasets/ssb/SSBJoinReorderTest.java b/fe/fe-core/src/test/java/org/apache/doris/nereids/datasets/ssb/SSBJoinReorderTest.java index 583250afd1..70140e8994 100644 --- a/fe/fe-core/src/test/java/org/apache/doris/nereids/datasets/ssb/SSBJoinReorderTest.java +++ b/fe/fe-core/src/test/java/org/apache/doris/nereids/datasets/ssb/SSBJoinReorderTest.java @@ -55,7 +55,7 @@ public class SSBJoinReorderTest extends SSBTestBase implements MemoPatternMatchS "(lo_partkey = p_partkey)" ), ImmutableList.of( - "d_year IN (1997, 1998)", + "((d_year = 1997) OR (d_year = 1998))", "(c_region = 'AMERICA')", "(s_region = 'AMERICA')", "((p_mfgr = 'MFGR#1') OR (p_mfgr = 'MFGR#2'))" @@ -74,7 +74,7 @@ public class SSBJoinReorderTest extends SSBTestBase implements MemoPatternMatchS "(lo_partkey = p_partkey)" ), ImmutableList.of( - "d_year IN (1997, 1998)", + "((d_year = 1997) OR (d_year = 1998))", "(s_nation = 'UNITED STATES')", "(p_category = 'MFGR#14')" ) diff --git a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/expression/rewrite/ExpressionRewriteTest.java b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/expression/rewrite/ExpressionRewriteTest.java index cf8deb6fe5..67425e00a6 100644 --- a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/expression/rewrite/ExpressionRewriteTest.java +++ b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/expression/rewrite/ExpressionRewriteTest.java @@ -175,20 +175,20 @@ public class ExpressionRewriteTest extends ExpressionRewriteTestHelper { executor = new ExpressionRuleExecutor(ImmutableList.of(InPredicateToEqualToRule.INSTANCE)); assertRewrite("a in (1)", "a = 1"); - assertRewrite("a in (1, 2)", "a in (1, 2)"); + assertRewrite("a in (1, 2)", "((a = 1) OR (a = 2))"); assertRewrite("a not in (1)", "not a = 1"); - assertRewrite("a not in (1, 2)", "not a in (1, 2)"); + assertRewrite("a not in (1, 2)", "not ((a = 1) OR (a = 2))"); assertRewrite("a in (a in (1))", "a = (a = 1)"); - assertRewrite("a in (a in (1, 2))", "a = (a in (1, 2))"); + assertRewrite("a in (a in (1, 2))", "a = ((a = 1) OR (a = 2))"); assertRewrite("(a in (1)) in (1)", "(a = 1) = 1"); - assertRewrite("(a in (1, 2)) in (1)", "(a in (1, 2)) = 1"); - assertRewrite("(a in (1)) in (1, 2)", "(a = 1) in (1, 2)"); + assertRewrite("(a in (1, 2)) in (1)", "((a = 1) OR (a = 2)) = 1"); + assertRewrite("(a in (1)) in (1, 2)", "((a = 1) = 1) OR ((a = 1) = 2)"); assertRewrite("case a when b in (1) then a else c end in (1)", "case a when b = 1 then a else c end = 1"); assertRewrite("case a when b not in (1) then a else c end not in (1)", "not case a when not b = 1 then a else c end = 1"); assertRewrite("case a when b not in (1) then a else c end in (1, 2)", - "case a when not b = 1 then a else c end in (1, 2)"); + "(CASE WHEN (a = ( not (b = 1))) THEN a ELSE c END = 1) OR (CASE WHEN (a = ( not (b = 1))) THEN a ELSE c END = 2)"); } diff --git a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/expression/rewrite/SimplifyRangeTest.java b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/expression/rewrite/SimplifyRangeTest.java index 07aea19772..a4a62503b9 100644 --- a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/expression/rewrite/SimplifyRangeTest.java +++ b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/expression/rewrite/SimplifyRangeTest.java @@ -87,7 +87,7 @@ public class SimplifyRangeTest { assertRewrite("((TA > 10 or TA > 5) and TB > 10) or (TB > 10 and (TB > 20 or TB < 10))", "(TA > 5 and TB > 10) or (TB > 10 and (TB > 20 or TB < 10))"); assertRewrite("TA in (1,2,3) and TA > 10", "FALSE"); assertRewrite("TA in (1,2,3) and TA >= 1", "TA in (1,2,3)"); - assertRewrite("TA in (1,2,3) and TA > 1", "TA in (2,3)"); + assertRewrite("TA in (1,2,3) and TA > 1", "((TA = 2) OR (TA = 3))"); assertRewrite("TA in (1,2,3) or TA >= 1", "TA >= 1"); assertRewrite("TA in (1)", "TA in (1)"); assertRewrite("TA in (1,2,3) and TA < 10", "TA in (1,2,3)");