From c809a219932e730e8a536f4247524f27bb214561 Mon Sep 17 00:00:00 2001 From: minghong Date: Mon, 26 Sep 2022 11:19:37 +0800 Subject: [PATCH] [feature](nereids) extract single table expression for push down (#12894) TPCH q7, we have expression like (n1.n_name = 'FRANCE' and n2.n_name = 'GERMANY') or (n1.n_name = 'GERMANY' and n2.n_name = 'FRANCE') this expression implies (n1.n_name='FRANCE' or n1.n_name=''GERMANY) The implied expression is logical redundancy, but it could be used to reduce the output tuple number of scan(n1), if nereids pushes this expression down. This pr introduces a RULE to extract such expressions. NOTE: 1. we only extract expression on a single table. 2. if the extracted expression cannot be pushed down, e.g. it is on right table of left outer join, we need another rule to remove all the useless expressions. --- .../jobs/batch/NereidsRewriteJobExecutor.java | 2 + .../apache/doris/nereids/rules/RuleType.java | 1 + ...tSingleTableExpressionFromDisjunction.java | 166 ++++++++++++++++ .../trees/plans/logical/LogicalFilter.java | 35 +++- ...gleTableExpressionFromDisjunctionTest.java | 185 ++++++++++++++++++ 5 files changed, 384 insertions(+), 5 deletions(-) create mode 100644 fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/logical/ExtractSingleTableExpressionFromDisjunction.java create mode 100644 fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/logical/ExtractSingleTableExpressionFromDisjunctionTest.java diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/batch/NereidsRewriteJobExecutor.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/batch/NereidsRewriteJobExecutor.java index b77c599cc0..5534871b50 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/batch/NereidsRewriteJobExecutor.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/batch/NereidsRewriteJobExecutor.java @@ -27,6 +27,7 @@ import org.apache.doris.nereids.rules.rewrite.AggregateDisassemble; import org.apache.doris.nereids.rules.rewrite.logical.ColumnPruning; import org.apache.doris.nereids.rules.rewrite.logical.EliminateFilter; import org.apache.doris.nereids.rules.rewrite.logical.EliminateLimit; +import org.apache.doris.nereids.rules.rewrite.logical.ExtractSingleTableExpressionFromDisjunction; import org.apache.doris.nereids.rules.rewrite.logical.FindHashConditionForJoin; import org.apache.doris.nereids.rules.rewrite.logical.LimitPushDown; import org.apache.doris.nereids.rules.rewrite.logical.NormalizeAggregate; @@ -59,6 +60,7 @@ public class NereidsRewriteJobExecutor extends BatchRulesJob { .addAll(new ConvertApplyToJoinJob(cascadesContext).rulesJob) .add(topDownBatch(ImmutableList.of(new ExpressionNormalization()))) .add(topDownBatch(ImmutableList.of(new ExpressionOptimization()))) + .add(topDownBatch(ImmutableList.of(new ExtractSingleTableExpressionFromDisjunction()))) .add(topDownBatch(ImmutableList.of(new NormalizeAggregate()))) .add(topDownBatch(RuleSet.PUSH_DOWN_JOIN_CONDITION_RULES, false)) .add(topDownBatch(ImmutableList.of(new ReorderJoin()))) diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/RuleType.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/RuleType.java index 1765d543f7..e9968f2fd1 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/RuleType.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/RuleType.java @@ -109,6 +109,7 @@ public enum RuleType { ROLLUP_AGG_PROJECT_FILTER_SCAN(RuleTypeClass.REWRITE), ROLLUP_AGG_FILTER_PROJECT_SCAN(RuleTypeClass.REWRITE), OLAP_SCAN_PARTITION_PRUNE(RuleTypeClass.REWRITE), + EXTRACT_SINGLE_TABLE_EXPRESSION_FROM_DISJUNCTION(RuleTypeClass.REWRITE), // Pushdown filter PUSHDOWN_FILTER_THROUGH_PROJET(RuleTypeClass.REWRITE), PUSHDOWN_PROJECT_THROUGHT_LIMIT(RuleTypeClass.REWRITE), diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/logical/ExtractSingleTableExpressionFromDisjunction.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/logical/ExtractSingleTableExpressionFromDisjunction.java new file mode 100644 index 0000000000..9317ea6969 --- /dev/null +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/logical/ExtractSingleTableExpressionFromDisjunction.java @@ -0,0 +1,166 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +package org.apache.doris.nereids.rules.rewrite.logical; + +import org.apache.doris.nereids.rules.Rule; +import org.apache.doris.nereids.rules.RuleType; +import org.apache.doris.nereids.rules.rewrite.OneRewriteRuleFactory; +import org.apache.doris.nereids.trees.expressions.Expression; +import org.apache.doris.nereids.trees.expressions.Slot; +import org.apache.doris.nereids.trees.expressions.SlotReference; +import org.apache.doris.nereids.trees.plans.logical.LogicalFilter; +import org.apache.doris.nereids.util.ExpressionUtils; + +import com.google.common.collect.Lists; + +import java.util.List; +import java.util.Optional; +import java.util.Set; +import java.util.stream.Collectors; + +/** + * Example: + * (n1.n_name = 'FRANCE' and n2.n_name = 'GERMANY') or (n1.n_name = 'GERMANY' and n2.n_name = 'FRANCE') + * => + * (n1.n_name = 'FRANCE' and n2.n_name = 'GERMANY') or (n1.n_name = 'GERMANY' and n2.n_name = 'FRANCE') + * and (n1.n_name = 'FRANCE' or n1.n_name='GERMANY') and (n2.n_name='GERMANY' or n2.n_name='FRANCE') + * + * (n1.n_name = 'FRANCE' or n1.n_name='GERMANY') is a logical redundant, but it could be pushed down to scan(n1) to + * reduce the number of scan output tuples. + * For complete sql example, refer to tpch q7. + * ================================================================================================== + *
+ * There are 2 cases, in which the redundant expressions are useless: + * 1. filter(expr)-->XXX out join. + * For example, for left join, the redundant expression for right side is not useful, because we cannot push expression + * down to right child. Refer to PushDownJoinOtherCondition Rule for push-down cases. + * But it is hard to detect this case, if the outer join is a descendant but not child of the filter. + * 2. filter(expr) + * |-->upper-join + * |-->bottom-join + * +-->child + * In current version, we do not extract redundant expression for bottom-join. This redundancy is good for + * upper-join (reduce the number of input tuple from bottom join), but it becomes unuseful if we rotate the join tree. + * ================================================================================================== + *
+ * Implementation note: + * 1. This rule should only be applied ONCE to avoid generate same redundant expression. + * 2. This version only generates redundant expressions, but not push them. + * 3. A redundant expression only contains slots from a single table. + * 4. This rule is applied after rules converting sub-query to join. + * 5. Add a flag 'isRedundant' in Expression. It is true, if it is generated by this rule. + * 6. The useless redundant expression should be removed, if it cannot be pushed down. We need a new rule + * `RemoveRedundantExpression` to fulfill this purpose. + * 7. In old optimizer, there is `InferFilterRule` generates redundancy expressions. Its Nereid counterpart also need + * `RemoveRedundantExpression`. + */ +public class ExtractSingleTableExpressionFromDisjunction extends OneRewriteRuleFactory { + @Override + public Rule build() { + return logicalFilter().whenNot(LogicalFilter::isSingleTableExpressionExtracted).then(filter -> { + //filter = [(n1.n_name = 'FRANCE' and n2.n_name = 'GERMANY') + // or (n1.n_name = 'GERMANY' and n2.n_name = 'FRANCE')] + // and ... + List conjuncts = ExpressionUtils.extractConjunction(filter.getPredicates()) + .stream().collect(Collectors.toList()); + + List redundants = Lists.newArrayList(); + for (Expression conjunct : conjuncts) { + //conjunct=(n1.n_name = 'FRANCE' and n2.n_name = 'GERMANY') + // or (n1.n_name = 'GERMANY' and n2.n_name = 'FRANCE') + List disjuncts = ExpressionUtils.extractDisjunction(conjunct); + //disjuncts={ (n1.n_name = 'FRANCE' and n2.n_name = 'GERMANY'), + // (n1.n_name = 'GERMANY' and n2.n_name = 'FRANCE')} + if (disjuncts.size() == 1) { + continue; + } + //only check table in first disjunct. + //In our example, qualifiers = { n1, n2 } + Expression first = disjuncts.get(0); + Set qualifiers = first.getInputSlots() + .stream() + .map(SlotReference.class::cast) + .map(this::getSlotQualifierAsString) + .collect(Collectors.toSet()); + //try to extract + for (String qualifier : qualifiers) { + List extractForAll = Lists.newArrayList(); + boolean success = true; + for (Expression expr : ExpressionUtils.extractDisjunction(conjunct)) { + Optional extracted = extractSingleTableExpression(expr, qualifier); + if (!extracted.isPresent()) { + //extract failed + success = false; + break; + } else { + extractForAll.add(extracted.get()); + } + } + if (success) { + redundants.add(ExpressionUtils.or(extractForAll)); + } + } + + } + if (redundants.isEmpty()) { + return new LogicalFilter<>(filter.getPredicates(), true, filter.child()); + } else { + Expression newPredicate = ExpressionUtils.and(filter.getPredicates(), ExpressionUtils.and(redundants)); + return new LogicalFilter<>(newPredicate, + true, filter.child()); + } + }).toRule(RuleType.EXTRACT_SINGLE_TABLE_EXPRESSION_FROM_DISJUNCTION); + } + + private String getSlotQualifierAsString(SlotReference slotReference) { + StringBuilder builder = new StringBuilder(); + for (String q : slotReference.getQualifier()) { + builder.append(q).append('.'); + } + return builder.toString(); + } + + //extract some conjucts from expr, all slots of the extracted conjunct comes from the table referred by qualifier. + //example: expr=(n1.n_name = 'FRANCE' and n2.n_name = 'GERMANY'), qualifier="n1." + //output: n1.n_name = 'FRANCE' + private Optional extractSingleTableExpression(Expression expr, String qualifier) { + List output = Lists.newArrayList(); + List conjuncts = ExpressionUtils.extractConjunction(expr); + for (Expression conjunct : conjuncts) { + if (isSingleTableExpression(conjunct, qualifier)) { + output.add(conjunct); + } + } + if (output.isEmpty()) { + return Optional.empty(); + } else { + return Optional.of(ExpressionUtils.and(output)); + } + } + + private boolean isSingleTableExpression(Expression expr, String qualifier) { + //TODO: cache getSlotQualifierAsString() result. + for (Slot slot : expr.getInputSlots()) { + String slotQualifier = getSlotQualifierAsString((SlotReference) slot); + if (!slotQualifier.equals(qualifier)) { + return false; + } + } + return true; + } +} diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/logical/LogicalFilter.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/logical/LogicalFilter.java index 6225c985fd..9e5f624fec 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/logical/LogicalFilter.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/logical/LogicalFilter.java @@ -40,14 +40,30 @@ import java.util.Optional; public class LogicalFilter extends LogicalUnary implements Filter { private final Expression predicates; + private final boolean singleTableExpressionExtracted; + public LogicalFilter(Expression predicates, CHILD_TYPE child) { this(predicates, Optional.empty(), Optional.empty(), child); } + public LogicalFilter(Expression predicates, + boolean singleTableExpressionExtracted, + CHILD_TYPE child) { + this(predicates, Optional.empty(), singleTableExpressionExtracted, + Optional.empty(), child); + } + public LogicalFilter(Expression predicates, Optional groupExpression, Optional logicalProperties, CHILD_TYPE child) { + this(predicates, groupExpression, false, logicalProperties, child); + } + + public LogicalFilter(Expression predicates, Optional groupExpression, + boolean singleTableExpressionExtracted, + Optional logicalProperties, CHILD_TYPE child) { super(PlanType.LOGICAL_FILTER, groupExpression, logicalProperties, child); this.predicates = Objects.requireNonNull(predicates, "predicates can not be null"); + this.singleTableExpressionExtracted = singleTableExpressionExtracted; } public Expression getPredicates() { @@ -75,12 +91,13 @@ public class LogicalFilter extends LogicalUnary extends LogicalUnary withChildren(List children) { Preconditions.checkArgument(children.size() == 1); - return new LogicalFilter<>(predicates, children.get(0)); + return new LogicalFilter<>(predicates, singleTableExpressionExtracted, children.get(0)); } @Override public Plan withGroupExpression(Optional groupExpression) { - return new LogicalFilter<>(predicates, groupExpression, Optional.of(getLogicalProperties()), child()); + return new LogicalFilter<>(predicates, groupExpression, singleTableExpressionExtracted, + Optional.of(getLogicalProperties()), child()); } @Override public Plan withLogicalProperties(Optional logicalProperties) { - return new LogicalFilter<>(predicates, Optional.empty(), logicalProperties, child()); + return new LogicalFilter<>(predicates, Optional.empty(), + singleTableExpressionExtracted, + logicalProperties, child()); } + + public boolean isSingleTableExpressionExtracted() { + return singleTableExpressionExtracted; + } + } diff --git a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/logical/ExtractSingleTableExpressionFromDisjunctionTest.java b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/logical/ExtractSingleTableExpressionFromDisjunctionTest.java new file mode 100644 index 0000000000..1a1e75ad6e --- /dev/null +++ b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/logical/ExtractSingleTableExpressionFromDisjunctionTest.java @@ -0,0 +1,185 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +package org.apache.doris.nereids.rules.rewrite.logical; + +import org.apache.doris.nereids.trees.expressions.And; +import org.apache.doris.nereids.trees.expressions.EqualTo; +import org.apache.doris.nereids.trees.expressions.Expression; +import org.apache.doris.nereids.trees.expressions.Or; +import org.apache.doris.nereids.trees.expressions.SlotReference; +import org.apache.doris.nereids.trees.expressions.literal.IntegerLiteral; +import org.apache.doris.nereids.trees.expressions.literal.StringLiteral; +import org.apache.doris.nereids.trees.plans.JoinType; +import org.apache.doris.nereids.trees.plans.Plan; +import org.apache.doris.nereids.trees.plans.logical.LogicalFilter; +import org.apache.doris.nereids.trees.plans.logical.LogicalJoin; +import org.apache.doris.nereids.trees.plans.logical.LogicalOlapScan; +import org.apache.doris.nereids.util.ExpressionUtils; +import org.apache.doris.nereids.util.MemoTestUtils; +import org.apache.doris.nereids.util.PatternMatchSupported; +import org.apache.doris.nereids.util.PlanChecker; +import org.apache.doris.nereids.util.PlanConstructor; + +import com.google.common.collect.ImmutableList; +import org.junit.jupiter.api.Assertions; +import org.junit.jupiter.api.BeforeAll; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.TestInstance; + +import java.util.List; + +@TestInstance(TestInstance.Lifecycle.PER_CLASS) +public class ExtractSingleTableExpressionFromDisjunctionTest implements PatternMatchSupported { + Plan student; + Plan course; + SlotReference courseCid; + SlotReference courseName; + SlotReference studentAge; + SlotReference studentGender; + + @BeforeAll + public final void beforeAll() { + student = new LogicalOlapScan(PlanConstructor.getNextRelationId(), PlanConstructor.student, ImmutableList.of("")); + course = new LogicalOlapScan(PlanConstructor.getNextRelationId(), PlanConstructor.course, ImmutableList.of("")); + //select * + //from student join course + //where (course.cid=1 and student.age=10) or (student.gender = 0 and course.name='abc') + courseCid = (SlotReference) course.getOutput().get(0); + courseName = (SlotReference) course.getOutput().get(1); + studentAge = (SlotReference) student.getOutput().get(3); + studentGender = (SlotReference) student.getOutput().get(1); + } + /** + *(cid=1 and sage=10) or (sgender=1 and cname='abc') + * => + * (cid=1 or cname='abc') + (sage=10 or sgender=1) + */ + + @Test + public void testExtract1() { + Expression expr = new Or( + new And( + new EqualTo(courseCid, new IntegerLiteral(1)), + new EqualTo(studentAge, new IntegerLiteral(10)) + ), + new And( + new EqualTo(studentGender, new IntegerLiteral(1)), + new EqualTo(courseName, new StringLiteral("abc")) + ) + ); + Plan join = new LogicalJoin<>(JoinType.CROSS_JOIN, student, course); + LogicalFilter root = new LogicalFilter(expr, join); + PlanChecker.from(MemoTestUtils.createConnectContext(), root) + .applyTopDown(new ExtractSingleTableExpressionFromDisjunction()) + .matchesFromRoot( + logicalFilter() + .when(filter -> verifySingleTableExpression1(filter.getPredicates())) + ); + Assertions.assertTrue(studentGender != null); + } + + private boolean verifySingleTableExpression1(Expression expr) { + List conjuncts = ExpressionUtils.extractConjunction(expr); + Expression or1 = new Or( + new EqualTo(courseCid, new IntegerLiteral(1)), + new EqualTo(courseName, new StringLiteral("abc")) + ); + Expression or2 = new Or( + new EqualTo(studentAge, new IntegerLiteral(10)), + new EqualTo(studentGender, new IntegerLiteral(1)) + ); + + return conjuncts.size() == 3 && conjuncts.contains(or1) && conjuncts.contains(or2); + } + + /** + * (cid=1 and sage=10) or (cid=2 and cname='abc') + * => + * cid=1 or (cid=2 and cname='abc') + */ + @Test + public void testExtract2() { + + Expression expr = new Or( + new And( + new EqualTo(courseCid, new IntegerLiteral(1)), + new EqualTo(studentAge, new IntegerLiteral(10)) + ), + new And( + new EqualTo(courseCid, new IntegerLiteral(2)), + new EqualTo(courseName, new StringLiteral("abc")) + ) + ); + Plan join = new LogicalJoin<>(JoinType.CROSS_JOIN, student, course); + LogicalFilter root = new LogicalFilter(expr, join); + PlanChecker.from(MemoTestUtils.createConnectContext(), root) + .applyTopDown(new ExtractSingleTableExpressionFromDisjunction()) + .matchesFromRoot( + logicalFilter() + .when(filter -> verifySingleTableExpression2(filter.getPredicates())) + ); + Assertions.assertTrue(studentGender != null); + } + + private boolean verifySingleTableExpression2(Expression expr) { + List conjuncts = ExpressionUtils.extractConjunction(expr); + Expression or1 = new Or( + new EqualTo(courseCid, new IntegerLiteral(1)), + new And( + new EqualTo(courseCid, new IntegerLiteral(2)), + new EqualTo(courseName, new StringLiteral("abc")))); + + return conjuncts.size() == 2 && conjuncts.contains(or1); + } + + /** + *(cid=1 and sage=10) or sgender=1 + * => + * (sage=10 or sgender=1) + */ + + @Test + public void testExtract3() { + Expression expr = new Or( + new And( + new EqualTo(courseCid, new IntegerLiteral(1)), + new EqualTo(studentAge, new IntegerLiteral(10)) + ), + new EqualTo(studentGender, new IntegerLiteral(1)) + ); + Plan join = new LogicalJoin<>(JoinType.CROSS_JOIN, student, course); + LogicalFilter root = new LogicalFilter(expr, join); + PlanChecker.from(MemoTestUtils.createConnectContext(), root) + .applyTopDown(new ExtractSingleTableExpressionFromDisjunction()) + .matchesFromRoot( + logicalFilter() + .when(filter -> verifySingleTableExpression3(filter.getPredicates())) + ); + Assertions.assertTrue(studentGender != null); + } + + private boolean verifySingleTableExpression3(Expression expr) { + List conjuncts = ExpressionUtils.extractConjunction(expr); + Expression or = new Or( + new EqualTo(studentAge, new IntegerLiteral(10)), + new EqualTo(studentGender, new IntegerLiteral(1)) + ); + + return conjuncts.size() == 2 && conjuncts.contains(or); + } +}