diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/PushdownCountThroughJoin.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/PushdownCountThroughJoin.java index 6d5bf8b75f..4f0f63e547 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/PushdownCountThroughJoin.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/PushdownCountThroughJoin.java @@ -43,8 +43,28 @@ import java.util.Map; import java.util.Set; /** - * Count(*) - * Count(col) + * TODO: distinct | just push one level + * Support Pushdown Count(*)/Count(col). + * Count(col) -> Sum( cnt * cntStar ) + * Count(*) -> Sum( leftCntStar * rightCntStar ) + *

+ * Related paper "Eager aggregation and lazy aggregation". + *

+ *  aggregate: count(x)
+ *  |
+ *  join
+ *  |   \
+ *  |    *
+ *  (x)
+ *  ->
+ *  aggregate: Sum( cnt * cntStar )
+ *  |
+ *  join
+ *  |   \
+ *  |    aggregate: count(*) as cntStar
+ *  aggregate: count(x) as cnt
+ *  
+ * Notice: when Count(*) exists, group by mustn't be empty. */ public class PushdownCountThroughJoin implements RewriteRuleFactory { @Override @@ -57,7 +77,8 @@ public class PushdownCountThroughJoin implements RewriteRuleFactory { .when(agg -> { Set funcs = agg.getAggregateFunctions(); return !funcs.isEmpty() && funcs.stream() - .allMatch(f -> f instanceof Count && f.child(0) instanceof Slot); + .allMatch(f -> f instanceof Count && !f.isDistinct() + && (((Count) f).isCountStar() || f.child(0) instanceof Slot)); }) .then(agg -> pushCount(agg, agg.child(), ImmutableList.of())) .toRule(RuleType.PUSHDOWN_COUNT_THROUGH_JOIN), @@ -69,7 +90,8 @@ public class PushdownCountThroughJoin implements RewriteRuleFactory { .when(agg -> { Set funcs = agg.getAggregateFunctions(); return !funcs.isEmpty() && funcs.stream() - .allMatch(f -> f instanceof Count && f.child(0) instanceof Slot); + .allMatch(f -> f instanceof Count && !f.isDistinct() + && (((Count) f).isCountStar() || f.child(0) instanceof Slot)); }) .then(agg -> pushCount(agg, agg.child().child(), agg.child().getProjects())) .toRule(RuleType.PUSHDOWN_COUNT_THROUGH_JOIN) @@ -83,23 +105,23 @@ public class PushdownCountThroughJoin implements RewriteRuleFactory { List leftCounts = new ArrayList<>(); List rightCounts = new ArrayList<>(); + List countStars = new ArrayList<>(); for (AggregateFunction f : agg.getAggregateFunctions()) { Count count = (Count) f; if (count.isCountStar()) { - // TODO: handle Count(*) - return null; - } - Slot slot = (Slot) count.child(0); - if (leftOutput.contains(slot)) { - leftCounts.add(count); - } else if (rightOutput.contains(slot)) { - rightCounts.add(count); + countStars.add(count); } else { - throw new IllegalStateException("Slot " + slot + " not found in join output"); + Slot slot = (Slot) count.child(0); + if (leftOutput.contains(slot)) { + leftCounts.add(count); + } else if (rightOutput.contains(slot)) { + rightCounts.add(count); + } else { + throw new IllegalStateException("Slot " + slot + " not found in join output"); + } } } - // TODO: empty GroupBy Set leftGroupBy = new HashSet<>(); Set rightGroupBy = new HashSet<>(); for (Expression e : agg.getGroupByExpressions()) { @@ -112,6 +134,11 @@ public class PushdownCountThroughJoin implements RewriteRuleFactory { return null; } } + + if (!countStars.isEmpty() && leftGroupBy.isEmpty() && rightGroupBy.isEmpty()) { + return null; + } + join.getHashJoinConjuncts().forEach(e -> e.getInputSlots().forEach(slot -> { if (leftOutput.contains(slot)) { leftGroupBy.add(slot); @@ -133,7 +160,7 @@ public class PushdownCountThroughJoin implements RewriteRuleFactory { leftCntSlotToOutput.put((Slot) func.child(0), alias); leftCntAggOutputBuilder.add(alias); }); - if (!rightCounts.isEmpty()) { + if (!rightCounts.isEmpty() || !countStars.isEmpty()) { leftCnt = new Count().alias("leftCntStar"); leftCntAggOutputBuilder.add(leftCnt); } @@ -150,7 +177,7 @@ public class PushdownCountThroughJoin implements RewriteRuleFactory { rightCntAggOutputBuilder.add(alias); }); - if (!leftCounts.isEmpty()) { + if (!leftCounts.isEmpty() || !countStars.isEmpty()) { rightCnt = new Count().alias("rightCntStar"); rightCntAggOutputBuilder.add(rightCnt); } @@ -160,22 +187,31 @@ public class PushdownCountThroughJoin implements RewriteRuleFactory { Plan newJoin = join.withChildren(leftCntAgg, rightCntAgg); // top Sum agg - // count(slot) -> sum( count(slot) * cnt ) + // count(slot) -> sum( count(slot) * cntStar ) + // count(*) -> sum( leftCntStar * leftCntStar ) List newOutputExprs = new ArrayList<>(); for (NamedExpression ne : agg.getOutputExpressions()) { if (ne instanceof Alias && ((Alias) ne).child() instanceof Count) { Count oldTopCnt = (Count) ((Alias) ne).child(); - Slot slot = (Slot) oldTopCnt.child(0); - if (leftCntSlotToOutput.containsKey(slot)) { - Preconditions.checkState(rightCnt != null); - Expression expr = new Sum(new Multiply(leftCntSlotToOutput.get(slot).toSlot(), rightCnt.toSlot())); - newOutputExprs.add((NamedExpression) ne.withChildren(expr)); - } else if (rightCntSlotToOutput.containsKey(slot)) { - Preconditions.checkState(leftCnt != null); - Expression expr = new Sum(new Multiply(rightCntSlotToOutput.get(slot).toSlot(), leftCnt.toSlot())); + if (oldTopCnt.isCountStar()) { + Preconditions.checkState(rightCnt != null && leftCnt != null); + Expression expr = new Sum(new Multiply(leftCnt.toSlot(), rightCnt.toSlot())); newOutputExprs.add((NamedExpression) ne.withChildren(expr)); } else { - throw new IllegalStateException("Slot " + slot + " not found in join output"); + Slot slot = (Slot) oldTopCnt.child(0); + if (leftCntSlotToOutput.containsKey(slot)) { + Preconditions.checkState(rightCnt != null); + Expression expr = new Sum( + new Multiply(leftCntSlotToOutput.get(slot).toSlot(), rightCnt.toSlot())); + newOutputExprs.add((NamedExpression) ne.withChildren(expr)); + } else if (rightCntSlotToOutput.containsKey(slot)) { + Preconditions.checkState(leftCnt != null); + Expression expr = new Sum( + new Multiply(rightCntSlotToOutput.get(slot).toSlot(), leftCnt.toSlot())); + newOutputExprs.add((NamedExpression) ne.withChildren(expr)); + } else { + throw new IllegalStateException("Slot " + slot + " not found in join output"); + } } } else { newOutputExprs.add(ne); diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/PushdownMinMaxThroughJoin.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/PushdownMinMaxThroughJoin.java index bd61f4f0ac..9b728ad141 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/PushdownMinMaxThroughJoin.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/PushdownMinMaxThroughJoin.java @@ -42,6 +42,7 @@ import java.util.Map; import java.util.Set; /** + * TODO: distinct * Related paper "Eager aggregation and lazy aggregation". *
  * aggregate: Min/Max(x)
@@ -69,7 +70,8 @@ public class PushdownMinMaxThroughJoin implements RewriteRuleFactory {
                         .when(agg -> {
                             Set funcs = agg.getAggregateFunctions();
                             return !funcs.isEmpty() && funcs.stream()
-                                .allMatch(f -> (f instanceof Min || f instanceof Max) && f.child(0) instanceof Slot);
+                                    .allMatch(f -> (f instanceof Min || f instanceof Max) && !f.isDistinct() && f.child(
+                                            0) instanceof Slot);
                         })
                         .then(agg -> pushMinMax(agg, agg.child(), ImmutableList.of()))
                         .toRule(RuleType.PUSHDOWN_MIN_MAX_THROUGH_JOIN),
@@ -80,7 +82,9 @@ public class PushdownMinMaxThroughJoin implements RewriteRuleFactory {
                         .when(agg -> {
                             Set funcs = agg.getAggregateFunctions();
                             return !funcs.isEmpty() && funcs.stream()
-                                .allMatch(f -> (f instanceof Min || f instanceof Max) && f.child(0) instanceof Slot);
+                                    .allMatch(
+                                            f -> (f instanceof Min || f instanceof Max) && !f.isDistinct() && f.child(
+                                                    0) instanceof Slot);
                         })
                         .then(agg -> pushMinMax(agg, agg.child().child(), agg.child().getProjects()))
                         .toRule(RuleType.PUSHDOWN_MIN_MAX_THROUGH_JOIN)
diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/PushdownSumThroughJoin.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/PushdownSumThroughJoin.java
index 1319200220..81e655ab85 100644
--- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/PushdownSumThroughJoin.java
+++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/PushdownSumThroughJoin.java
@@ -42,6 +42,7 @@ import java.util.Map;
 import java.util.Set;
 
 /**
+ * TODO: distinct
  * Related paper "Eager aggregation and lazy aggregation".
  * 
  * aggregate: Sum(x)
@@ -69,7 +70,7 @@ public class PushdownSumThroughJoin implements RewriteRuleFactory {
                         .when(agg -> {
                             Set funcs = agg.getAggregateFunctions();
                             return !funcs.isEmpty() && funcs.stream()
-                                    .allMatch(f -> f instanceof Sum && f.child(0) instanceof Slot);
+                                    .allMatch(f -> f instanceof Sum && !f.isDistinct() && f.child(0) instanceof Slot);
                         })
                         .then(agg -> pushSum(agg, agg.child(), ImmutableList.of()))
                         .toRule(RuleType.PUSHDOWN_SUM_THROUGH_JOIN),
@@ -80,7 +81,7 @@ public class PushdownSumThroughJoin implements RewriteRuleFactory {
                         .when(agg -> {
                             Set funcs = agg.getAggregateFunctions();
                             return !funcs.isEmpty() && funcs.stream()
-                                    .allMatch(f -> f instanceof Sum && f.child(0) instanceof Slot);
+                                    .allMatch(f -> f instanceof Sum && !f.isDistinct() && f.child(0) instanceof Slot);
                         })
                         .then(agg -> pushSum(agg, agg.child().child(), agg.child().getProjects()))
                         .toRule(RuleType.PUSHDOWN_SUM_THROUGH_JOIN)
diff --git a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/PushdownCountThroughJoinTest.java b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/PushdownCountThroughJoinTest.java
index 1a39a7c5ff..10f814e9ba 100644
--- a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/PushdownCountThroughJoinTest.java
+++ b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/PushdownCountThroughJoinTest.java
@@ -65,4 +65,46 @@ class PushdownCountThroughJoinTest implements MemoPatternMatchSupported {
                 .printlnTree();
     }
 
+    @Test
+    void testSingleCountStar() {
+        Alias count = new Count().alias("countStar");
+        LogicalPlan plan = new LogicalPlanBuilder(scan1)
+                .join(scan2, JoinType.INNER_JOIN, Pair.of(0, 0))
+                .aggGroupUsingIndex(ImmutableList.of(0), ImmutableList.of(scan1.getOutput().get(0), count))
+                .build();
+
+        PlanChecker.from(MemoTestUtils.createConnectContext(), plan)
+                .applyTopDown(new PushdownCountThroughJoin())
+                .printlnTree();
+    }
+
+    @Test
+    void testSingleCountStarEmptyGroupBy() {
+        Alias count = new Count().alias("countStar");
+        LogicalPlan plan = new LogicalPlanBuilder(scan1)
+                .join(scan2, JoinType.INNER_JOIN, Pair.of(0, 0))
+                .aggGroupUsingIndex(ImmutableList.of(), ImmutableList.of(count))
+                .build();
+
+        // shouldn't rewrite.
+        PlanChecker.from(MemoTestUtils.createConnectContext(), plan)
+                .applyTopDown(new PushdownCountThroughJoin())
+                .printlnTree();
+    }
+
+    @Test
+    void testBothSideCountAndCountStar() {
+        Alias leftCnt = new Count(scan1.getOutput().get(0)).alias("leftCnt");
+        Alias rightCnt = new Count(scan2.getOutput().get(0)).alias("rightCnt");
+        Alias countStar = new Count().alias("countStar");
+        LogicalPlan plan = new LogicalPlanBuilder(scan1)
+                .join(scan2, JoinType.INNER_JOIN, Pair.of(0, 0))
+                .aggGroupUsingIndex(ImmutableList.of(0),
+                        ImmutableList.of(scan1.getOutput().get(0), leftCnt, rightCnt, countStar))
+                .build();
+
+        PlanChecker.from(MemoTestUtils.createConnectContext(), plan)
+                .applyTopDown(new PushdownCountThroughJoin())
+                .printlnTree();
+    }
 }