From 10b252856dcb26d87d4ab9fcda37a7870b2156c1 Mon Sep 17 00:00:00 2001 From: jakevin Date: Tue, 18 Apr 2023 09:31:07 +0800 Subject: [PATCH] [feature](Nereids): pullup semiJoin through aggregate. (#18669) --- .../apache/doris/nereids/rules/RuleSet.java | 4 ++ .../apache/doris/nereids/rules/RuleType.java | 2 + .../exploration/AggSemiJoinTranspose.java | 45 ++++++++++++++ .../AggSemiJoinTransposeProject.java | 47 +++++++++++++++ .../PushdownFilterThroughAggregation.java | 60 ++++++++----------- .../rewrite/logical/SemiJoinAggTranspose.java | 32 ++++------ .../logical/SemiJoinAggTransposeProject.java | 24 +------- .../SemiJoinLogicalJoinTransposeProject.java | 1 + .../trees/plans/logical/LogicalJoin.java | 8 +++ .../exploration/AggSemiJoinTransposeTest.java | 60 +++++++++++++++++++ 10 files changed, 207 insertions(+), 76 deletions(-) create mode 100644 fe/fe-core/src/main/java/org/apache/doris/nereids/rules/exploration/AggSemiJoinTranspose.java create mode 100644 fe/fe-core/src/main/java/org/apache/doris/nereids/rules/exploration/AggSemiJoinTransposeProject.java create mode 100644 fe/fe-core/src/test/java/org/apache/doris/nereids/rules/exploration/AggSemiJoinTransposeTest.java diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/RuleSet.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/RuleSet.java index 2917996e27..aa30417fc2 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/RuleSet.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/RuleSet.java @@ -17,6 +17,8 @@ package org.apache.doris.nereids.rules; +import org.apache.doris.nereids.rules.exploration.AggSemiJoinTranspose; +import org.apache.doris.nereids.rules.exploration.AggSemiJoinTransposeProject; import org.apache.doris.nereids.rules.exploration.MergeProjectsCBO; import org.apache.doris.nereids.rules.exploration.PushdownFilterThroughProjectCBO; import org.apache.doris.nereids.rules.exploration.join.InnerJoinLAsscom; @@ -98,6 +100,8 @@ public class RuleSet { .add(LogicalJoinSemiJoinTransposeProject.INSTANCE) .add(PushdownProjectThroughInnerJoin.INSTANCE) .add(PushdownProjectThroughSemiJoin.INSTANCE) + .add(AggSemiJoinTranspose.INSTANCE) + .add(AggSemiJoinTransposeProject.INSTANCE) .build(); public static final List PUSH_DOWN_FILTERS = ImmutableList.of( diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/RuleType.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/RuleType.java index 9406267fc4..3113a120c2 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/RuleType.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/RuleType.java @@ -238,6 +238,8 @@ public enum RuleType { LOGICAL_INNER_JOIN_LEFT_ASSOCIATIVE_PROJECT(RuleTypeClass.EXPLORATION), LOGICAL_INNER_JOIN_RIGHT_ASSOCIATIVE(RuleTypeClass.EXPLORATION), LOGICAL_INNER_JOIN_RIGHT_ASSOCIATIVE_PROJECT(RuleTypeClass.EXPLORATION), + LOGICAL_AGG_SEMI_JOIN_TRANSPOSE(RuleTypeClass.EXPLORATION), + LOGICAL_AGG_SEMI_JOIN_TRANSPOSE_PROJECT(RuleTypeClass.EXPLORATION), PUSH_DOWN_PROJECT_THROUGH_SEMI_JOIN(RuleTypeClass.EXPLORATION), PUSH_DOWN_PROJECT_THROUGH_INNER_JOIN(RuleTypeClass.EXPLORATION), EAGER_COUNT(RuleTypeClass.EXPLORATION), diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/exploration/AggSemiJoinTranspose.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/exploration/AggSemiJoinTranspose.java new file mode 100644 index 0000000000..de995cee36 --- /dev/null +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/exploration/AggSemiJoinTranspose.java @@ -0,0 +1,45 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +package org.apache.doris.nereids.rules.exploration; + +import org.apache.doris.nereids.rules.Rule; +import org.apache.doris.nereids.rules.RuleType; +import org.apache.doris.nereids.rules.rewrite.logical.SemiJoinAggTranspose; +import org.apache.doris.nereids.trees.plans.GroupPlan; +import org.apache.doris.nereids.trees.plans.logical.LogicalJoin; + +/** + * Pull up SemiJoin through Agg. + */ +public class AggSemiJoinTranspose extends OneExplorationRuleFactory { + public static final AggSemiJoinTranspose INSTANCE = new AggSemiJoinTranspose(); + + @Override + public Rule build() { + return logicalAggregate(logicalJoin()) + .when(agg -> agg.child().getJoinType().isLeftSemiOrAntiJoin()) + .then(agg -> { + LogicalJoin join = agg.child(); + if (!SemiJoinAggTranspose.canTranspose(agg, join)) { + return null; + } + return join.withChildren(agg.withChildren(join.left()), join.right()); + }) + .toRule(RuleType.LOGICAL_AGG_SEMI_JOIN_TRANSPOSE); + } +} diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/exploration/AggSemiJoinTransposeProject.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/exploration/AggSemiJoinTransposeProject.java new file mode 100644 index 0000000000..45b687cc60 --- /dev/null +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/exploration/AggSemiJoinTransposeProject.java @@ -0,0 +1,47 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +package org.apache.doris.nereids.rules.exploration; + +import org.apache.doris.nereids.rules.Rule; +import org.apache.doris.nereids.rules.RuleType; +import org.apache.doris.nereids.rules.rewrite.logical.SemiJoinAggTranspose; +import org.apache.doris.nereids.trees.plans.GroupPlan; +import org.apache.doris.nereids.trees.plans.logical.LogicalJoin; +import org.apache.doris.nereids.trees.plans.logical.LogicalProject; + +/** + * Pull up SemiJoin through Agg. + */ +public class AggSemiJoinTransposeProject extends OneExplorationRuleFactory { + public static final AggSemiJoinTransposeProject INSTANCE = new AggSemiJoinTransposeProject(); + + @Override + public Rule build() { + return logicalAggregate(logicalProject(logicalJoin())) + .when(agg -> agg.child().child().getJoinType().isLeftSemiOrAntiJoin()) + .then(agg -> { + LogicalProject> project = agg.child(); + LogicalJoin join = project.child(); + if (!SemiJoinAggTranspose.canTranspose(agg, join)) { + return null; + } + return join.withChildren(agg.withChildren(project.withChildren(join.left())), join.right()); + }) + .toRule(RuleType.LOGICAL_AGG_SEMI_JOIN_TRANSPOSE_PROJECT); + } +} diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/logical/PushdownFilterThroughAggregation.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/logical/PushdownFilterThroughAggregation.java index a891eb07ee..e64e659308 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/logical/PushdownFilterThroughAggregation.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/logical/PushdownFilterThroughAggregation.java @@ -27,7 +27,6 @@ import org.apache.doris.nereids.trees.plans.logical.LogicalAggregate; import org.apache.doris.nereids.trees.plans.logical.LogicalFilter; import org.apache.doris.nereids.util.PlanUtils; -import com.google.common.collect.ImmutableList; import com.google.common.collect.Sets; import java.util.HashSet; @@ -36,28 +35,21 @@ import java.util.Set; /** * Push the predicate in the LogicalFilter to the aggregate child. * For example: + *
  * Logical plan tree:
- *                 any_node
- *                   |
  *                filter (a>0 and b>0)
  *                   |
  *                group by(a, c)
- *                   |
- *                 scan
  * transformed to:
- *                 project
- *                   |
  *              upper filter (b>0)
  *                   |
  *                group by(a, c)
  *                   |
  *              bottom filter (a>0)
- *                   |
- *                 scan
+ * 
* Note: - * 'a>0' could be push down, because 'a' is in group by keys; - * but 'b>0' could not push down, because 'b' is not in group by keys. - * + * 'a>0' could be push down, because 'a' is in group by keys; + * but 'b>0' could not push down, because 'b' is not in group by keys. */ public class PushdownFilterThroughAggregation extends OneRewriteRuleFactory { @@ -66,17 +58,7 @@ public class PushdownFilterThroughAggregation extends OneRewriteRuleFactory { public Rule build() { return logicalFilter(logicalAggregate()).then(filter -> { LogicalAggregate aggregate = filter.child(); - Set canPushDownSlots = new HashSet<>(); - if (aggregate.hasRepeat()) { - // When there is a repeat, the push-down condition is consistent with the repeat - canPushDownSlots.addAll(aggregate.getSourceRepeat().get().getCommonGroupingSetExpressions()); - } else { - for (Expression groupByExpression : aggregate.getGroupByExpressions()) { - if (groupByExpression instanceof Slot) { - canPushDownSlots.add((Slot) groupByExpression); - } - } - } + Set canPushDownSlots = getCanPushDownSlots(aggregate); Set pushDownPredicates = Sets.newHashSet(); Set filterPredicates = Sets.newHashSet(); @@ -88,20 +70,30 @@ public class PushdownFilterThroughAggregation extends OneRewriteRuleFactory { filterPredicates.add(conjunct); } }); - - return pushDownPredicate(filter, aggregate, pushDownPredicates, filterPredicates); + if (pushDownPredicates.size() == 0) { + return null; + } + Plan bottomFilter = new LogicalFilter<>(pushDownPredicates, aggregate.child(0)); + aggregate = (LogicalAggregate) aggregate.withChildren(bottomFilter); + return PlanUtils.filterOrSelf(filterPredicates, aggregate); }).toRule(RuleType.PUSHDOWN_PREDICATE_THROUGH_AGGREGATION); } - private Plan pushDownPredicate(LogicalFilter filter, LogicalAggregate aggregate, - Set pushDownPredicates, Set filterPredicates) { - if (pushDownPredicates.size() == 0) { - // nothing pushed down, just return origin plan - return filter; + /** + * get the slots that can be pushed down + */ + public static Set getCanPushDownSlots(LogicalAggregate aggregate) { + Set canPushDownSlots = new HashSet<>(); + if (aggregate.hasRepeat()) { + // When there is a repeat, the push-down condition is consistent with the repeat + canPushDownSlots.addAll(aggregate.getSourceRepeat().get().getCommonGroupingSetExpressions()); + } else { + for (Expression groupByExpression : aggregate.getGroupByExpressions()) { + if (groupByExpression instanceof Slot) { + canPushDownSlots.add((Slot) groupByExpression); + } + } } - LogicalFilter bottomFilter = new LogicalFilter<>(pushDownPredicates, aggregate.child(0)); - - aggregate = aggregate.withChildren(ImmutableList.of(bottomFilter)); - return PlanUtils.filterOrSelf(filterPredicates, aggregate); + return canPushDownSlots; } } diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/logical/SemiJoinAggTranspose.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/logical/SemiJoinAggTranspose.java index f6ecde1f4e..ed8f743629 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/logical/SemiJoinAggTranspose.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/logical/SemiJoinAggTranspose.java @@ -20,14 +20,12 @@ package org.apache.doris.nereids.rules.rewrite.logical; import org.apache.doris.nereids.rules.Rule; import org.apache.doris.nereids.rules.RuleType; import org.apache.doris.nereids.rules.rewrite.OneRewriteRuleFactory; -import org.apache.doris.nereids.trees.expressions.Expression; import org.apache.doris.nereids.trees.expressions.Slot; import org.apache.doris.nereids.trees.plans.Plan; import org.apache.doris.nereids.trees.plans.logical.LogicalAggregate; +import org.apache.doris.nereids.trees.plans.logical.LogicalJoin; -import java.util.HashSet; import java.util.Set; -import java.util.stream.Collectors; /** * Pushdown semi-join through agg @@ -39,26 +37,20 @@ public class SemiJoinAggTranspose extends OneRewriteRuleFactory { .when(join -> join.getJoinType().isLeftSemiOrAntiJoin()) .then(join -> { LogicalAggregate aggregate = join.left(); - Set canPushDownSlots = new HashSet<>(); - if (aggregate.hasRepeat()) { - // When there is a repeat, the push-down condition is consistent with the repeat - canPushDownSlots.addAll(aggregate.getSourceRepeat().get().getCommonGroupingSetExpressions()); - } else { - for (Expression groupByExpression : aggregate.getGroupByExpressions()) { - if (groupByExpression instanceof Slot) { - canPushDownSlots.add((Slot) groupByExpression); - } - } - } - Set leftOutputSet = join.left().getOutputSet(); - Set conditionSlot = join.getConditionSlot() - .stream() - .filter(leftOutputSet::contains) - .collect(Collectors.toSet()); - if (!canPushDownSlots.containsAll(conditionSlot)) { + if (!canTranspose(aggregate, join)) { return null; } return aggregate.withChildren(join.withChildren(aggregate.child(), join.right())); }).toRule(RuleType.LOGICAL_SEMI_JOIN_AGG_TRANSPOSE); } + + /** + * check if we can transpose agg and semi join + */ + public static boolean canTranspose(LogicalAggregate aggregate, + LogicalJoin join) { + Set canPushDownSlots = PushdownFilterThroughAggregation.getCanPushDownSlots(aggregate); + Set leftConditionSlot = join.getLeftConditionSlot(); + return canPushDownSlots.containsAll(leftConditionSlot); + } } diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/logical/SemiJoinAggTransposeProject.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/logical/SemiJoinAggTransposeProject.java index 1820cfde37..838a41cdc3 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/logical/SemiJoinAggTransposeProject.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/logical/SemiJoinAggTransposeProject.java @@ -20,16 +20,11 @@ package org.apache.doris.nereids.rules.rewrite.logical; import org.apache.doris.nereids.rules.Rule; import org.apache.doris.nereids.rules.RuleType; import org.apache.doris.nereids.rules.rewrite.OneRewriteRuleFactory; -import org.apache.doris.nereids.trees.expressions.Expression; import org.apache.doris.nereids.trees.expressions.Slot; import org.apache.doris.nereids.trees.plans.Plan; import org.apache.doris.nereids.trees.plans.logical.LogicalAggregate; import org.apache.doris.nereids.trees.plans.logical.LogicalProject; -import java.util.HashSet; -import java.util.Set; -import java.util.stream.Collectors; - /** * Pushdown semi-join through agg */ @@ -38,27 +33,12 @@ public class SemiJoinAggTransposeProject extends OneRewriteRuleFactory { public Rule build() { return logicalJoin(logicalProject(logicalAggregate()), any()) .when(join -> join.getJoinType().isLeftSemiOrAntiJoin()) + .when(join -> join.left().isAllSlots()) .when(join -> join.left().getProjects().stream().allMatch(n -> n instanceof Slot)) .then(join -> { LogicalProject> project = join.left(); LogicalAggregate aggregate = project.child(); - Set canPushDownSlots = new HashSet<>(); - if (aggregate.hasRepeat()) { - // When there is a repeat, the push-down condition is consistent with the repeat - canPushDownSlots.addAll(aggregate.getSourceRepeat().get().getCommonGroupingSetExpressions()); - } else { - for (Expression groupByExpression : aggregate.getGroupByExpressions()) { - if (groupByExpression instanceof Slot) { - canPushDownSlots.add((Slot) groupByExpression); - } - } - } - Set leftOutputSet = join.left().getOutputSet(); - Set conditionSlot = join.getConditionSlot() - .stream() - .filter(leftOutputSet::contains) - .collect(Collectors.toSet()); - if (!canPushDownSlots.containsAll(conditionSlot)) { + if (!SemiJoinAggTranspose.canTranspose(aggregate, join)) { return null; } Plan newPlan = aggregate.withChildren(join.withChildren(aggregate.child(), join.right())); diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/logical/SemiJoinLogicalJoinTransposeProject.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/logical/SemiJoinLogicalJoinTransposeProject.java index aeabe44000..94e3273f08 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/logical/SemiJoinLogicalJoinTransposeProject.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/logical/SemiJoinLogicalJoinTransposeProject.java @@ -46,6 +46,7 @@ public class SemiJoinLogicalJoinTransposeProject extends OneRewriteRuleFactory { && (topJoin.left().child().getJoinType().isInnerJoin() || topJoin.left().child().getJoinType().isLeftOuterJoin() || topJoin.left().child().getJoinType().isRightOuterJoin()))) + .when(join -> join.left().isAllSlots()) .whenNot(join -> join.hasJoinHint() || join.left().child().hasJoinHint()) .whenNot(join -> join.isMarkJoin() || join.left().child().isMarkJoin()) .when(join -> join.left().getProjects().stream().allMatch(expr -> expr instanceof Slot)) diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/logical/LogicalJoin.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/logical/LogicalJoin.java index 81f1e08b8e..12c0ab13c2 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/logical/LogicalJoin.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/logical/LogicalJoin.java @@ -145,6 +145,14 @@ public class LogicalJoin expr.getInputSlotExprIds().stream()).collect(Collectors.toSet()); } + public Set getLeftConditionSlot() { + Set leftOutputSet = this.left().getOutputSet(); + return Stream.concat(hashJoinConjuncts.stream(), otherJoinConjuncts.stream()) + .flatMap(expr -> expr.getInputSlots().stream()) + .filter(leftOutputSet::contains) + .collect(ImmutableSet.toImmutableSet()); + } + public Optional getOnClauseCondition() { return ExpressionUtils.optionalAnd(hashJoinConjuncts, otherJoinConjuncts); } diff --git a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/exploration/AggSemiJoinTransposeTest.java b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/exploration/AggSemiJoinTransposeTest.java new file mode 100644 index 0000000000..d88e53b16e --- /dev/null +++ b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/exploration/AggSemiJoinTransposeTest.java @@ -0,0 +1,60 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +package org.apache.doris.nereids.rules.exploration; + +import org.apache.doris.common.Pair; +import org.apache.doris.nereids.trees.expressions.Alias; +import org.apache.doris.nereids.trees.expressions.functions.agg.Sum; +import org.apache.doris.nereids.trees.plans.JoinType; +import org.apache.doris.nereids.trees.plans.logical.LogicalOlapScan; +import org.apache.doris.nereids.trees.plans.logical.LogicalPlan; +import org.apache.doris.nereids.util.LogicalPlanBuilder; +import org.apache.doris.nereids.util.MemoPatternMatchSupported; +import org.apache.doris.nereids.util.MemoTestUtils; +import org.apache.doris.nereids.util.PlanChecker; +import org.apache.doris.nereids.util.PlanConstructor; + +import com.google.common.collect.ImmutableList; +import org.junit.jupiter.api.Test; + +class AggSemiJoinTransposeTest implements MemoPatternMatchSupported { + private final LogicalOlapScan scan1 = PlanConstructor.newLogicalOlapScan(0, "t1", 0); + private final LogicalOlapScan scan2 = PlanConstructor.newLogicalOlapScan(1, "t2", 0); + + @Test + void simple() { + LogicalPlan plan = new LogicalPlanBuilder(scan1) + .join(scan2, JoinType.LEFT_SEMI_JOIN, Pair.of(0, 0)) + .aggGroupUsingIndex(ImmutableList.of(0), + ImmutableList.of( + scan1.getOutput().get(0), + new Alias(new Sum(scan1.getOutput().get(1)), "sum") + ) + ) + .build(); + PlanChecker.from(MemoTestUtils.createConnectContext(), plan) + .applyExploration(AggSemiJoinTranspose.INSTANCE.build()) + .printlnExploration() + .matchesExploration( + leftSemiLogicalJoin( + logicalAggregate(), + logicalOlapScan() + ) + ); + } +}