diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rewrite/rules/InPredicateToEqualToRule.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rewrite/rules/InPredicateToEqualToRule.java
index e71764bb23..a729165077 100644
--- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rewrite/rules/InPredicateToEqualToRule.java
+++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rewrite/rules/InPredicateToEqualToRule.java
@@ -22,14 +22,26 @@ import org.apache.doris.nereids.rules.expression.rewrite.ExpressionRewriteContex
import org.apache.doris.nereids.trees.expressions.EqualTo;
import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.expressions.InPredicate;
+import org.apache.doris.nereids.util.ExpressionUtils;
+
+import com.google.common.base.Preconditions;
import java.util.List;
+import java.util.stream.Collectors;
/**
- * Rewrite InPredicate to an EqualTo Expression, if there exists exactly one element in InPredicate.options
+ * Paper: Quantifying TPC-H Choke Points and Their Optimizations
+ * - Figure 14:
+ *
+ * Rewrite InPredicate to disjunction, if there exists < 3 elements in InPredicate
+ * Examples:
+ * where A in (x, y) ==> where A = x or A = y
* Examples:
* where A in (x) ==> where A = x
* where A not in (x) ==> where not A = x (After ExpressionTranslator, "not A = x" will be translated to "A != x")
+ *
+ * NOTICE: it's related with `SimplifyRange`.
+ * They are same processes, so must change synchronously.
*/
public class InPredicateToEqualToRule extends AbstractExpressionRewriteRule {
@@ -37,11 +49,16 @@ public class InPredicateToEqualToRule extends AbstractExpressionRewriteRule {
@Override
public Expression visitInPredicate(InPredicate inPredicate, ExpressionRewriteContext context) {
- Expression left = inPredicate.getCompareExpr();
- List right = inPredicate.getOptions();
- if (right.size() != 1) {
- return new InPredicate(left.accept(this, context), right);
+ Expression cmpExpr = inPredicate.getCompareExpr();
+ List options = inPredicate.getOptions();
+ Preconditions.checkArgument(options.size() > 0, "InPredicate.options should not be empty");
+ if (options.size() > 2) {
+ return new InPredicate(cmpExpr.accept(this, context), options);
}
- return new EqualTo(left.accept(this, context), right.get(0).accept(this, context));
+ Expression newCmpExpr = cmpExpr.accept(this, context);
+ List disjunction = options.stream()
+ .map(option -> new EqualTo(newCmpExpr, option.accept(this, context)))
+ .collect(Collectors.toList());
+ return ExpressionUtils.or(disjunction);
}
}
diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rewrite/rules/SimplifyRange.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rewrite/rules/SimplifyRange.java
index 010031486c..da2904fb04 100644
--- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rewrite/rules/SimplifyRange.java
+++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rewrite/rules/SimplifyRange.java
@@ -44,6 +44,7 @@ import com.google.common.collect.Sets;
import java.util.Arrays;
import java.util.Collection;
+import java.util.Iterator;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
@@ -59,7 +60,7 @@ import java.util.stream.Collectors;
* a > 1 or a > 2 => a > 1
* a in (1,2,3) and a > 1 => a in (2,3)
* a in (1,2,3) and a in (3,4,5) => a = 3
- * a in(1,2,3) and a in (4,5,6) => false
+ * a in (1,2,3) and a in (4,5,6) => false
* The logic is as follows:
* 1. for `And` expression.
* 1. extract conjunctions then build `ValueDesc` for each conjunction
@@ -402,8 +403,13 @@ public class SimplifyRange extends AbstractExpressionRewriteRule {
@Override
public Expression toExpression() {
+ // NOTICE: it's related with `InPredicateToEqualToRule`
+ // They are same processes, so must change synchronously.
if (values.size() == 1) {
return new EqualTo(reference, values.iterator().next());
+ } else if (values.size() == 2) {
+ Iterator iterator = values.iterator();
+ return new Or(new EqualTo(reference, iterator.next()), new EqualTo(reference, iterator.next()));
} else {
return new InPredicate(reference, Lists.newArrayList(values));
}
diff --git a/fe/fe-core/src/test/java/org/apache/doris/nereids/datasets/ssb/SSBJoinReorderTest.java b/fe/fe-core/src/test/java/org/apache/doris/nereids/datasets/ssb/SSBJoinReorderTest.java
index 583250afd1..70140e8994 100644
--- a/fe/fe-core/src/test/java/org/apache/doris/nereids/datasets/ssb/SSBJoinReorderTest.java
+++ b/fe/fe-core/src/test/java/org/apache/doris/nereids/datasets/ssb/SSBJoinReorderTest.java
@@ -55,7 +55,7 @@ public class SSBJoinReorderTest extends SSBTestBase implements MemoPatternMatchS
"(lo_partkey = p_partkey)"
),
ImmutableList.of(
- "d_year IN (1997, 1998)",
+ "((d_year = 1997) OR (d_year = 1998))",
"(c_region = 'AMERICA')",
"(s_region = 'AMERICA')",
"((p_mfgr = 'MFGR#1') OR (p_mfgr = 'MFGR#2'))"
@@ -74,7 +74,7 @@ public class SSBJoinReorderTest extends SSBTestBase implements MemoPatternMatchS
"(lo_partkey = p_partkey)"
),
ImmutableList.of(
- "d_year IN (1997, 1998)",
+ "((d_year = 1997) OR (d_year = 1998))",
"(s_nation = 'UNITED STATES')",
"(p_category = 'MFGR#14')"
)
diff --git a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/expression/rewrite/ExpressionRewriteTest.java b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/expression/rewrite/ExpressionRewriteTest.java
index cf8deb6fe5..67425e00a6 100644
--- a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/expression/rewrite/ExpressionRewriteTest.java
+++ b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/expression/rewrite/ExpressionRewriteTest.java
@@ -175,20 +175,20 @@ public class ExpressionRewriteTest extends ExpressionRewriteTestHelper {
executor = new ExpressionRuleExecutor(ImmutableList.of(InPredicateToEqualToRule.INSTANCE));
assertRewrite("a in (1)", "a = 1");
- assertRewrite("a in (1, 2)", "a in (1, 2)");
+ assertRewrite("a in (1, 2)", "((a = 1) OR (a = 2))");
assertRewrite("a not in (1)", "not a = 1");
- assertRewrite("a not in (1, 2)", "not a in (1, 2)");
+ assertRewrite("a not in (1, 2)", "not ((a = 1) OR (a = 2))");
assertRewrite("a in (a in (1))", "a = (a = 1)");
- assertRewrite("a in (a in (1, 2))", "a = (a in (1, 2))");
+ assertRewrite("a in (a in (1, 2))", "a = ((a = 1) OR (a = 2))");
assertRewrite("(a in (1)) in (1)", "(a = 1) = 1");
- assertRewrite("(a in (1, 2)) in (1)", "(a in (1, 2)) = 1");
- assertRewrite("(a in (1)) in (1, 2)", "(a = 1) in (1, 2)");
+ assertRewrite("(a in (1, 2)) in (1)", "((a = 1) OR (a = 2)) = 1");
+ assertRewrite("(a in (1)) in (1, 2)", "((a = 1) = 1) OR ((a = 1) = 2)");
assertRewrite("case a when b in (1) then a else c end in (1)",
"case a when b = 1 then a else c end = 1");
assertRewrite("case a when b not in (1) then a else c end not in (1)",
"not case a when not b = 1 then a else c end = 1");
assertRewrite("case a when b not in (1) then a else c end in (1, 2)",
- "case a when not b = 1 then a else c end in (1, 2)");
+ "(CASE WHEN (a = ( not (b = 1))) THEN a ELSE c END = 1) OR (CASE WHEN (a = ( not (b = 1))) THEN a ELSE c END = 2)");
}
diff --git a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/expression/rewrite/SimplifyRangeTest.java b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/expression/rewrite/SimplifyRangeTest.java
index 07aea19772..a4a62503b9 100644
--- a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/expression/rewrite/SimplifyRangeTest.java
+++ b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/expression/rewrite/SimplifyRangeTest.java
@@ -87,7 +87,7 @@ public class SimplifyRangeTest {
assertRewrite("((TA > 10 or TA > 5) and TB > 10) or (TB > 10 and (TB > 20 or TB < 10))", "(TA > 5 and TB > 10) or (TB > 10 and (TB > 20 or TB < 10))");
assertRewrite("TA in (1,2,3) and TA > 10", "FALSE");
assertRewrite("TA in (1,2,3) and TA >= 1", "TA in (1,2,3)");
- assertRewrite("TA in (1,2,3) and TA > 1", "TA in (2,3)");
+ assertRewrite("TA in (1,2,3) and TA > 1", "((TA = 2) OR (TA = 3))");
assertRewrite("TA in (1,2,3) or TA >= 1", "TA >= 1");
assertRewrite("TA in (1)", "TA in (1)");
assertRewrite("TA in (1,2,3) and TA < 10", "TA in (1,2,3)");