[feature](Nereids): pullup semiJoin through aggregate. (#18669)

This commit is contained in:
jakevin
2023-04-18 09:31:07 +08:00
committed by GitHub
parent b68857902e
commit 10b252856d
10 changed files with 207 additions and 76 deletions

View File

@ -17,6 +17,8 @@
package org.apache.doris.nereids.rules;
import org.apache.doris.nereids.rules.exploration.AggSemiJoinTranspose;
import org.apache.doris.nereids.rules.exploration.AggSemiJoinTransposeProject;
import org.apache.doris.nereids.rules.exploration.MergeProjectsCBO;
import org.apache.doris.nereids.rules.exploration.PushdownFilterThroughProjectCBO;
import org.apache.doris.nereids.rules.exploration.join.InnerJoinLAsscom;
@ -98,6 +100,8 @@ public class RuleSet {
.add(LogicalJoinSemiJoinTransposeProject.INSTANCE)
.add(PushdownProjectThroughInnerJoin.INSTANCE)
.add(PushdownProjectThroughSemiJoin.INSTANCE)
.add(AggSemiJoinTranspose.INSTANCE)
.add(AggSemiJoinTransposeProject.INSTANCE)
.build();
public static final List<RuleFactory> PUSH_DOWN_FILTERS = ImmutableList.of(

View File

@ -238,6 +238,8 @@ public enum RuleType {
LOGICAL_INNER_JOIN_LEFT_ASSOCIATIVE_PROJECT(RuleTypeClass.EXPLORATION),
LOGICAL_INNER_JOIN_RIGHT_ASSOCIATIVE(RuleTypeClass.EXPLORATION),
LOGICAL_INNER_JOIN_RIGHT_ASSOCIATIVE_PROJECT(RuleTypeClass.EXPLORATION),
LOGICAL_AGG_SEMI_JOIN_TRANSPOSE(RuleTypeClass.EXPLORATION),
LOGICAL_AGG_SEMI_JOIN_TRANSPOSE_PROJECT(RuleTypeClass.EXPLORATION),
PUSH_DOWN_PROJECT_THROUGH_SEMI_JOIN(RuleTypeClass.EXPLORATION),
PUSH_DOWN_PROJECT_THROUGH_INNER_JOIN(RuleTypeClass.EXPLORATION),
EAGER_COUNT(RuleTypeClass.EXPLORATION),

View File

@ -0,0 +1,45 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.
package org.apache.doris.nereids.rules.exploration;
import org.apache.doris.nereids.rules.Rule;
import org.apache.doris.nereids.rules.RuleType;
import org.apache.doris.nereids.rules.rewrite.logical.SemiJoinAggTranspose;
import org.apache.doris.nereids.trees.plans.GroupPlan;
import org.apache.doris.nereids.trees.plans.logical.LogicalJoin;
/**
* Pull up SemiJoin through Agg.
*/
public class AggSemiJoinTranspose extends OneExplorationRuleFactory {
public static final AggSemiJoinTranspose INSTANCE = new AggSemiJoinTranspose();
@Override
public Rule build() {
return logicalAggregate(logicalJoin())
.when(agg -> agg.child().getJoinType().isLeftSemiOrAntiJoin())
.then(agg -> {
LogicalJoin<GroupPlan, GroupPlan> join = agg.child();
if (!SemiJoinAggTranspose.canTranspose(agg, join)) {
return null;
}
return join.withChildren(agg.withChildren(join.left()), join.right());
})
.toRule(RuleType.LOGICAL_AGG_SEMI_JOIN_TRANSPOSE);
}
}

View File

@ -0,0 +1,47 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.
package org.apache.doris.nereids.rules.exploration;
import org.apache.doris.nereids.rules.Rule;
import org.apache.doris.nereids.rules.RuleType;
import org.apache.doris.nereids.rules.rewrite.logical.SemiJoinAggTranspose;
import org.apache.doris.nereids.trees.plans.GroupPlan;
import org.apache.doris.nereids.trees.plans.logical.LogicalJoin;
import org.apache.doris.nereids.trees.plans.logical.LogicalProject;
/**
* Pull up SemiJoin through Agg.
*/
public class AggSemiJoinTransposeProject extends OneExplorationRuleFactory {
public static final AggSemiJoinTransposeProject INSTANCE = new AggSemiJoinTransposeProject();
@Override
public Rule build() {
return logicalAggregate(logicalProject(logicalJoin()))
.when(agg -> agg.child().child().getJoinType().isLeftSemiOrAntiJoin())
.then(agg -> {
LogicalProject<LogicalJoin<GroupPlan, GroupPlan>> project = agg.child();
LogicalJoin<GroupPlan, GroupPlan> join = project.child();
if (!SemiJoinAggTranspose.canTranspose(agg, join)) {
return null;
}
return join.withChildren(agg.withChildren(project.withChildren(join.left())), join.right());
})
.toRule(RuleType.LOGICAL_AGG_SEMI_JOIN_TRANSPOSE_PROJECT);
}
}

View File

@ -27,7 +27,6 @@ import org.apache.doris.nereids.trees.plans.logical.LogicalAggregate;
import org.apache.doris.nereids.trees.plans.logical.LogicalFilter;
import org.apache.doris.nereids.util.PlanUtils;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.Sets;
import java.util.HashSet;
@ -36,28 +35,21 @@ import java.util.Set;
/**
* Push the predicate in the LogicalFilter to the aggregate child.
* For example:
* <pre>
* Logical plan tree:
* any_node
* |
* filter (a>0 and b>0)
* |
* group by(a, c)
* |
* scan
* transformed to:
* project
* |
* upper filter (b>0)
* |
* group by(a, c)
* |
* bottom filter (a>0)
* |
* scan
* </pre>
* Note:
* 'a>0' could be push down, because 'a' is in group by keys;
* but 'b>0' could not push down, because 'b' is not in group by keys.
*
* 'a>0' could be push down, because 'a' is in group by keys;
* but 'b>0' could not push down, because 'b' is not in group by keys.
*/
public class PushdownFilterThroughAggregation extends OneRewriteRuleFactory {
@ -66,17 +58,7 @@ public class PushdownFilterThroughAggregation extends OneRewriteRuleFactory {
public Rule build() {
return logicalFilter(logicalAggregate()).then(filter -> {
LogicalAggregate<Plan> aggregate = filter.child();
Set<Slot> canPushDownSlots = new HashSet<>();
if (aggregate.hasRepeat()) {
// When there is a repeat, the push-down condition is consistent with the repeat
canPushDownSlots.addAll(aggregate.getSourceRepeat().get().getCommonGroupingSetExpressions());
} else {
for (Expression groupByExpression : aggregate.getGroupByExpressions()) {
if (groupByExpression instanceof Slot) {
canPushDownSlots.add((Slot) groupByExpression);
}
}
}
Set<Slot> canPushDownSlots = getCanPushDownSlots(aggregate);
Set<Expression> pushDownPredicates = Sets.newHashSet();
Set<Expression> filterPredicates = Sets.newHashSet();
@ -88,20 +70,30 @@ public class PushdownFilterThroughAggregation extends OneRewriteRuleFactory {
filterPredicates.add(conjunct);
}
});
return pushDownPredicate(filter, aggregate, pushDownPredicates, filterPredicates);
if (pushDownPredicates.size() == 0) {
return null;
}
Plan bottomFilter = new LogicalFilter<>(pushDownPredicates, aggregate.child(0));
aggregate = (LogicalAggregate<Plan>) aggregate.withChildren(bottomFilter);
return PlanUtils.filterOrSelf(filterPredicates, aggregate);
}).toRule(RuleType.PUSHDOWN_PREDICATE_THROUGH_AGGREGATION);
}
private Plan pushDownPredicate(LogicalFilter filter, LogicalAggregate aggregate,
Set<Expression> pushDownPredicates, Set<Expression> filterPredicates) {
if (pushDownPredicates.size() == 0) {
// nothing pushed down, just return origin plan
return filter;
/**
* get the slots that can be pushed down
*/
public static Set<Slot> getCanPushDownSlots(LogicalAggregate<? extends Plan> aggregate) {
Set<Slot> canPushDownSlots = new HashSet<>();
if (aggregate.hasRepeat()) {
// When there is a repeat, the push-down condition is consistent with the repeat
canPushDownSlots.addAll(aggregate.getSourceRepeat().get().getCommonGroupingSetExpressions());
} else {
for (Expression groupByExpression : aggregate.getGroupByExpressions()) {
if (groupByExpression instanceof Slot) {
canPushDownSlots.add((Slot) groupByExpression);
}
}
}
LogicalFilter bottomFilter = new LogicalFilter<>(pushDownPredicates, aggregate.child(0));
aggregate = aggregate.withChildren(ImmutableList.of(bottomFilter));
return PlanUtils.filterOrSelf(filterPredicates, aggregate);
return canPushDownSlots;
}
}

View File

@ -20,14 +20,12 @@ package org.apache.doris.nereids.rules.rewrite.logical;
import org.apache.doris.nereids.rules.Rule;
import org.apache.doris.nereids.rules.RuleType;
import org.apache.doris.nereids.rules.rewrite.OneRewriteRuleFactory;
import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.expressions.Slot;
import org.apache.doris.nereids.trees.plans.Plan;
import org.apache.doris.nereids.trees.plans.logical.LogicalAggregate;
import org.apache.doris.nereids.trees.plans.logical.LogicalJoin;
import java.util.HashSet;
import java.util.Set;
import java.util.stream.Collectors;
/**
* Pushdown semi-join through agg
@ -39,26 +37,20 @@ public class SemiJoinAggTranspose extends OneRewriteRuleFactory {
.when(join -> join.getJoinType().isLeftSemiOrAntiJoin())
.then(join -> {
LogicalAggregate<Plan> aggregate = join.left();
Set<Slot> canPushDownSlots = new HashSet<>();
if (aggregate.hasRepeat()) {
// When there is a repeat, the push-down condition is consistent with the repeat
canPushDownSlots.addAll(aggregate.getSourceRepeat().get().getCommonGroupingSetExpressions());
} else {
for (Expression groupByExpression : aggregate.getGroupByExpressions()) {
if (groupByExpression instanceof Slot) {
canPushDownSlots.add((Slot) groupByExpression);
}
}
}
Set<Slot> leftOutputSet = join.left().getOutputSet();
Set<Slot> conditionSlot = join.getConditionSlot()
.stream()
.filter(leftOutputSet::contains)
.collect(Collectors.toSet());
if (!canPushDownSlots.containsAll(conditionSlot)) {
if (!canTranspose(aggregate, join)) {
return null;
}
return aggregate.withChildren(join.withChildren(aggregate.child(), join.right()));
}).toRule(RuleType.LOGICAL_SEMI_JOIN_AGG_TRANSPOSE);
}
/**
* check if we can transpose agg and semi join
*/
public static boolean canTranspose(LogicalAggregate<? extends Plan> aggregate,
LogicalJoin<? extends Plan, ? extends Plan> join) {
Set<Slot> canPushDownSlots = PushdownFilterThroughAggregation.getCanPushDownSlots(aggregate);
Set<Slot> leftConditionSlot = join.getLeftConditionSlot();
return canPushDownSlots.containsAll(leftConditionSlot);
}
}

View File

@ -20,16 +20,11 @@ package org.apache.doris.nereids.rules.rewrite.logical;
import org.apache.doris.nereids.rules.Rule;
import org.apache.doris.nereids.rules.RuleType;
import org.apache.doris.nereids.rules.rewrite.OneRewriteRuleFactory;
import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.expressions.Slot;
import org.apache.doris.nereids.trees.plans.Plan;
import org.apache.doris.nereids.trees.plans.logical.LogicalAggregate;
import org.apache.doris.nereids.trees.plans.logical.LogicalProject;
import java.util.HashSet;
import java.util.Set;
import java.util.stream.Collectors;
/**
* Pushdown semi-join through agg
*/
@ -38,27 +33,12 @@ public class SemiJoinAggTransposeProject extends OneRewriteRuleFactory {
public Rule build() {
return logicalJoin(logicalProject(logicalAggregate()), any())
.when(join -> join.getJoinType().isLeftSemiOrAntiJoin())
.when(join -> join.left().isAllSlots())
.when(join -> join.left().getProjects().stream().allMatch(n -> n instanceof Slot))
.then(join -> {
LogicalProject<LogicalAggregate<Plan>> project = join.left();
LogicalAggregate<Plan> aggregate = project.child();
Set<Slot> canPushDownSlots = new HashSet<>();
if (aggregate.hasRepeat()) {
// When there is a repeat, the push-down condition is consistent with the repeat
canPushDownSlots.addAll(aggregate.getSourceRepeat().get().getCommonGroupingSetExpressions());
} else {
for (Expression groupByExpression : aggregate.getGroupByExpressions()) {
if (groupByExpression instanceof Slot) {
canPushDownSlots.add((Slot) groupByExpression);
}
}
}
Set<Slot> leftOutputSet = join.left().getOutputSet();
Set<Slot> conditionSlot = join.getConditionSlot()
.stream()
.filter(leftOutputSet::contains)
.collect(Collectors.toSet());
if (!canPushDownSlots.containsAll(conditionSlot)) {
if (!SemiJoinAggTranspose.canTranspose(aggregate, join)) {
return null;
}
Plan newPlan = aggregate.withChildren(join.withChildren(aggregate.child(), join.right()));

View File

@ -46,6 +46,7 @@ public class SemiJoinLogicalJoinTransposeProject extends OneRewriteRuleFactory {
&& (topJoin.left().child().getJoinType().isInnerJoin()
|| topJoin.left().child().getJoinType().isLeftOuterJoin()
|| topJoin.left().child().getJoinType().isRightOuterJoin())))
.when(join -> join.left().isAllSlots())
.whenNot(join -> join.hasJoinHint() || join.left().child().hasJoinHint())
.whenNot(join -> join.isMarkJoin() || join.left().child().isMarkJoin())
.when(join -> join.left().getProjects().stream().allMatch(expr -> expr instanceof Slot))

View File

@ -145,6 +145,14 @@ public class LogicalJoin<LEFT_CHILD_TYPE extends Plan, RIGHT_CHILD_TYPE extends
.flatMap(expr -> expr.getInputSlotExprIds().stream()).collect(Collectors.toSet());
}
public Set<Slot> getLeftConditionSlot() {
Set<Slot> leftOutputSet = this.left().getOutputSet();
return Stream.concat(hashJoinConjuncts.stream(), otherJoinConjuncts.stream())
.flatMap(expr -> expr.getInputSlots().stream())
.filter(leftOutputSet::contains)
.collect(ImmutableSet.toImmutableSet());
}
public Optional<Expression> getOnClauseCondition() {
return ExpressionUtils.optionalAnd(hashJoinConjuncts, otherJoinConjuncts);
}