diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/batch/NereidsRewriteJobExecutor.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/batch/NereidsRewriteJobExecutor.java
index b77c599cc0..5534871b50 100644
--- a/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/batch/NereidsRewriteJobExecutor.java
+++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/batch/NereidsRewriteJobExecutor.java
@@ -27,6 +27,7 @@ import org.apache.doris.nereids.rules.rewrite.AggregateDisassemble;
import org.apache.doris.nereids.rules.rewrite.logical.ColumnPruning;
import org.apache.doris.nereids.rules.rewrite.logical.EliminateFilter;
import org.apache.doris.nereids.rules.rewrite.logical.EliminateLimit;
+import org.apache.doris.nereids.rules.rewrite.logical.ExtractSingleTableExpressionFromDisjunction;
import org.apache.doris.nereids.rules.rewrite.logical.FindHashConditionForJoin;
import org.apache.doris.nereids.rules.rewrite.logical.LimitPushDown;
import org.apache.doris.nereids.rules.rewrite.logical.NormalizeAggregate;
@@ -59,6 +60,7 @@ public class NereidsRewriteJobExecutor extends BatchRulesJob {
.addAll(new ConvertApplyToJoinJob(cascadesContext).rulesJob)
.add(topDownBatch(ImmutableList.of(new ExpressionNormalization())))
.add(topDownBatch(ImmutableList.of(new ExpressionOptimization())))
+ .add(topDownBatch(ImmutableList.of(new ExtractSingleTableExpressionFromDisjunction())))
.add(topDownBatch(ImmutableList.of(new NormalizeAggregate())))
.add(topDownBatch(RuleSet.PUSH_DOWN_JOIN_CONDITION_RULES, false))
.add(topDownBatch(ImmutableList.of(new ReorderJoin())))
diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/RuleType.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/RuleType.java
index 1765d543f7..e9968f2fd1 100644
--- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/RuleType.java
+++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/RuleType.java
@@ -109,6 +109,7 @@ public enum RuleType {
ROLLUP_AGG_PROJECT_FILTER_SCAN(RuleTypeClass.REWRITE),
ROLLUP_AGG_FILTER_PROJECT_SCAN(RuleTypeClass.REWRITE),
OLAP_SCAN_PARTITION_PRUNE(RuleTypeClass.REWRITE),
+ EXTRACT_SINGLE_TABLE_EXPRESSION_FROM_DISJUNCTION(RuleTypeClass.REWRITE),
// Pushdown filter
PUSHDOWN_FILTER_THROUGH_PROJET(RuleTypeClass.REWRITE),
PUSHDOWN_PROJECT_THROUGHT_LIMIT(RuleTypeClass.REWRITE),
diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/logical/ExtractSingleTableExpressionFromDisjunction.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/logical/ExtractSingleTableExpressionFromDisjunction.java
new file mode 100644
index 0000000000..9317ea6969
--- /dev/null
+++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/logical/ExtractSingleTableExpressionFromDisjunction.java
@@ -0,0 +1,166 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+package org.apache.doris.nereids.rules.rewrite.logical;
+
+import org.apache.doris.nereids.rules.Rule;
+import org.apache.doris.nereids.rules.RuleType;
+import org.apache.doris.nereids.rules.rewrite.OneRewriteRuleFactory;
+import org.apache.doris.nereids.trees.expressions.Expression;
+import org.apache.doris.nereids.trees.expressions.Slot;
+import org.apache.doris.nereids.trees.expressions.SlotReference;
+import org.apache.doris.nereids.trees.plans.logical.LogicalFilter;
+import org.apache.doris.nereids.util.ExpressionUtils;
+
+import com.google.common.collect.Lists;
+
+import java.util.List;
+import java.util.Optional;
+import java.util.Set;
+import java.util.stream.Collectors;
+
+/**
+ * Example:
+ * (n1.n_name = 'FRANCE' and n2.n_name = 'GERMANY') or (n1.n_name = 'GERMANY' and n2.n_name = 'FRANCE')
+ * =>
+ * (n1.n_name = 'FRANCE' and n2.n_name = 'GERMANY') or (n1.n_name = 'GERMANY' and n2.n_name = 'FRANCE')
+ * and (n1.n_name = 'FRANCE' or n1.n_name='GERMANY') and (n2.n_name='GERMANY' or n2.n_name='FRANCE')
+ *
+ * (n1.n_name = 'FRANCE' or n1.n_name='GERMANY') is a logical redundant, but it could be pushed down to scan(n1) to
+ * reduce the number of scan output tuples.
+ * For complete sql example, refer to tpch q7.
+ * ==================================================================================================
+ *
+ * There are 2 cases, in which the redundant expressions are useless:
+ * 1. filter(expr)-->XXX out join.
+ * For example, for left join, the redundant expression for right side is not useful, because we cannot push expression
+ * down to right child. Refer to PushDownJoinOtherCondition Rule for push-down cases.
+ * But it is hard to detect this case, if the outer join is a descendant but not child of the filter.
+ * 2. filter(expr)
+ * |-->upper-join
+ * |-->bottom-join
+ * +-->child
+ * In current version, we do not extract redundant expression for bottom-join. This redundancy is good for
+ * upper-join (reduce the number of input tuple from bottom join), but it becomes unuseful if we rotate the join tree.
+ * ==================================================================================================
+ *
+ * Implementation note:
+ * 1. This rule should only be applied ONCE to avoid generate same redundant expression.
+ * 2. This version only generates redundant expressions, but not push them.
+ * 3. A redundant expression only contains slots from a single table.
+ * 4. This rule is applied after rules converting sub-query to join.
+ * 5. Add a flag 'isRedundant' in Expression. It is true, if it is generated by this rule.
+ * 6. The useless redundant expression should be removed, if it cannot be pushed down. We need a new rule
+ * `RemoveRedundantExpression` to fulfill this purpose.
+ * 7. In old optimizer, there is `InferFilterRule` generates redundancy expressions. Its Nereid counterpart also need
+ * `RemoveRedundantExpression`.
+ */
+public class ExtractSingleTableExpressionFromDisjunction extends OneRewriteRuleFactory {
+ @Override
+ public Rule build() {
+ return logicalFilter().whenNot(LogicalFilter::isSingleTableExpressionExtracted).then(filter -> {
+ //filter = [(n1.n_name = 'FRANCE' and n2.n_name = 'GERMANY')
+ // or (n1.n_name = 'GERMANY' and n2.n_name = 'FRANCE')]
+ // and ...
+ List conjuncts = ExpressionUtils.extractConjunction(filter.getPredicates())
+ .stream().collect(Collectors.toList());
+
+ List redundants = Lists.newArrayList();
+ for (Expression conjunct : conjuncts) {
+ //conjunct=(n1.n_name = 'FRANCE' and n2.n_name = 'GERMANY')
+ // or (n1.n_name = 'GERMANY' and n2.n_name = 'FRANCE')
+ List disjuncts = ExpressionUtils.extractDisjunction(conjunct);
+ //disjuncts={ (n1.n_name = 'FRANCE' and n2.n_name = 'GERMANY'),
+ // (n1.n_name = 'GERMANY' and n2.n_name = 'FRANCE')}
+ if (disjuncts.size() == 1) {
+ continue;
+ }
+ //only check table in first disjunct.
+ //In our example, qualifiers = { n1, n2 }
+ Expression first = disjuncts.get(0);
+ Set qualifiers = first.getInputSlots()
+ .stream()
+ .map(SlotReference.class::cast)
+ .map(this::getSlotQualifierAsString)
+ .collect(Collectors.toSet());
+ //try to extract
+ for (String qualifier : qualifiers) {
+ List extractForAll = Lists.newArrayList();
+ boolean success = true;
+ for (Expression expr : ExpressionUtils.extractDisjunction(conjunct)) {
+ Optional extracted = extractSingleTableExpression(expr, qualifier);
+ if (!extracted.isPresent()) {
+ //extract failed
+ success = false;
+ break;
+ } else {
+ extractForAll.add(extracted.get());
+ }
+ }
+ if (success) {
+ redundants.add(ExpressionUtils.or(extractForAll));
+ }
+ }
+
+ }
+ if (redundants.isEmpty()) {
+ return new LogicalFilter<>(filter.getPredicates(), true, filter.child());
+ } else {
+ Expression newPredicate = ExpressionUtils.and(filter.getPredicates(), ExpressionUtils.and(redundants));
+ return new LogicalFilter<>(newPredicate,
+ true, filter.child());
+ }
+ }).toRule(RuleType.EXTRACT_SINGLE_TABLE_EXPRESSION_FROM_DISJUNCTION);
+ }
+
+ private String getSlotQualifierAsString(SlotReference slotReference) {
+ StringBuilder builder = new StringBuilder();
+ for (String q : slotReference.getQualifier()) {
+ builder.append(q).append('.');
+ }
+ return builder.toString();
+ }
+
+ //extract some conjucts from expr, all slots of the extracted conjunct comes from the table referred by qualifier.
+ //example: expr=(n1.n_name = 'FRANCE' and n2.n_name = 'GERMANY'), qualifier="n1."
+ //output: n1.n_name = 'FRANCE'
+ private Optional extractSingleTableExpression(Expression expr, String qualifier) {
+ List output = Lists.newArrayList();
+ List conjuncts = ExpressionUtils.extractConjunction(expr);
+ for (Expression conjunct : conjuncts) {
+ if (isSingleTableExpression(conjunct, qualifier)) {
+ output.add(conjunct);
+ }
+ }
+ if (output.isEmpty()) {
+ return Optional.empty();
+ } else {
+ return Optional.of(ExpressionUtils.and(output));
+ }
+ }
+
+ private boolean isSingleTableExpression(Expression expr, String qualifier) {
+ //TODO: cache getSlotQualifierAsString() result.
+ for (Slot slot : expr.getInputSlots()) {
+ String slotQualifier = getSlotQualifierAsString((SlotReference) slot);
+ if (!slotQualifier.equals(qualifier)) {
+ return false;
+ }
+ }
+ return true;
+ }
+}
diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/logical/LogicalFilter.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/logical/LogicalFilter.java
index 6225c985fd..9e5f624fec 100644
--- a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/logical/LogicalFilter.java
+++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/logical/LogicalFilter.java
@@ -40,14 +40,30 @@ import java.util.Optional;
public class LogicalFilter extends LogicalUnary implements Filter {
private final Expression predicates;
+ private final boolean singleTableExpressionExtracted;
+
public LogicalFilter(Expression predicates, CHILD_TYPE child) {
this(predicates, Optional.empty(), Optional.empty(), child);
}
+ public LogicalFilter(Expression predicates,
+ boolean singleTableExpressionExtracted,
+ CHILD_TYPE child) {
+ this(predicates, Optional.empty(), singleTableExpressionExtracted,
+ Optional.empty(), child);
+ }
+
public LogicalFilter(Expression predicates, Optional groupExpression,
Optional logicalProperties, CHILD_TYPE child) {
+ this(predicates, groupExpression, false, logicalProperties, child);
+ }
+
+ public LogicalFilter(Expression predicates, Optional groupExpression,
+ boolean singleTableExpressionExtracted,
+ Optional logicalProperties, CHILD_TYPE child) {
super(PlanType.LOGICAL_FILTER, groupExpression, logicalProperties, child);
this.predicates = Objects.requireNonNull(predicates, "predicates can not be null");
+ this.singleTableExpressionExtracted = singleTableExpressionExtracted;
}
public Expression getPredicates() {
@@ -75,12 +91,13 @@ public class LogicalFilter extends LogicalUnary extends LogicalUnary withChildren(List children) {
Preconditions.checkArgument(children.size() == 1);
- return new LogicalFilter<>(predicates, children.get(0));
+ return new LogicalFilter<>(predicates, singleTableExpressionExtracted, children.get(0));
}
@Override
public Plan withGroupExpression(Optional groupExpression) {
- return new LogicalFilter<>(predicates, groupExpression, Optional.of(getLogicalProperties()), child());
+ return new LogicalFilter<>(predicates, groupExpression, singleTableExpressionExtracted,
+ Optional.of(getLogicalProperties()), child());
}
@Override
public Plan withLogicalProperties(Optional logicalProperties) {
- return new LogicalFilter<>(predicates, Optional.empty(), logicalProperties, child());
+ return new LogicalFilter<>(predicates, Optional.empty(),
+ singleTableExpressionExtracted,
+ logicalProperties, child());
}
+
+ public boolean isSingleTableExpressionExtracted() {
+ return singleTableExpressionExtracted;
+ }
+
}
diff --git a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/logical/ExtractSingleTableExpressionFromDisjunctionTest.java b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/logical/ExtractSingleTableExpressionFromDisjunctionTest.java
new file mode 100644
index 0000000000..1a1e75ad6e
--- /dev/null
+++ b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/logical/ExtractSingleTableExpressionFromDisjunctionTest.java
@@ -0,0 +1,185 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+package org.apache.doris.nereids.rules.rewrite.logical;
+
+import org.apache.doris.nereids.trees.expressions.And;
+import org.apache.doris.nereids.trees.expressions.EqualTo;
+import org.apache.doris.nereids.trees.expressions.Expression;
+import org.apache.doris.nereids.trees.expressions.Or;
+import org.apache.doris.nereids.trees.expressions.SlotReference;
+import org.apache.doris.nereids.trees.expressions.literal.IntegerLiteral;
+import org.apache.doris.nereids.trees.expressions.literal.StringLiteral;
+import org.apache.doris.nereids.trees.plans.JoinType;
+import org.apache.doris.nereids.trees.plans.Plan;
+import org.apache.doris.nereids.trees.plans.logical.LogicalFilter;
+import org.apache.doris.nereids.trees.plans.logical.LogicalJoin;
+import org.apache.doris.nereids.trees.plans.logical.LogicalOlapScan;
+import org.apache.doris.nereids.util.ExpressionUtils;
+import org.apache.doris.nereids.util.MemoTestUtils;
+import org.apache.doris.nereids.util.PatternMatchSupported;
+import org.apache.doris.nereids.util.PlanChecker;
+import org.apache.doris.nereids.util.PlanConstructor;
+
+import com.google.common.collect.ImmutableList;
+import org.junit.jupiter.api.Assertions;
+import org.junit.jupiter.api.BeforeAll;
+import org.junit.jupiter.api.Test;
+import org.junit.jupiter.api.TestInstance;
+
+import java.util.List;
+
+@TestInstance(TestInstance.Lifecycle.PER_CLASS)
+public class ExtractSingleTableExpressionFromDisjunctionTest implements PatternMatchSupported {
+ Plan student;
+ Plan course;
+ SlotReference courseCid;
+ SlotReference courseName;
+ SlotReference studentAge;
+ SlotReference studentGender;
+
+ @BeforeAll
+ public final void beforeAll() {
+ student = new LogicalOlapScan(PlanConstructor.getNextRelationId(), PlanConstructor.student, ImmutableList.of(""));
+ course = new LogicalOlapScan(PlanConstructor.getNextRelationId(), PlanConstructor.course, ImmutableList.of(""));
+ //select *
+ //from student join course
+ //where (course.cid=1 and student.age=10) or (student.gender = 0 and course.name='abc')
+ courseCid = (SlotReference) course.getOutput().get(0);
+ courseName = (SlotReference) course.getOutput().get(1);
+ studentAge = (SlotReference) student.getOutput().get(3);
+ studentGender = (SlotReference) student.getOutput().get(1);
+ }
+ /**
+ *(cid=1 and sage=10) or (sgender=1 and cname='abc')
+ * =>
+ * (cid=1 or cname='abc') + (sage=10 or sgender=1)
+ */
+
+ @Test
+ public void testExtract1() {
+ Expression expr = new Or(
+ new And(
+ new EqualTo(courseCid, new IntegerLiteral(1)),
+ new EqualTo(studentAge, new IntegerLiteral(10))
+ ),
+ new And(
+ new EqualTo(studentGender, new IntegerLiteral(1)),
+ new EqualTo(courseName, new StringLiteral("abc"))
+ )
+ );
+ Plan join = new LogicalJoin<>(JoinType.CROSS_JOIN, student, course);
+ LogicalFilter root = new LogicalFilter(expr, join);
+ PlanChecker.from(MemoTestUtils.createConnectContext(), root)
+ .applyTopDown(new ExtractSingleTableExpressionFromDisjunction())
+ .matchesFromRoot(
+ logicalFilter()
+ .when(filter -> verifySingleTableExpression1(filter.getPredicates()))
+ );
+ Assertions.assertTrue(studentGender != null);
+ }
+
+ private boolean verifySingleTableExpression1(Expression expr) {
+ List conjuncts = ExpressionUtils.extractConjunction(expr);
+ Expression or1 = new Or(
+ new EqualTo(courseCid, new IntegerLiteral(1)),
+ new EqualTo(courseName, new StringLiteral("abc"))
+ );
+ Expression or2 = new Or(
+ new EqualTo(studentAge, new IntegerLiteral(10)),
+ new EqualTo(studentGender, new IntegerLiteral(1))
+ );
+
+ return conjuncts.size() == 3 && conjuncts.contains(or1) && conjuncts.contains(or2);
+ }
+
+ /**
+ * (cid=1 and sage=10) or (cid=2 and cname='abc')
+ * =>
+ * cid=1 or (cid=2 and cname='abc')
+ */
+ @Test
+ public void testExtract2() {
+
+ Expression expr = new Or(
+ new And(
+ new EqualTo(courseCid, new IntegerLiteral(1)),
+ new EqualTo(studentAge, new IntegerLiteral(10))
+ ),
+ new And(
+ new EqualTo(courseCid, new IntegerLiteral(2)),
+ new EqualTo(courseName, new StringLiteral("abc"))
+ )
+ );
+ Plan join = new LogicalJoin<>(JoinType.CROSS_JOIN, student, course);
+ LogicalFilter root = new LogicalFilter(expr, join);
+ PlanChecker.from(MemoTestUtils.createConnectContext(), root)
+ .applyTopDown(new ExtractSingleTableExpressionFromDisjunction())
+ .matchesFromRoot(
+ logicalFilter()
+ .when(filter -> verifySingleTableExpression2(filter.getPredicates()))
+ );
+ Assertions.assertTrue(studentGender != null);
+ }
+
+ private boolean verifySingleTableExpression2(Expression expr) {
+ List conjuncts = ExpressionUtils.extractConjunction(expr);
+ Expression or1 = new Or(
+ new EqualTo(courseCid, new IntegerLiteral(1)),
+ new And(
+ new EqualTo(courseCid, new IntegerLiteral(2)),
+ new EqualTo(courseName, new StringLiteral("abc"))));
+
+ return conjuncts.size() == 2 && conjuncts.contains(or1);
+ }
+
+ /**
+ *(cid=1 and sage=10) or sgender=1
+ * =>
+ * (sage=10 or sgender=1)
+ */
+
+ @Test
+ public void testExtract3() {
+ Expression expr = new Or(
+ new And(
+ new EqualTo(courseCid, new IntegerLiteral(1)),
+ new EqualTo(studentAge, new IntegerLiteral(10))
+ ),
+ new EqualTo(studentGender, new IntegerLiteral(1))
+ );
+ Plan join = new LogicalJoin<>(JoinType.CROSS_JOIN, student, course);
+ LogicalFilter root = new LogicalFilter(expr, join);
+ PlanChecker.from(MemoTestUtils.createConnectContext(), root)
+ .applyTopDown(new ExtractSingleTableExpressionFromDisjunction())
+ .matchesFromRoot(
+ logicalFilter()
+ .when(filter -> verifySingleTableExpression3(filter.getPredicates()))
+ );
+ Assertions.assertTrue(studentGender != null);
+ }
+
+ private boolean verifySingleTableExpression3(Expression expr) {
+ List conjuncts = ExpressionUtils.extractConjunction(expr);
+ Expression or = new Or(
+ new EqualTo(studentAge, new IntegerLiteral(10)),
+ new EqualTo(studentGender, new IntegerLiteral(1))
+ );
+
+ return conjuncts.size() == 2 && conjuncts.contains(or);
+ }
+}