diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/PushdownCountThroughJoin.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/PushdownCountThroughJoin.java index 6d5bf8b75f..4f0f63e547 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/PushdownCountThroughJoin.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/PushdownCountThroughJoin.java @@ -43,8 +43,28 @@ import java.util.Map; import java.util.Set; /** - * Count(*) - * Count(col) + * TODO: distinct | just push one level + * Support Pushdown Count(*)/Count(col). + * Count(col) -> Sum( cnt * cntStar ) + * Count(*) -> Sum( leftCntStar * rightCntStar ) + *
+ * Related paper "Eager aggregation and lazy aggregation". + *
+ * aggregate: count(x) + * | + * join + * | \ + * | * + * (x) + * -> + * aggregate: Sum( cnt * cntStar ) + * | + * join + * | \ + * | aggregate: count(*) as cntStar + * aggregate: count(x) as cnt + *+ * Notice: when Count(*) exists, group by mustn't be empty. */ public class PushdownCountThroughJoin implements RewriteRuleFactory { @Override @@ -57,7 +77,8 @@ public class PushdownCountThroughJoin implements RewriteRuleFactory { .when(agg -> { Set
* aggregate: Min/Max(x)
@@ -69,7 +70,8 @@ public class PushdownMinMaxThroughJoin implements RewriteRuleFactory {
.when(agg -> {
Set funcs = agg.getAggregateFunctions();
return !funcs.isEmpty() && funcs.stream()
- .allMatch(f -> (f instanceof Min || f instanceof Max) && f.child(0) instanceof Slot);
+ .allMatch(f -> (f instanceof Min || f instanceof Max) && !f.isDistinct() && f.child(
+ 0) instanceof Slot);
})
.then(agg -> pushMinMax(agg, agg.child(), ImmutableList.of()))
.toRule(RuleType.PUSHDOWN_MIN_MAX_THROUGH_JOIN),
@@ -80,7 +82,9 @@ public class PushdownMinMaxThroughJoin implements RewriteRuleFactory {
.when(agg -> {
Set funcs = agg.getAggregateFunctions();
return !funcs.isEmpty() && funcs.stream()
- .allMatch(f -> (f instanceof Min || f instanceof Max) && f.child(0) instanceof Slot);
+ .allMatch(
+ f -> (f instanceof Min || f instanceof Max) && !f.isDistinct() && f.child(
+ 0) instanceof Slot);
})
.then(agg -> pushMinMax(agg, agg.child().child(), agg.child().getProjects()))
.toRule(RuleType.PUSHDOWN_MIN_MAX_THROUGH_JOIN)
diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/PushdownSumThroughJoin.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/PushdownSumThroughJoin.java
index 1319200220..81e655ab85 100644
--- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/PushdownSumThroughJoin.java
+++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/PushdownSumThroughJoin.java
@@ -42,6 +42,7 @@ import java.util.Map;
import java.util.Set;
/**
+ * TODO: distinct
* Related paper "Eager aggregation and lazy aggregation".
*
* aggregate: Sum(x)
@@ -69,7 +70,7 @@ public class PushdownSumThroughJoin implements RewriteRuleFactory {
.when(agg -> {
Set funcs = agg.getAggregateFunctions();
return !funcs.isEmpty() && funcs.stream()
- .allMatch(f -> f instanceof Sum && f.child(0) instanceof Slot);
+ .allMatch(f -> f instanceof Sum && !f.isDistinct() && f.child(0) instanceof Slot);
})
.then(agg -> pushSum(agg, agg.child(), ImmutableList.of()))
.toRule(RuleType.PUSHDOWN_SUM_THROUGH_JOIN),
@@ -80,7 +81,7 @@ public class PushdownSumThroughJoin implements RewriteRuleFactory {
.when(agg -> {
Set funcs = agg.getAggregateFunctions();
return !funcs.isEmpty() && funcs.stream()
- .allMatch(f -> f instanceof Sum && f.child(0) instanceof Slot);
+ .allMatch(f -> f instanceof Sum && !f.isDistinct() && f.child(0) instanceof Slot);
})
.then(agg -> pushSum(agg, agg.child().child(), agg.child().getProjects()))
.toRule(RuleType.PUSHDOWN_SUM_THROUGH_JOIN)
diff --git a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/PushdownCountThroughJoinTest.java b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/PushdownCountThroughJoinTest.java
index 1a39a7c5ff..10f814e9ba 100644
--- a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/PushdownCountThroughJoinTest.java
+++ b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/PushdownCountThroughJoinTest.java
@@ -65,4 +65,46 @@ class PushdownCountThroughJoinTest implements MemoPatternMatchSupported {
.printlnTree();
}
+ @Test
+ void testSingleCountStar() {
+ Alias count = new Count().alias("countStar");
+ LogicalPlan plan = new LogicalPlanBuilder(scan1)
+ .join(scan2, JoinType.INNER_JOIN, Pair.of(0, 0))
+ .aggGroupUsingIndex(ImmutableList.of(0), ImmutableList.of(scan1.getOutput().get(0), count))
+ .build();
+
+ PlanChecker.from(MemoTestUtils.createConnectContext(), plan)
+ .applyTopDown(new PushdownCountThroughJoin())
+ .printlnTree();
+ }
+
+ @Test
+ void testSingleCountStarEmptyGroupBy() {
+ Alias count = new Count().alias("countStar");
+ LogicalPlan plan = new LogicalPlanBuilder(scan1)
+ .join(scan2, JoinType.INNER_JOIN, Pair.of(0, 0))
+ .aggGroupUsingIndex(ImmutableList.of(), ImmutableList.of(count))
+ .build();
+
+ // shouldn't rewrite.
+ PlanChecker.from(MemoTestUtils.createConnectContext(), plan)
+ .applyTopDown(new PushdownCountThroughJoin())
+ .printlnTree();
+ }
+
+ @Test
+ void testBothSideCountAndCountStar() {
+ Alias leftCnt = new Count(scan1.getOutput().get(0)).alias("leftCnt");
+ Alias rightCnt = new Count(scan2.getOutput().get(0)).alias("rightCnt");
+ Alias countStar = new Count().alias("countStar");
+ LogicalPlan plan = new LogicalPlanBuilder(scan1)
+ .join(scan2, JoinType.INNER_JOIN, Pair.of(0, 0))
+ .aggGroupUsingIndex(ImmutableList.of(0),
+ ImmutableList.of(scan1.getOutput().get(0), leftCnt, rightCnt, countStar))
+ .build();
+
+ PlanChecker.from(MemoTestUtils.createConnectContext(), plan)
+ .applyTopDown(new PushdownCountThroughJoin())
+ .printlnTree();
+ }
}