diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/batch/NereidsRewriteJobExecutor.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/batch/NereidsRewriteJobExecutor.java index b77c599cc0..5534871b50 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/batch/NereidsRewriteJobExecutor.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/batch/NereidsRewriteJobExecutor.java @@ -27,6 +27,7 @@ import org.apache.doris.nereids.rules.rewrite.AggregateDisassemble; import org.apache.doris.nereids.rules.rewrite.logical.ColumnPruning; import org.apache.doris.nereids.rules.rewrite.logical.EliminateFilter; import org.apache.doris.nereids.rules.rewrite.logical.EliminateLimit; +import org.apache.doris.nereids.rules.rewrite.logical.ExtractSingleTableExpressionFromDisjunction; import org.apache.doris.nereids.rules.rewrite.logical.FindHashConditionForJoin; import org.apache.doris.nereids.rules.rewrite.logical.LimitPushDown; import org.apache.doris.nereids.rules.rewrite.logical.NormalizeAggregate; @@ -59,6 +60,7 @@ public class NereidsRewriteJobExecutor extends BatchRulesJob { .addAll(new ConvertApplyToJoinJob(cascadesContext).rulesJob) .add(topDownBatch(ImmutableList.of(new ExpressionNormalization()))) .add(topDownBatch(ImmutableList.of(new ExpressionOptimization()))) + .add(topDownBatch(ImmutableList.of(new ExtractSingleTableExpressionFromDisjunction()))) .add(topDownBatch(ImmutableList.of(new NormalizeAggregate()))) .add(topDownBatch(RuleSet.PUSH_DOWN_JOIN_CONDITION_RULES, false)) .add(topDownBatch(ImmutableList.of(new ReorderJoin()))) diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/RuleType.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/RuleType.java index 1765d543f7..e9968f2fd1 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/RuleType.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/RuleType.java @@ -109,6 +109,7 @@ public enum RuleType { ROLLUP_AGG_PROJECT_FILTER_SCAN(RuleTypeClass.REWRITE), ROLLUP_AGG_FILTER_PROJECT_SCAN(RuleTypeClass.REWRITE), OLAP_SCAN_PARTITION_PRUNE(RuleTypeClass.REWRITE), + EXTRACT_SINGLE_TABLE_EXPRESSION_FROM_DISJUNCTION(RuleTypeClass.REWRITE), // Pushdown filter PUSHDOWN_FILTER_THROUGH_PROJET(RuleTypeClass.REWRITE), PUSHDOWN_PROJECT_THROUGHT_LIMIT(RuleTypeClass.REWRITE), diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/logical/ExtractSingleTableExpressionFromDisjunction.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/logical/ExtractSingleTableExpressionFromDisjunction.java new file mode 100644 index 0000000000..9317ea6969 --- /dev/null +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/logical/ExtractSingleTableExpressionFromDisjunction.java @@ -0,0 +1,166 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +package org.apache.doris.nereids.rules.rewrite.logical; + +import org.apache.doris.nereids.rules.Rule; +import org.apache.doris.nereids.rules.RuleType; +import org.apache.doris.nereids.rules.rewrite.OneRewriteRuleFactory; +import org.apache.doris.nereids.trees.expressions.Expression; +import org.apache.doris.nereids.trees.expressions.Slot; +import org.apache.doris.nereids.trees.expressions.SlotReference; +import org.apache.doris.nereids.trees.plans.logical.LogicalFilter; +import org.apache.doris.nereids.util.ExpressionUtils; + +import com.google.common.collect.Lists; + +import java.util.List; +import java.util.Optional; +import java.util.Set; +import java.util.stream.Collectors; + +/** + * Example: + * (n1.n_name = 'FRANCE' and n2.n_name = 'GERMANY') or (n1.n_name = 'GERMANY' and n2.n_name = 'FRANCE') + * => + * (n1.n_name = 'FRANCE' and n2.n_name = 'GERMANY') or (n1.n_name = 'GERMANY' and n2.n_name = 'FRANCE') + * and (n1.n_name = 'FRANCE' or n1.n_name='GERMANY') and (n2.n_name='GERMANY' or n2.n_name='FRANCE') + * + * (n1.n_name = 'FRANCE' or n1.n_name='GERMANY') is a logical redundant, but it could be pushed down to scan(n1) to + * reduce the number of scan output tuples. + * For complete sql example, refer to tpch q7. + * ================================================================================================== + *
+ * There are 2 cases, in which the redundant expressions are useless: + * 1. filter(expr)-->XXX out join. + * For example, for left join, the redundant expression for right side is not useful, because we cannot push expression + * down to right child. Refer to PushDownJoinOtherCondition Rule for push-down cases. + * But it is hard to detect this case, if the outer join is a descendant but not child of the filter. + * 2. filter(expr) + * |-->upper-join + * |-->bottom-join + * +-->child + * In current version, we do not extract redundant expression for bottom-join. This redundancy is good for + * upper-join (reduce the number of input tuple from bottom join), but it becomes unuseful if we rotate the join tree. + * ================================================================================================== + *
+ * Implementation note: + * 1. This rule should only be applied ONCE to avoid generate same redundant expression. + * 2. This version only generates redundant expressions, but not push them. + * 3. A redundant expression only contains slots from a single table. + * 4. This rule is applied after rules converting sub-query to join. + * 5. Add a flag 'isRedundant' in Expression. It is true, if it is generated by this rule. + * 6. The useless redundant expression should be removed, if it cannot be pushed down. We need a new rule + * `RemoveRedundantExpression` to fulfill this purpose. + * 7. In old optimizer, there is `InferFilterRule` generates redundancy expressions. Its Nereid counterpart also need + * `RemoveRedundantExpression`. + */ +public class ExtractSingleTableExpressionFromDisjunction extends OneRewriteRuleFactory { + @Override + public Rule build() { + return logicalFilter().whenNot(LogicalFilter::isSingleTableExpressionExtracted).then(filter -> { + //filter = [(n1.n_name = 'FRANCE' and n2.n_name = 'GERMANY') + // or (n1.n_name = 'GERMANY' and n2.n_name = 'FRANCE')] + // and ... + List conjuncts = ExpressionUtils.extractConjunction(filter.getPredicates()) + .stream().collect(Collectors.toList()); + + List redundants = Lists.newArrayList(); + for (Expression conjunct : conjuncts) { + //conjunct=(n1.n_name = 'FRANCE' and n2.n_name = 'GERMANY') + // or (n1.n_name = 'GERMANY' and n2.n_name = 'FRANCE') + List disjuncts = ExpressionUtils.extractDisjunction(conjunct); + //disjuncts={ (n1.n_name = 'FRANCE' and n2.n_name = 'GERMANY'), + // (n1.n_name = 'GERMANY' and n2.n_name = 'FRANCE')} + if (disjuncts.size() == 1) { + continue; + } + //only check table in first disjunct. + //In our example, qualifiers = { n1, n2 } + Expression first = disjuncts.get(0); + Set qualifiers = first.getInputSlots() + .stream() + .map(SlotReference.class::cast) + .map(this::getSlotQualifierAsString) + .collect(Collectors.toSet()); + //try to extract + for (String qualifier : qualifiers) { + List extractForAll = Lists.newArrayList(); + boolean success = true; + for (Expression expr : ExpressionUtils.extractDisjunction(conjunct)) { + Optional extracted = extractSingleTableExpression(expr, qualifier); + if (!extracted.isPresent()) { + //extract failed + success = false; + break; + } else { + extractForAll.add(extracted.get()); + } + } + if (success) { + redundants.add(ExpressionUtils.or(extractForAll)); + } + } + + } + if (redundants.isEmpty()) { + return new LogicalFilter<>(filter.getPredicates(), true, filter.child()); + } else { + Expression newPredicate = ExpressionUtils.and(filter.getPredicates(), ExpressionUtils.and(redundants)); + return new LogicalFilter<>(newPredicate, + true, filter.child()); + } + }).toRule(RuleType.EXTRACT_SINGLE_TABLE_EXPRESSION_FROM_DISJUNCTION); + } + + private String getSlotQualifierAsString(SlotReference slotReference) { + StringBuilder builder = new StringBuilder(); + for (String q : slotReference.getQualifier()) { + builder.append(q).append('.'); + } + return builder.toString(); + } + + //extract some conjucts from expr, all slots of the extracted conjunct comes from the table referred by qualifier. + //example: expr=(n1.n_name = 'FRANCE' and n2.n_name = 'GERMANY'), qualifier="n1." + //output: n1.n_name = 'FRANCE' + private Optional extractSingleTableExpression(Expression expr, String qualifier) { + List output = Lists.newArrayList(); + List conjuncts = ExpressionUtils.extractConjunction(expr); + for (Expression conjunct : conjuncts) { + if (isSingleTableExpression(conjunct, qualifier)) { + output.add(conjunct); + } + } + if (output.isEmpty()) { + return Optional.empty(); + } else { + return Optional.of(ExpressionUtils.and(output)); + } + } + + private boolean isSingleTableExpression(Expression expr, String qualifier) { + //TODO: cache getSlotQualifierAsString() result. + for (Slot slot : expr.getInputSlots()) { + String slotQualifier = getSlotQualifierAsString((SlotReference) slot); + if (!slotQualifier.equals(qualifier)) { + return false; + } + } + return true; + } +} diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/logical/LogicalFilter.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/logical/LogicalFilter.java index 6225c985fd..9e5f624fec 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/logical/LogicalFilter.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/logical/LogicalFilter.java @@ -40,14 +40,30 @@ import java.util.Optional; public class LogicalFilter extends LogicalUnary implements Filter { private final Expression predicates; + private final boolean singleTableExpressionExtracted; + public LogicalFilter(Expression predicates, CHILD_TYPE child) { this(predicates, Optional.empty(), Optional.empty(), child); } + public LogicalFilter(Expression predicates, + boolean singleTableExpressionExtracted, + CHILD_TYPE child) { + this(predicates, Optional.empty(), singleTableExpressionExtracted, + Optional.empty(), child); + } + public LogicalFilter(Expression predicates, Optional groupExpression, Optional logicalProperties, CHILD_TYPE child) { + this(predicates, groupExpression, false, logicalProperties, child); + } + + public LogicalFilter(Expression predicates, Optional groupExpression, + boolean singleTableExpressionExtracted, + Optional logicalProperties, CHILD_TYPE child) { super(PlanType.LOGICAL_FILTER, groupExpression, logicalProperties, child); this.predicates = Objects.requireNonNull(predicates, "predicates can not be null"); + this.singleTableExpressionExtracted = singleTableExpressionExtracted; } public Expression getPredicates() { @@ -75,12 +91,13 @@ public class LogicalFilter extends LogicalUnary extends LogicalUnary withChildren(List children) { Preconditions.checkArgument(children.size() == 1); - return new LogicalFilter<>(predicates, children.get(0)); + return new LogicalFilter<>(predicates, singleTableExpressionExtracted, children.get(0)); } @Override public Plan withGroupExpression(Optional groupExpression) { - return new LogicalFilter<>(predicates, groupExpression, Optional.of(getLogicalProperties()), child()); + return new LogicalFilter<>(predicates, groupExpression, singleTableExpressionExtracted, + Optional.of(getLogicalProperties()), child()); } @Override public Plan withLogicalProperties(Optional logicalProperties) { - return new LogicalFilter<>(predicates, Optional.empty(), logicalProperties, child()); + return new LogicalFilter<>(predicates, Optional.empty(), + singleTableExpressionExtracted, + logicalProperties, child()); } + + public boolean isSingleTableExpressionExtracted() { + return singleTableExpressionExtracted; + } + } diff --git a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/logical/ExtractSingleTableExpressionFromDisjunctionTest.java b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/logical/ExtractSingleTableExpressionFromDisjunctionTest.java new file mode 100644 index 0000000000..1a1e75ad6e --- /dev/null +++ b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/logical/ExtractSingleTableExpressionFromDisjunctionTest.java @@ -0,0 +1,185 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +package org.apache.doris.nereids.rules.rewrite.logical; + +import org.apache.doris.nereids.trees.expressions.And; +import org.apache.doris.nereids.trees.expressions.EqualTo; +import org.apache.doris.nereids.trees.expressions.Expression; +import org.apache.doris.nereids.trees.expressions.Or; +import org.apache.doris.nereids.trees.expressions.SlotReference; +import org.apache.doris.nereids.trees.expressions.literal.IntegerLiteral; +import org.apache.doris.nereids.trees.expressions.literal.StringLiteral; +import org.apache.doris.nereids.trees.plans.JoinType; +import org.apache.doris.nereids.trees.plans.Plan; +import org.apache.doris.nereids.trees.plans.logical.LogicalFilter; +import org.apache.doris.nereids.trees.plans.logical.LogicalJoin; +import org.apache.doris.nereids.trees.plans.logical.LogicalOlapScan; +import org.apache.doris.nereids.util.ExpressionUtils; +import org.apache.doris.nereids.util.MemoTestUtils; +import org.apache.doris.nereids.util.PatternMatchSupported; +import org.apache.doris.nereids.util.PlanChecker; +import org.apache.doris.nereids.util.PlanConstructor; + +import com.google.common.collect.ImmutableList; +import org.junit.jupiter.api.Assertions; +import org.junit.jupiter.api.BeforeAll; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.TestInstance; + +import java.util.List; + +@TestInstance(TestInstance.Lifecycle.PER_CLASS) +public class ExtractSingleTableExpressionFromDisjunctionTest implements PatternMatchSupported { + Plan student; + Plan course; + SlotReference courseCid; + SlotReference courseName; + SlotReference studentAge; + SlotReference studentGender; + + @BeforeAll + public final void beforeAll() { + student = new LogicalOlapScan(PlanConstructor.getNextRelationId(), PlanConstructor.student, ImmutableList.of("")); + course = new LogicalOlapScan(PlanConstructor.getNextRelationId(), PlanConstructor.course, ImmutableList.of("")); + //select * + //from student join course + //where (course.cid=1 and student.age=10) or (student.gender = 0 and course.name='abc') + courseCid = (SlotReference) course.getOutput().get(0); + courseName = (SlotReference) course.getOutput().get(1); + studentAge = (SlotReference) student.getOutput().get(3); + studentGender = (SlotReference) student.getOutput().get(1); + } + /** + *(cid=1 and sage=10) or (sgender=1 and cname='abc') + * => + * (cid=1 or cname='abc') + (sage=10 or sgender=1) + */ + + @Test + public void testExtract1() { + Expression expr = new Or( + new And( + new EqualTo(courseCid, new IntegerLiteral(1)), + new EqualTo(studentAge, new IntegerLiteral(10)) + ), + new And( + new EqualTo(studentGender, new IntegerLiteral(1)), + new EqualTo(courseName, new StringLiteral("abc")) + ) + ); + Plan join = new LogicalJoin<>(JoinType.CROSS_JOIN, student, course); + LogicalFilter root = new LogicalFilter(expr, join); + PlanChecker.from(MemoTestUtils.createConnectContext(), root) + .applyTopDown(new ExtractSingleTableExpressionFromDisjunction()) + .matchesFromRoot( + logicalFilter() + .when(filter -> verifySingleTableExpression1(filter.getPredicates())) + ); + Assertions.assertTrue(studentGender != null); + } + + private boolean verifySingleTableExpression1(Expression expr) { + List conjuncts = ExpressionUtils.extractConjunction(expr); + Expression or1 = new Or( + new EqualTo(courseCid, new IntegerLiteral(1)), + new EqualTo(courseName, new StringLiteral("abc")) + ); + Expression or2 = new Or( + new EqualTo(studentAge, new IntegerLiteral(10)), + new EqualTo(studentGender, new IntegerLiteral(1)) + ); + + return conjuncts.size() == 3 && conjuncts.contains(or1) && conjuncts.contains(or2); + } + + /** + * (cid=1 and sage=10) or (cid=2 and cname='abc') + * => + * cid=1 or (cid=2 and cname='abc') + */ + @Test + public void testExtract2() { + + Expression expr = new Or( + new And( + new EqualTo(courseCid, new IntegerLiteral(1)), + new EqualTo(studentAge, new IntegerLiteral(10)) + ), + new And( + new EqualTo(courseCid, new IntegerLiteral(2)), + new EqualTo(courseName, new StringLiteral("abc")) + ) + ); + Plan join = new LogicalJoin<>(JoinType.CROSS_JOIN, student, course); + LogicalFilter root = new LogicalFilter(expr, join); + PlanChecker.from(MemoTestUtils.createConnectContext(), root) + .applyTopDown(new ExtractSingleTableExpressionFromDisjunction()) + .matchesFromRoot( + logicalFilter() + .when(filter -> verifySingleTableExpression2(filter.getPredicates())) + ); + Assertions.assertTrue(studentGender != null); + } + + private boolean verifySingleTableExpression2(Expression expr) { + List conjuncts = ExpressionUtils.extractConjunction(expr); + Expression or1 = new Or( + new EqualTo(courseCid, new IntegerLiteral(1)), + new And( + new EqualTo(courseCid, new IntegerLiteral(2)), + new EqualTo(courseName, new StringLiteral("abc")))); + + return conjuncts.size() == 2 && conjuncts.contains(or1); + } + + /** + *(cid=1 and sage=10) or sgender=1 + * => + * (sage=10 or sgender=1) + */ + + @Test + public void testExtract3() { + Expression expr = new Or( + new And( + new EqualTo(courseCid, new IntegerLiteral(1)), + new EqualTo(studentAge, new IntegerLiteral(10)) + ), + new EqualTo(studentGender, new IntegerLiteral(1)) + ); + Plan join = new LogicalJoin<>(JoinType.CROSS_JOIN, student, course); + LogicalFilter root = new LogicalFilter(expr, join); + PlanChecker.from(MemoTestUtils.createConnectContext(), root) + .applyTopDown(new ExtractSingleTableExpressionFromDisjunction()) + .matchesFromRoot( + logicalFilter() + .when(filter -> verifySingleTableExpression3(filter.getPredicates())) + ); + Assertions.assertTrue(studentGender != null); + } + + private boolean verifySingleTableExpression3(Expression expr) { + List conjuncts = ExpressionUtils.extractConjunction(expr); + Expression or = new Or( + new EqualTo(studentAge, new IntegerLiteral(10)), + new EqualTo(studentGender, new IntegerLiteral(1)) + ); + + return conjuncts.size() == 2 && conjuncts.contains(or); + } +}