From ad1c19bd65644bc19392eef6ebc2e845b342fd37 Mon Sep 17 00:00:00 2001 From: jakevin Date: Mon, 22 Jan 2024 12:23:50 +0800 Subject: [PATCH] [refactor](Nereids): Eager Aggregation unify pushdown agg function (#30142) --- .../rewrite/PushDownMinMaxThroughJoin.java | 17 ++- .../rules/rewrite/PushDownSumThroughJoin.java | 4 +- .../PushDownSumThroughJoinOneSide.java | 117 +----------------- 3 files changed, 19 insertions(+), 119 deletions(-) diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/PushDownMinMaxThroughJoin.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/PushDownMinMaxThroughJoin.java index 48ded00def..3057f1eafc 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/PushDownMinMaxThroughJoin.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/PushDownMinMaxThroughJoin.java @@ -81,7 +81,7 @@ public class PushDownMinMaxThroughJoin implements RewriteRuleFactory { return null; } LogicalAggregate> agg = ctx.root; - return pushMinMax(agg, agg.child(), ImmutableList.of()); + return pushMinMaxSum(agg, agg.child(), ImmutableList.of()); }) .toRule(RuleType.PUSH_DOWN_MIN_MAX_THROUGH_JOIN), logicalAggregate(logicalProject(innerLogicalJoin())) @@ -102,13 +102,16 @@ public class PushDownMinMaxThroughJoin implements RewriteRuleFactory { return null; } LogicalAggregate>> agg = ctx.root; - return pushMinMax(agg, agg.child().child(), agg.child().getProjects()); + return pushMinMaxSum(agg, agg.child().child(), agg.child().getProjects()); }) .toRule(RuleType.PUSH_DOWN_MIN_MAX_THROUGH_JOIN) ); } - private LogicalAggregate pushMinMax(LogicalAggregate agg, + /** + * Push down Min/Max/Sum through join. + */ + public static LogicalAggregate pushMinMaxSum(LogicalAggregate agg, LogicalJoin join, List projects) { List leftOutput = join.left().getOutput(); List rightOutput = join.right().getOutput(); @@ -125,6 +128,9 @@ public class PushDownMinMaxThroughJoin implements RewriteRuleFactory { throw new IllegalStateException("Slot " + slot + " not found in join output"); } } + if (leftFuncs.isEmpty() && rightFuncs.isEmpty()) { + return null; + } Set leftGroupBy = new HashSet<>(); Set rightGroupBy = new HashSet<>(); @@ -177,6 +183,11 @@ public class PushDownMinMaxThroughJoin implements RewriteRuleFactory { Preconditions.checkState(left != join.left() || right != join.right()); Plan newJoin = join.withChildren(left, right); + // top agg + // replace + // min(x) -> min(min#) + // max(x) -> max(max#) + // sum(x) -> sum(sum#) List newOutputExprs = new ArrayList<>(); for (NamedExpression ne : agg.getOutputExpressions()) { if (ne instanceof Alias && ((Alias) ne).child() instanceof AggregateFunction) { diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/PushDownSumThroughJoin.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/PushDownSumThroughJoin.java index 91cb2a6050..e8987e670a 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/PushDownSumThroughJoin.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/PushDownSumThroughJoin.java @@ -53,12 +53,12 @@ import java.util.Set; * | * * (x) * -> - * aggregate: Sum(min1) + * aggregate: Sum(sum1) * | * join * | \ * | * - * aggregate: Sum(x) as min1 + * aggregate: Sum(x) as sum1 * */ public class PushDownSumThroughJoin implements RewriteRuleFactory { diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/PushDownSumThroughJoinOneSide.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/PushDownSumThroughJoinOneSide.java index 3f4fa09cd7..88b13b383a 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/PushDownSumThroughJoinOneSide.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/PushDownSumThroughJoinOneSide.java @@ -19,9 +19,6 @@ package org.apache.doris.nereids.rules.rewrite; import org.apache.doris.nereids.rules.Rule; import org.apache.doris.nereids.rules.RuleType; -import org.apache.doris.nereids.trees.expressions.Alias; -import org.apache.doris.nereids.trees.expressions.Expression; -import org.apache.doris.nereids.trees.expressions.NamedExpression; import org.apache.doris.nereids.trees.expressions.Slot; import org.apache.doris.nereids.trees.expressions.functions.agg.AggregateFunction; import org.apache.doris.nereids.trees.expressions.functions.agg.Sum; @@ -30,15 +27,9 @@ import org.apache.doris.nereids.trees.plans.logical.LogicalAggregate; import org.apache.doris.nereids.trees.plans.logical.LogicalJoin; import org.apache.doris.nereids.trees.plans.logical.LogicalProject; -import com.google.common.base.Preconditions; import com.google.common.collect.ImmutableList; -import com.google.common.collect.ImmutableList.Builder; -import java.util.ArrayList; -import java.util.HashMap; -import java.util.HashSet; import java.util.List; -import java.util.Map; import java.util.Set; /** @@ -79,7 +70,7 @@ public class PushDownSumThroughJoinOneSide implements RewriteRuleFactory { return null; } LogicalAggregate> agg = ctx.root; - return pushSum(agg, agg.child(), ImmutableList.of()); + return PushDownMinMaxThroughJoin.pushMinMaxSum(agg, agg.child(), ImmutableList.of()); }) .toRule(RuleType.PUSH_DOWN_SUM_THROUGH_JOIN), logicalAggregate(logicalProject(innerLogicalJoin())) @@ -98,112 +89,10 @@ public class PushDownSumThroughJoinOneSide implements RewriteRuleFactory { return null; } LogicalAggregate>> agg = ctx.root; - return pushSum(agg, agg.child().child(), agg.child().getProjects()); + return PushDownMinMaxThroughJoin.pushMinMaxSum(agg, agg.child().child(), + agg.child().getProjects()); }) .toRule(RuleType.PUSH_DOWN_SUM_THROUGH_JOIN) ); } - - private LogicalAggregate pushSum(LogicalAggregate agg, - LogicalJoin join, List projects) { - List leftOutput = join.left().getOutput(); - List rightOutput = join.right().getOutput(); - - List leftSums = new ArrayList<>(); - List rightSums = new ArrayList<>(); - for (AggregateFunction f : agg.getAggregateFunctions()) { - Sum sum = (Sum) f; - Slot slot = (Slot) sum.child(); - if (leftOutput.contains(slot)) { - leftSums.add(sum); - } else if (rightOutput.contains(slot)) { - rightSums.add(sum); - } else { - throw new IllegalStateException("Slot " + slot + " not found in join output"); - } - } - if (leftSums.isEmpty() && rightSums.isEmpty()) { - return null; - } - - Set leftGroupBy = new HashSet<>(); - Set rightGroupBy = new HashSet<>(); - for (Expression e : agg.getGroupByExpressions()) { - Slot slot = (Slot) e; - if (leftOutput.contains(slot)) { - leftGroupBy.add(slot); - } else if (rightOutput.contains(slot)) { - rightGroupBy.add(slot); - } else { - return null; - } - } - join.getHashJoinConjuncts().forEach(e -> e.getInputSlots().forEach(slot -> { - if (leftOutput.contains(slot)) { - leftGroupBy.add(slot); - } else if (rightOutput.contains(slot)) { - rightGroupBy.add(slot); - } else { - throw new IllegalStateException("Slot " + slot + " not found in join output"); - } - })); - - Plan left = join.left(); - Plan right = join.right(); - - Map leftSumSlotToOutput = new HashMap<>(); - Map rightSumSlotToOutput = new HashMap<>(); - - // left Sum agg - if (!leftSums.isEmpty()) { - Builder leftSumAggOutputBuilder = ImmutableList.builder() - .addAll(leftGroupBy); - leftSums.forEach(func -> { - Alias alias = func.alias(func.getName()); - leftSumSlotToOutput.put((Slot) func.child(0), alias); - leftSumAggOutputBuilder.add(alias); - }); - left = new LogicalAggregate<>(ImmutableList.copyOf(leftGroupBy), leftSumAggOutputBuilder.build(), - join.left()); - } - - // right Sum agg - if (!rightSums.isEmpty()) { - Builder rightSumAggOutputBuilder = ImmutableList.builder() - .addAll(rightGroupBy); - rightSums.forEach(func -> { - Alias alias = func.alias(func.getName()); - rightSumSlotToOutput.put((Slot) func.child(0), alias); - rightSumAggOutputBuilder.add(alias); - }); - right = new LogicalAggregate<>(ImmutableList.copyOf(rightGroupBy), rightSumAggOutputBuilder.build(), - join.right()); - } - - Preconditions.checkState(left != join.left() || right != join.right()); - Plan newJoin = join.withChildren(left, right); - - // top Sum agg - // replace sum(x) -> sum(sum#) - List newOutputExprs = new ArrayList<>(); - for (NamedExpression ne : agg.getOutputExpressions()) { - if (ne instanceof Alias && ((Alias) ne).child() instanceof Sum) { - Sum oldTopSum = (Sum) ((Alias) ne).child(); - - Slot slot = (Slot) oldTopSum.child(0); - if (leftSumSlotToOutput.containsKey(slot)) { - Expression expr = new Sum(leftSumSlotToOutput.get(slot).toSlot()); - newOutputExprs.add((NamedExpression) ne.withChildren(expr)); - } else if (rightSumSlotToOutput.containsKey(slot)) { - Expression expr = new Sum(rightSumSlotToOutput.get(slot).toSlot()); - newOutputExprs.add((NamedExpression) ne.withChildren(expr)); - } else { - throw new IllegalStateException("Slot " + slot + " not found in join output"); - } - } else { - newOutputExprs.add(ne); - } - } - return agg.withAggOutputChild(newOutputExprs, newJoin); - } }