[refactor](Nereids): Eager Aggregation unify pushdown agg function (#30142)

This commit is contained in:
jakevin
2024-01-22 12:23:50 +08:00
committed by yiguolei
parent 06f8266ca2
commit ad1c19bd65
3 changed files with 19 additions and 119 deletions

View File

@ -81,7 +81,7 @@ public class PushDownMinMaxThroughJoin implements RewriteRuleFactory {
return null;
}
LogicalAggregate<LogicalJoin<Plan, Plan>> agg = ctx.root;
return pushMinMax(agg, agg.child(), ImmutableList.of());
return pushMinMaxSum(agg, agg.child(), ImmutableList.of());
})
.toRule(RuleType.PUSH_DOWN_MIN_MAX_THROUGH_JOIN),
logicalAggregate(logicalProject(innerLogicalJoin()))
@ -102,13 +102,16 @@ public class PushDownMinMaxThroughJoin implements RewriteRuleFactory {
return null;
}
LogicalAggregate<LogicalProject<LogicalJoin<Plan, Plan>>> agg = ctx.root;
return pushMinMax(agg, agg.child().child(), agg.child().getProjects());
return pushMinMaxSum(agg, agg.child().child(), agg.child().getProjects());
})
.toRule(RuleType.PUSH_DOWN_MIN_MAX_THROUGH_JOIN)
);
}
private LogicalAggregate<Plan> pushMinMax(LogicalAggregate<? extends Plan> agg,
/**
* Push down Min/Max/Sum through join.
*/
public static LogicalAggregate<Plan> pushMinMaxSum(LogicalAggregate<? extends Plan> agg,
LogicalJoin<Plan, Plan> join, List<NamedExpression> projects) {
List<Slot> leftOutput = join.left().getOutput();
List<Slot> rightOutput = join.right().getOutput();
@ -125,6 +128,9 @@ public class PushDownMinMaxThroughJoin implements RewriteRuleFactory {
throw new IllegalStateException("Slot " + slot + " not found in join output");
}
}
if (leftFuncs.isEmpty() && rightFuncs.isEmpty()) {
return null;
}
Set<Slot> leftGroupBy = new HashSet<>();
Set<Slot> rightGroupBy = new HashSet<>();
@ -177,6 +183,11 @@ public class PushDownMinMaxThroughJoin implements RewriteRuleFactory {
Preconditions.checkState(left != join.left() || right != join.right());
Plan newJoin = join.withChildren(left, right);
// top agg
// replace
// min(x) -> min(min#)
// max(x) -> max(max#)
// sum(x) -> sum(sum#)
List<NamedExpression> newOutputExprs = new ArrayList<>();
for (NamedExpression ne : agg.getOutputExpressions()) {
if (ne instanceof Alias && ((Alias) ne).child() instanceof AggregateFunction) {

View File

@ -53,12 +53,12 @@ import java.util.Set;
* | *
* (x)
* ->
* aggregate: Sum(min1)
* aggregate: Sum(sum1)
* |
* join
* | \
* | *
* aggregate: Sum(x) as min1
* aggregate: Sum(x) as sum1
* </pre>
*/
public class PushDownSumThroughJoin implements RewriteRuleFactory {

View File

@ -19,9 +19,6 @@ package org.apache.doris.nereids.rules.rewrite;
import org.apache.doris.nereids.rules.Rule;
import org.apache.doris.nereids.rules.RuleType;
import org.apache.doris.nereids.trees.expressions.Alias;
import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.expressions.NamedExpression;
import org.apache.doris.nereids.trees.expressions.Slot;
import org.apache.doris.nereids.trees.expressions.functions.agg.AggregateFunction;
import org.apache.doris.nereids.trees.expressions.functions.agg.Sum;
@ -30,15 +27,9 @@ import org.apache.doris.nereids.trees.plans.logical.LogicalAggregate;
import org.apache.doris.nereids.trees.plans.logical.LogicalJoin;
import org.apache.doris.nereids.trees.plans.logical.LogicalProject;
import com.google.common.base.Preconditions;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableList.Builder;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;
/**
@ -79,7 +70,7 @@ public class PushDownSumThroughJoinOneSide implements RewriteRuleFactory {
return null;
}
LogicalAggregate<LogicalJoin<Plan, Plan>> agg = ctx.root;
return pushSum(agg, agg.child(), ImmutableList.of());
return PushDownMinMaxThroughJoin.pushMinMaxSum(agg, agg.child(), ImmutableList.of());
})
.toRule(RuleType.PUSH_DOWN_SUM_THROUGH_JOIN),
logicalAggregate(logicalProject(innerLogicalJoin()))
@ -98,112 +89,10 @@ public class PushDownSumThroughJoinOneSide implements RewriteRuleFactory {
return null;
}
LogicalAggregate<LogicalProject<LogicalJoin<Plan, Plan>>> agg = ctx.root;
return pushSum(agg, agg.child().child(), agg.child().getProjects());
return PushDownMinMaxThroughJoin.pushMinMaxSum(agg, agg.child().child(),
agg.child().getProjects());
})
.toRule(RuleType.PUSH_DOWN_SUM_THROUGH_JOIN)
);
}
private LogicalAggregate<Plan> pushSum(LogicalAggregate<? extends Plan> agg,
LogicalJoin<Plan, Plan> join, List<NamedExpression> projects) {
List<Slot> leftOutput = join.left().getOutput();
List<Slot> rightOutput = join.right().getOutput();
List<Sum> leftSums = new ArrayList<>();
List<Sum> rightSums = new ArrayList<>();
for (AggregateFunction f : agg.getAggregateFunctions()) {
Sum sum = (Sum) f;
Slot slot = (Slot) sum.child();
if (leftOutput.contains(slot)) {
leftSums.add(sum);
} else if (rightOutput.contains(slot)) {
rightSums.add(sum);
} else {
throw new IllegalStateException("Slot " + slot + " not found in join output");
}
}
if (leftSums.isEmpty() && rightSums.isEmpty()) {
return null;
}
Set<Slot> leftGroupBy = new HashSet<>();
Set<Slot> rightGroupBy = new HashSet<>();
for (Expression e : agg.getGroupByExpressions()) {
Slot slot = (Slot) e;
if (leftOutput.contains(slot)) {
leftGroupBy.add(slot);
} else if (rightOutput.contains(slot)) {
rightGroupBy.add(slot);
} else {
return null;
}
}
join.getHashJoinConjuncts().forEach(e -> e.getInputSlots().forEach(slot -> {
if (leftOutput.contains(slot)) {
leftGroupBy.add(slot);
} else if (rightOutput.contains(slot)) {
rightGroupBy.add(slot);
} else {
throw new IllegalStateException("Slot " + slot + " not found in join output");
}
}));
Plan left = join.left();
Plan right = join.right();
Map<Slot, NamedExpression> leftSumSlotToOutput = new HashMap<>();
Map<Slot, NamedExpression> rightSumSlotToOutput = new HashMap<>();
// left Sum agg
if (!leftSums.isEmpty()) {
Builder<NamedExpression> leftSumAggOutputBuilder = ImmutableList.<NamedExpression>builder()
.addAll(leftGroupBy);
leftSums.forEach(func -> {
Alias alias = func.alias(func.getName());
leftSumSlotToOutput.put((Slot) func.child(0), alias);
leftSumAggOutputBuilder.add(alias);
});
left = new LogicalAggregate<>(ImmutableList.copyOf(leftGroupBy), leftSumAggOutputBuilder.build(),
join.left());
}
// right Sum agg
if (!rightSums.isEmpty()) {
Builder<NamedExpression> rightSumAggOutputBuilder = ImmutableList.<NamedExpression>builder()
.addAll(rightGroupBy);
rightSums.forEach(func -> {
Alias alias = func.alias(func.getName());
rightSumSlotToOutput.put((Slot) func.child(0), alias);
rightSumAggOutputBuilder.add(alias);
});
right = new LogicalAggregate<>(ImmutableList.copyOf(rightGroupBy), rightSumAggOutputBuilder.build(),
join.right());
}
Preconditions.checkState(left != join.left() || right != join.right());
Plan newJoin = join.withChildren(left, right);
// top Sum agg
// replace sum(x) -> sum(sum#)
List<NamedExpression> newOutputExprs = new ArrayList<>();
for (NamedExpression ne : agg.getOutputExpressions()) {
if (ne instanceof Alias && ((Alias) ne).child() instanceof Sum) {
Sum oldTopSum = (Sum) ((Alias) ne).child();
Slot slot = (Slot) oldTopSum.child(0);
if (leftSumSlotToOutput.containsKey(slot)) {
Expression expr = new Sum(leftSumSlotToOutput.get(slot).toSlot());
newOutputExprs.add((NamedExpression) ne.withChildren(expr));
} else if (rightSumSlotToOutput.containsKey(slot)) {
Expression expr = new Sum(rightSumSlotToOutput.get(slot).toSlot());
newOutputExprs.add((NamedExpression) ne.withChildren(expr));
} else {
throw new IllegalStateException("Slot " + slot + " not found in join output");
}
} else {
newOutputExprs.add(ne);
}
}
return agg.withAggOutputChild(newOutputExprs, newJoin);
}
}