diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/executor/Rewriter.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/executor/Rewriter.java
index 4b488b6cfe..2c0e57b715 100644
--- a/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/executor/Rewriter.java
+++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/executor/Rewriter.java
@@ -98,16 +98,14 @@ import org.apache.doris.nereids.rules.rewrite.PullUpProjectUnderTopN;
import org.apache.doris.nereids.rules.rewrite.PushConjunctsIntoEsScan;
import org.apache.doris.nereids.rules.rewrite.PushConjunctsIntoJdbcScan;
import org.apache.doris.nereids.rules.rewrite.PushConjunctsIntoOdbcScan;
+import org.apache.doris.nereids.rules.rewrite.PushDownAggThroughJoinOneSide;
import org.apache.doris.nereids.rules.rewrite.PushDownCountThroughJoin;
-import org.apache.doris.nereids.rules.rewrite.PushDownCountThroughJoinOneSide;
import org.apache.doris.nereids.rules.rewrite.PushDownDistinctThroughJoin;
import org.apache.doris.nereids.rules.rewrite.PushDownFilterThroughProject;
import org.apache.doris.nereids.rules.rewrite.PushDownLimit;
import org.apache.doris.nereids.rules.rewrite.PushDownLimitDistinctThroughJoin;
import org.apache.doris.nereids.rules.rewrite.PushDownLimitDistinctThroughUnion;
-import org.apache.doris.nereids.rules.rewrite.PushDownMinMaxThroughJoin;
import org.apache.doris.nereids.rules.rewrite.PushDownSumThroughJoin;
-import org.apache.doris.nereids.rules.rewrite.PushDownSumThroughJoinOneSide;
import org.apache.doris.nereids.rules.rewrite.PushDownTopNDistinctThroughJoin;
import org.apache.doris.nereids.rules.rewrite.PushDownTopNDistinctThroughUnion;
import org.apache.doris.nereids.rules.rewrite.PushDownTopNThroughJoin;
@@ -291,13 +289,9 @@ public class Rewriter extends AbstractBatchJobExecutor {
topic("Eager aggregation",
topDown(
new PushDownSumThroughJoin(),
- new PushDownMinMaxThroughJoin(),
+ new PushDownAggThroughJoinOneSide(),
new PushDownCountThroughJoin()
),
- topDown(
- new PushDownSumThroughJoinOneSide(),
- new PushDownCountThroughJoinOneSide()
- ),
custom(RuleType.PUSH_DOWN_DISTINCT_THROUGH_JOIN, PushDownDistinctThroughJoin::new)
),
diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/RuleType.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/RuleType.java
index 6a994c1b6e..58947760b4 100644
--- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/RuleType.java
+++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/RuleType.java
@@ -167,13 +167,10 @@ public enum RuleType {
COLUMN_PRUNING(RuleTypeClass.REWRITE),
ELIMINATE_SORT(RuleTypeClass.REWRITE),
- PUSH_DOWN_MIN_MAX_THROUGH_JOIN(RuleTypeClass.REWRITE),
+ PUSH_DOWN_AGG_THROUGH_JOIN_ONE_SIDE(RuleTypeClass.REWRITE),
PUSH_DOWN_SUM_THROUGH_JOIN(RuleTypeClass.REWRITE),
PUSH_DOWN_COUNT_THROUGH_JOIN(RuleTypeClass.REWRITE),
- PUSH_DOWN_SUM_THROUGH_JOIN_ONE_SIDE(RuleTypeClass.REWRITE),
- PUSH_DOWN_COUNT_THROUGH_JOIN_ONE_SIDE(RuleTypeClass.REWRITE),
-
TRANSPOSE_LOGICAL_SEMI_JOIN_LOGICAL_JOIN(RuleTypeClass.REWRITE),
TRANSPOSE_LOGICAL_SEMI_JOIN_LOGICAL_JOIN_PROJECT(RuleTypeClass.REWRITE),
LOGICAL_SEMI_JOIN_COMMUTE(RuleTypeClass.REWRITE),
diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/PushDownMinMaxThroughJoin.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/PushDownAggThroughJoinOneSide.java
similarity index 81%
rename from fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/PushDownMinMaxThroughJoin.java
rename to fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/PushDownAggThroughJoinOneSide.java
index 3057f1eafc..f32bf8ea91 100644
--- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/PushDownMinMaxThroughJoin.java
+++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/PushDownAggThroughJoinOneSide.java
@@ -24,8 +24,10 @@ import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.expressions.NamedExpression;
import org.apache.doris.nereids.trees.expressions.Slot;
import org.apache.doris.nereids.trees.expressions.functions.agg.AggregateFunction;
+import org.apache.doris.nereids.trees.expressions.functions.agg.Count;
import org.apache.doris.nereids.trees.expressions.functions.agg.Max;
import org.apache.doris.nereids.trees.expressions.functions.agg.Min;
+import org.apache.doris.nereids.trees.expressions.functions.agg.Sum;
import org.apache.doris.nereids.trees.plans.Plan;
import org.apache.doris.nereids.trees.plans.logical.LogicalAggregate;
import org.apache.doris.nereids.trees.plans.logical.LogicalJoin;
@@ -46,22 +48,22 @@ import java.util.Set;
* TODO: distinct
* Related paper "Eager aggregation and lazy aggregation".
*
- * aggregate: Min/Max(x)
+ * aggregate: Min/Max/Sum(x)
* |
* join
* | \
* | *
* (x)
* ->
- * aggregate: Min/Max(min1)
+ * aggregate: Min/Max/Sum(min1)
* |
* join
* | \
* | *
- * aggregate: Min/Max(x) as min1
+ * aggregate: Min/Max/Sum(x) as min1
*
*/
-public class PushDownMinMaxThroughJoin implements RewriteRuleFactory {
+public class PushDownAggThroughJoinOneSide implements RewriteRuleFactory {
@Override
public List buildRules() {
return ImmutableList.of(
@@ -71,19 +73,20 @@ public class PushDownMinMaxThroughJoin implements RewriteRuleFactory {
.when(agg -> {
Set funcs = agg.getAggregateFunctions();
return !funcs.isEmpty() && funcs.stream()
- .allMatch(f -> (f instanceof Min || f instanceof Max) && !f.isDistinct() && f.child(
- 0) instanceof Slot);
+ .allMatch(f -> (f instanceof Min || f instanceof Max || f instanceof Sum
+ || (f instanceof Count && !((Count) f).isCountStar())) && !f.isDistinct()
+ && f.child(0) instanceof Slot);
})
.thenApply(ctx -> {
Set enableNereidsRules = ctx.cascadesContext.getConnectContext()
.getSessionVariable().getEnableNereidsRules();
- if (!enableNereidsRules.contains(RuleType.PUSH_DOWN_MIN_MAX_THROUGH_JOIN.type())) {
+ if (!enableNereidsRules.contains(RuleType.PUSH_DOWN_AGG_THROUGH_JOIN_ONE_SIDE.type())) {
return null;
}
LogicalAggregate> agg = ctx.root;
- return pushMinMaxSum(agg, agg.child(), ImmutableList.of());
+ return pushMinMaxSumCount(agg, agg.child(), ImmutableList.of());
})
- .toRule(RuleType.PUSH_DOWN_MIN_MAX_THROUGH_JOIN),
+ .toRule(RuleType.PUSH_DOWN_AGG_THROUGH_JOIN_ONE_SIDE),
logicalAggregate(logicalProject(innerLogicalJoin()))
.when(agg -> agg.child().isAllSlots())
.when(agg -> agg.child().child().getOtherJoinConjuncts().isEmpty())
@@ -91,27 +94,27 @@ public class PushDownMinMaxThroughJoin implements RewriteRuleFactory {
.when(agg -> {
Set funcs = agg.getAggregateFunctions();
return !funcs.isEmpty() && funcs.stream()
- .allMatch(
- f -> (f instanceof Min || f instanceof Max) && !f.isDistinct() && f.child(
- 0) instanceof Slot);
+ .allMatch(f -> (f instanceof Min || f instanceof Max || f instanceof Sum
+ || (f instanceof Count && (!((Count) f).isCountStar()))) && !f.isDistinct()
+ && f.child(0) instanceof Slot);
})
.thenApply(ctx -> {
Set enableNereidsRules = ctx.cascadesContext.getConnectContext()
.getSessionVariable().getEnableNereidsRules();
- if (!enableNereidsRules.contains(RuleType.PUSH_DOWN_MIN_MAX_THROUGH_JOIN.type())) {
+ if (!enableNereidsRules.contains(RuleType.PUSH_DOWN_AGG_THROUGH_JOIN_ONE_SIDE.type())) {
return null;
}
LogicalAggregate>> agg = ctx.root;
- return pushMinMaxSum(agg, agg.child().child(), agg.child().getProjects());
+ return pushMinMaxSumCount(agg, agg.child().child(), agg.child().getProjects());
})
- .toRule(RuleType.PUSH_DOWN_MIN_MAX_THROUGH_JOIN)
+ .toRule(RuleType.PUSH_DOWN_AGG_THROUGH_JOIN_ONE_SIDE)
);
}
/**
* Push down Min/Max/Sum through join.
*/
- public static LogicalAggregate pushMinMaxSum(LogicalAggregate extends Plan> agg,
+ public static LogicalAggregate pushMinMaxSumCount(LogicalAggregate extends Plan> agg,
LogicalJoin join, List projects) {
List leftOutput = join.left().getOutput();
List rightOutput = join.right().getOutput();
@@ -183,21 +186,22 @@ public class PushDownMinMaxThroughJoin implements RewriteRuleFactory {
Preconditions.checkState(left != join.left() || right != join.right());
Plan newJoin = join.withChildren(left, right);
- // top agg
+ // top agg TODO: AVG
// replace
// min(x) -> min(min#)
// max(x) -> max(max#)
// sum(x) -> sum(sum#)
+ // count(x) -> sum(count#)
List newOutputExprs = new ArrayList<>();
for (NamedExpression ne : agg.getOutputExpressions()) {
if (ne instanceof Alias && ((Alias) ne).child() instanceof AggregateFunction) {
AggregateFunction func = (AggregateFunction) ((Alias) ne).child();
Slot slot = (Slot) func.child(0);
if (leftSlotToOutput.containsKey(slot)) {
- Expression newFunc = func.withChildren(leftSlotToOutput.get(slot).toSlot());
+ Expression newFunc = replaceAggFunc(func, leftSlotToOutput.get(slot).toSlot());
newOutputExprs.add((NamedExpression) ne.withChildren(newFunc));
} else if (rightSlotToOutput.containsKey(slot)) {
- Expression newFunc = func.withChildren(rightSlotToOutput.get(slot).toSlot());
+ Expression newFunc = replaceAggFunc(func, rightSlotToOutput.get(slot).toSlot());
newOutputExprs.add((NamedExpression) ne.withChildren(newFunc));
} else {
throw new IllegalStateException("Slot " + slot + " not found in join output");
@@ -210,4 +214,12 @@ public class PushDownMinMaxThroughJoin implements RewriteRuleFactory {
// TODO: column prune project
return agg.withAggOutputChild(newOutputExprs, newJoin);
}
+
+ private static Expression replaceAggFunc(AggregateFunction func, Slot inputSlot) {
+ if (func instanceof Count) {
+ return new Sum(inputSlot);
+ } else {
+ return func.withChildren(inputSlot);
+ }
+ }
}
diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/PushDownCountThroughJoinOneSide.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/PushDownCountThroughJoinOneSide.java
deleted file mode 100644
index 5abe33fb14..0000000000
--- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/PushDownCountThroughJoinOneSide.java
+++ /dev/null
@@ -1,216 +0,0 @@
-// Licensed to the Apache Software Foundation (ASF) under one
-// or more contributor license agreements. See the NOTICE file
-// distributed with this work for additional information
-// regarding copyright ownership. The ASF licenses this file
-// to you under the Apache License, Version 2.0 (the
-// "License"); you may not use this file except in compliance
-// with the License. You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing,
-// software distributed under the License is distributed on an
-// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-// KIND, either express or implied. See the License for the
-// specific language governing permissions and limitations
-// under the License.
-
-package org.apache.doris.nereids.rules.rewrite;
-
-import org.apache.doris.nereids.rules.Rule;
-import org.apache.doris.nereids.rules.RuleType;
-import org.apache.doris.nereids.trees.expressions.Alias;
-import org.apache.doris.nereids.trees.expressions.Expression;
-import org.apache.doris.nereids.trees.expressions.NamedExpression;
-import org.apache.doris.nereids.trees.expressions.Slot;
-import org.apache.doris.nereids.trees.expressions.functions.agg.AggregateFunction;
-import org.apache.doris.nereids.trees.expressions.functions.agg.Count;
-import org.apache.doris.nereids.trees.expressions.functions.agg.Sum;
-import org.apache.doris.nereids.trees.plans.Plan;
-import org.apache.doris.nereids.trees.plans.logical.LogicalAggregate;
-import org.apache.doris.nereids.trees.plans.logical.LogicalJoin;
-import org.apache.doris.nereids.trees.plans.logical.LogicalProject;
-
-import com.google.common.base.Preconditions;
-import com.google.common.collect.ImmutableList;
-import com.google.common.collect.ImmutableList.Builder;
-
-import java.util.ArrayList;
-import java.util.HashMap;
-import java.util.HashSet;
-import java.util.List;
-import java.util.Map;
-import java.util.Set;
-
-/**
- * TODO: distinct | just push one level
- * Support Pushdown Count(col).
- * Count(col) -> Sum( cnt )
- *
- * Related paper "Eager aggregation and lazy aggregation".
- *
- * aggregate: count(x)
- * |
- * join
- * | \
- * | *
- * (x)
- * ->
- * aggregate: Sum( cnt )
- * |
- * join
- * | \
- * | *
- * aggregate: count(x) as cnt
- *
- * Notice: rule can't optimize condition that groupby is empty when Count(*) exists.
- */
-public class PushDownCountThroughJoinOneSide implements RewriteRuleFactory {
- @Override
- public List buildRules() {
- return ImmutableList.of(
- logicalAggregate(innerLogicalJoin())
- .when(agg -> agg.child().getOtherJoinConjuncts().isEmpty())
- .whenNot(agg -> agg.child().children().stream().anyMatch(p -> p instanceof LogicalAggregate))
- .when(agg -> agg.getGroupByExpressions().stream().allMatch(e -> e instanceof Slot))
- .when(agg -> {
- Set funcs = agg.getAggregateFunctions();
- return !funcs.isEmpty() && funcs.stream()
- .allMatch(f -> f instanceof Count && !f.isDistinct()
- && (!((Count) f).isCountStar() && f.child(0) instanceof Slot));
- })
- .thenApply(ctx -> {
- Set enableNereidsRules = ctx.cascadesContext.getConnectContext()
- .getSessionVariable().getEnableNereidsRules();
- if (!enableNereidsRules.contains(RuleType.PUSH_DOWN_COUNT_THROUGH_JOIN_ONE_SIDE.type())) {
- return null;
- }
- LogicalAggregate> agg = ctx.root;
- return pushCount(agg, agg.child(), ImmutableList.of());
- })
- .toRule(RuleType.PUSH_DOWN_COUNT_THROUGH_JOIN_ONE_SIDE),
- logicalAggregate(logicalProject(innerLogicalJoin()))
- .when(agg -> agg.child().isAllSlots())
- .when(agg -> agg.child().child().getOtherJoinConjuncts().isEmpty())
- .whenNot(agg -> agg.child().children().stream().anyMatch(p -> p instanceof LogicalAggregate))
- .when(agg -> agg.getGroupByExpressions().stream().allMatch(e -> e instanceof Slot))
- .when(agg -> {
- Set funcs = agg.getAggregateFunctions();
- return !funcs.isEmpty() && funcs.stream()
- .allMatch(f -> f instanceof Count && !f.isDistinct()
- && (!((Count) f).isCountStar() && f.child(0) instanceof Slot));
- })
- .thenApply(ctx -> {
- Set enableNereidsRules = ctx.cascadesContext.getConnectContext()
- .getSessionVariable().getEnableNereidsRules();
- if (!enableNereidsRules.contains(RuleType.PUSH_DOWN_COUNT_THROUGH_JOIN_ONE_SIDE.type())) {
- return null;
- }
- LogicalAggregate>> agg = ctx.root;
- return pushCount(agg, agg.child().child(), agg.child().getProjects());
- })
- .toRule(RuleType.PUSH_DOWN_COUNT_THROUGH_JOIN_ONE_SIDE)
- );
- }
-
- private LogicalAggregate pushCount(LogicalAggregate extends Plan> agg,
- LogicalJoin join, List projects) {
- List leftOutput = join.left().getOutput();
- List rightOutput = join.right().getOutput();
-
- List leftCounts = new ArrayList<>();
- List rightCounts = new ArrayList<>();
- for (AggregateFunction f : agg.getAggregateFunctions()) {
- Count count = (Count) f;
- Slot slot = (Slot) count.child(0);
- if (leftOutput.contains(slot)) {
- leftCounts.add(count);
- } else if (rightOutput.contains(slot)) {
- rightCounts.add(count);
- } else {
- throw new IllegalStateException("Slot " + slot + " not found in join output");
- }
- }
-
- Set leftGroupBy = new HashSet<>();
- Set rightGroupBy = new HashSet<>();
- for (Expression e : agg.getGroupByExpressions()) {
- Slot slot = (Slot) e;
- if (leftOutput.contains(slot)) {
- leftGroupBy.add(slot);
- } else if (rightOutput.contains(slot)) {
- rightGroupBy.add(slot);
- } else {
- return null;
- }
- }
- join.getHashJoinConjuncts().forEach(e -> e.getInputSlots().forEach(slot -> {
- if (leftOutput.contains(slot)) {
- leftGroupBy.add(slot);
- } else if (rightOutput.contains(slot)) {
- rightGroupBy.add(slot);
- } else {
- throw new IllegalStateException("Slot " + slot + " not found in join output");
- }
- }));
-
- Plan left = join.left();
- Plan right = join.right();
-
- Map leftCntSlotToOutput = new HashMap<>();
- Map rightCntSlotToOutput = new HashMap<>();
-
- // left Count agg
- if (!leftCounts.isEmpty()) {
- Builder leftCntAggOutputBuilder = ImmutableList.builder()
- .addAll(leftGroupBy);
- leftCounts.forEach(func -> {
- Alias alias = func.alias(func.getName());
- leftCntSlotToOutput.put((Slot) func.child(0), alias);
- leftCntAggOutputBuilder.add(alias);
- });
- left = new LogicalAggregate<>(ImmutableList.copyOf(leftGroupBy), leftCntAggOutputBuilder.build(),
- join.left());
- }
-
- // right Count agg
- if (!rightCounts.isEmpty()) {
- Builder rightCntAggOutputBuilder = ImmutableList.builder()
- .addAll(rightGroupBy);
- rightCounts.forEach(func -> {
- Alias alias = func.alias(func.getName());
- rightCntSlotToOutput.put((Slot) func.child(0), alias);
- rightCntAggOutputBuilder.add(alias);
- });
-
- right = new LogicalAggregate<>(ImmutableList.copyOf(rightGroupBy), rightCntAggOutputBuilder.build(),
- join.right());
- }
-
- Preconditions.checkState(left != join.left() || right != join.right());
- Plan newJoin = join.withChildren(left, right);
-
- // top Sum agg
- // count(slot) -> sum( count(slot) as cnt )
- List newOutputExprs = new ArrayList<>();
- for (NamedExpression ne : agg.getOutputExpressions()) {
- if (ne instanceof Alias && ((Alias) ne).child() instanceof Count) {
- Count oldTopCnt = (Count) ((Alias) ne).child();
-
- Slot slot = (Slot) oldTopCnt.child(0);
- if (leftCntSlotToOutput.containsKey(slot)) {
- Expression expr = new Sum(leftCntSlotToOutput.get(slot).toSlot());
- newOutputExprs.add((NamedExpression) ne.withChildren(expr));
- } else if (rightCntSlotToOutput.containsKey(slot)) {
- Expression expr = new Sum(rightCntSlotToOutput.get(slot).toSlot());
- newOutputExprs.add((NamedExpression) ne.withChildren(expr));
- } else {
- throw new IllegalStateException("Slot " + slot + " not found in join output");
- }
- } else {
- newOutputExprs.add(ne);
- }
- }
- return agg.withAggOutputChild(newOutputExprs, newJoin);
- }
-}
diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/PushDownSumThroughJoinOneSide.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/PushDownSumThroughJoinOneSide.java
deleted file mode 100644
index 88b13b383a..0000000000
--- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/PushDownSumThroughJoinOneSide.java
+++ /dev/null
@@ -1,98 +0,0 @@
-// Licensed to the Apache Software Foundation (ASF) under one
-// or more contributor license agreements. See the NOTICE file
-// distributed with this work for additional information
-// regarding copyright ownership. The ASF licenses this file
-// to you under the Apache License, Version 2.0 (the
-// "License"); you may not use this file except in compliance
-// with the License. You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing,
-// software distributed under the License is distributed on an
-// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-// KIND, either express or implied. See the License for the
-// specific language governing permissions and limitations
-// under the License.
-
-package org.apache.doris.nereids.rules.rewrite;
-
-import org.apache.doris.nereids.rules.Rule;
-import org.apache.doris.nereids.rules.RuleType;
-import org.apache.doris.nereids.trees.expressions.Slot;
-import org.apache.doris.nereids.trees.expressions.functions.agg.AggregateFunction;
-import org.apache.doris.nereids.trees.expressions.functions.agg.Sum;
-import org.apache.doris.nereids.trees.plans.Plan;
-import org.apache.doris.nereids.trees.plans.logical.LogicalAggregate;
-import org.apache.doris.nereids.trees.plans.logical.LogicalJoin;
-import org.apache.doris.nereids.trees.plans.logical.LogicalProject;
-
-import com.google.common.collect.ImmutableList;
-
-import java.util.List;
-import java.util.Set;
-
-/**
- * TODO: distinct
- * Related paper "Eager aggregation and lazy aggregation".
- *
- * aggregate: Sum(x)
- * |
- * join
- * | \
- * | *
- * (x)
- * ->
- * aggregate: Sum(sum1)
- * |
- * join
- * | \
- * | *
- * aggregate: Sum(x) as sum1
- *
- */
-public class PushDownSumThroughJoinOneSide implements RewriteRuleFactory {
- @Override
- public List buildRules() {
- return ImmutableList.of(
- logicalAggregate(innerLogicalJoin())
- .when(agg -> agg.child().getOtherJoinConjuncts().isEmpty())
- .whenNot(agg -> agg.child().children().stream().anyMatch(p -> p instanceof LogicalAggregate))
- .when(agg -> {
- Set funcs = agg.getAggregateFunctions();
- return !funcs.isEmpty() && funcs.stream()
- .allMatch(f -> f instanceof Sum && !f.isDistinct() && f.child(0) instanceof Slot);
- })
- .thenApply(ctx -> {
- Set enableNereidsRules = ctx.cascadesContext.getConnectContext()
- .getSessionVariable().getEnableNereidsRules();
- if (!enableNereidsRules.contains(RuleType.PUSH_DOWN_SUM_THROUGH_JOIN_ONE_SIDE.type())) {
- return null;
- }
- LogicalAggregate> agg = ctx.root;
- return PushDownMinMaxThroughJoin.pushMinMaxSum(agg, agg.child(), ImmutableList.of());
- })
- .toRule(RuleType.PUSH_DOWN_SUM_THROUGH_JOIN),
- logicalAggregate(logicalProject(innerLogicalJoin()))
- .when(agg -> agg.child().isAllSlots())
- .when(agg -> agg.child().child().getOtherJoinConjuncts().isEmpty())
- .whenNot(agg -> agg.child().children().stream().anyMatch(p -> p instanceof LogicalAggregate))
- .when(agg -> {
- Set funcs = agg.getAggregateFunctions();
- return !funcs.isEmpty() && funcs.stream()
- .allMatch(f -> f instanceof Sum && !f.isDistinct() && f.child(0) instanceof Slot);
- })
- .thenApply(ctx -> {
- Set enableNereidsRules = ctx.cascadesContext.getConnectContext()
- .getSessionVariable().getEnableNereidsRules();
- if (!enableNereidsRules.contains(RuleType.PUSH_DOWN_SUM_THROUGH_JOIN_ONE_SIDE.type())) {
- return null;
- }
- LogicalAggregate>> agg = ctx.root;
- return PushDownMinMaxThroughJoin.pushMinMaxSum(agg, agg.child().child(),
- agg.child().getProjects());
- })
- .toRule(RuleType.PUSH_DOWN_SUM_THROUGH_JOIN)
- );
- }
-}
diff --git a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/PushDownCountThroughJoinOneSideTest.java b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/PushDownCountThroughJoinOneSideTest.java
deleted file mode 100644
index 3106eb30f4..0000000000
--- a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/PushDownCountThroughJoinOneSideTest.java
+++ /dev/null
@@ -1,139 +0,0 @@
-// Licensed to the Apache Software Foundation (ASF) under one
-// or more contributor license agreements. See the NOTICE file
-// distributed with this work for additional information
-// regarding copyright ownership. The ASF licenses this file
-// to you under the Apache License, Version 2.0 (the
-// "License"); you may not use this file except in compliance
-// with the License. You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing,
-// software distributed under the License is distributed on an
-// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-// KIND, either express or implied. See the License for the
-// specific language governing permissions and limitations
-// under the License.
-
-package org.apache.doris.nereids.rules.rewrite;
-
-import org.apache.doris.common.Pair;
-import org.apache.doris.nereids.rules.RuleType;
-import org.apache.doris.nereids.trees.expressions.Alias;
-import org.apache.doris.nereids.trees.expressions.functions.agg.Count;
-import org.apache.doris.nereids.trees.plans.JoinType;
-import org.apache.doris.nereids.trees.plans.logical.LogicalOlapScan;
-import org.apache.doris.nereids.trees.plans.logical.LogicalPlan;
-import org.apache.doris.nereids.util.LogicalPlanBuilder;
-import org.apache.doris.nereids.util.MemoPatternMatchSupported;
-import org.apache.doris.nereids.util.MemoTestUtils;
-import org.apache.doris.nereids.util.PlanChecker;
-import org.apache.doris.nereids.util.PlanConstructor;
-import org.apache.doris.qe.SessionVariable;
-
-import com.google.common.collect.ImmutableList;
-import com.google.common.collect.ImmutableSet;
-import mockit.Mock;
-import mockit.MockUp;
-import org.junit.jupiter.api.Test;
-
-import java.util.Set;
-
-class PushDownCountThroughJoinOneSideTest implements MemoPatternMatchSupported {
- private static final LogicalOlapScan scan1 = PlanConstructor.newLogicalOlapScan(0, "t1", 0);
- private static final LogicalOlapScan scan2 = PlanConstructor.newLogicalOlapScan(1, "t2", 0);
- private MockUp mockUp = new MockUp() {
- @Mock
- public Set getEnableNereidsRules() {
- return ImmutableSet.of(RuleType.PUSH_DOWN_COUNT_THROUGH_JOIN_ONE_SIDE.type());
- }
- };
-
- @Test
- void testSingleCount() {
- Alias count = new Count(scan1.getOutput().get(0)).alias("count");
- LogicalPlan plan = new LogicalPlanBuilder(scan1)
- .join(scan2, JoinType.INNER_JOIN, Pair.of(0, 0))
- .aggGroupUsingIndex(ImmutableList.of(0), ImmutableList.of(scan1.getOutput().get(0), count))
- .build();
-
- PlanChecker.from(MemoTestUtils.createConnectContext(), plan)
- .applyTopDown(new PushDownCountThroughJoinOneSide())
- .printlnTree()
- .matches(
- logicalAggregate(
- logicalJoin(
- logicalAggregate(),
- logicalOlapScan()
- )
- )
- );
- }
-
- @Test
- void testMultiCount() {
- Alias leftCnt1 = new Count(scan1.getOutput().get(0)).alias("leftCnt1");
- Alias leftCnt2 = new Count(scan1.getOutput().get(1)).alias("leftCnt2");
- Alias rightCnt1 = new Count(scan2.getOutput().get(1)).alias("rightCnt1");
- LogicalPlan plan = new LogicalPlanBuilder(scan1)
- .join(scan2, JoinType.INNER_JOIN, Pair.of(0, 0))
- .aggGroupUsingIndex(ImmutableList.of(0),
- ImmutableList.of(scan1.getOutput().get(0), leftCnt1, leftCnt2, rightCnt1))
- .build();
-
- PlanChecker.from(MemoTestUtils.createConnectContext(), plan)
- .applyTopDown(new PushDownCountThroughJoinOneSide())
- .matches(
- logicalAggregate(
- logicalJoin(
- logicalAggregate(),
- logicalAggregate()
- )
- )
- );
- }
-
- @Test
- void testSingleCountStar() {
- Alias count = new Count().alias("countStar");
- LogicalPlan plan = new LogicalPlanBuilder(scan1)
- .join(scan2, JoinType.INNER_JOIN, Pair.of(0, 0))
- .aggGroupUsingIndex(ImmutableList.of(0), ImmutableList.of(scan1.getOutput().get(0), count))
- .build();
-
- PlanChecker.from(MemoTestUtils.createConnectContext(), plan)
- .applyTopDown(new PushDownCountThroughJoinOneSide())
- .printlnTree()
- .matches(
- logicalAggregate(
- logicalJoin(
- logicalOlapScan(),
- logicalOlapScan()
- )
- )
- );
- }
-
- @Test
- void testBothSideCountAndCountStar() {
- Alias leftCnt = new Count(scan1.getOutput().get(0)).alias("leftCnt");
- Alias rightCnt = new Count(scan2.getOutput().get(0)).alias("rightCnt");
- Alias countStar = new Count().alias("countStar");
- LogicalPlan plan = new LogicalPlanBuilder(scan1)
- .join(scan2, JoinType.INNER_JOIN, Pair.of(0, 0))
- .aggGroupUsingIndex(ImmutableList.of(0),
- ImmutableList.of(scan1.getOutput().get(0), leftCnt, rightCnt, countStar))
- .build();
-
- PlanChecker.from(MemoTestUtils.createConnectContext(), plan)
- .applyTopDown(new PushDownCountThroughJoinOneSide())
- .matches(
- logicalAggregate(
- logicalJoin(
- logicalOlapScan(),
- logicalOlapScan()
- )
- )
- );
- }
-}
diff --git a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/PushDownMinMaxSumThroughJoinTest.java b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/PushDownMinMaxSumThroughJoinTest.java
new file mode 100644
index 0000000000..58ab7fbe9e
--- /dev/null
+++ b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/PushDownMinMaxSumThroughJoinTest.java
@@ -0,0 +1,357 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+package org.apache.doris.nereids.rules.rewrite;
+
+import org.apache.doris.common.Pair;
+import org.apache.doris.nereids.rules.RuleType;
+import org.apache.doris.nereids.trees.expressions.Alias;
+import org.apache.doris.nereids.trees.expressions.functions.agg.Count;
+import org.apache.doris.nereids.trees.expressions.functions.agg.Max;
+import org.apache.doris.nereids.trees.expressions.functions.agg.Min;
+import org.apache.doris.nereids.trees.expressions.functions.agg.Sum;
+import org.apache.doris.nereids.trees.plans.JoinType;
+import org.apache.doris.nereids.trees.plans.logical.LogicalOlapScan;
+import org.apache.doris.nereids.trees.plans.logical.LogicalPlan;
+import org.apache.doris.nereids.util.LogicalPlanBuilder;
+import org.apache.doris.nereids.util.MemoPatternMatchSupported;
+import org.apache.doris.nereids.util.MemoTestUtils;
+import org.apache.doris.nereids.util.PlanChecker;
+import org.apache.doris.nereids.util.PlanConstructor;
+import org.apache.doris.qe.SessionVariable;
+
+import com.google.common.collect.ImmutableList;
+import com.google.common.collect.ImmutableSet;
+import mockit.Mock;
+import mockit.MockUp;
+import org.junit.jupiter.api.Test;
+
+import java.util.Set;
+
+class PushDownMinMaxSumThroughJoinTest implements MemoPatternMatchSupported {
+ private final LogicalOlapScan scan1 = PlanConstructor.newLogicalOlapScan(0, "t1", 0);
+ private final LogicalOlapScan scan2 = PlanConstructor.newLogicalOlapScan(1, "t2", 0);
+ private final LogicalOlapScan scan3 = PlanConstructor.newLogicalOlapScan(2, "t3", 0);
+ private final LogicalOlapScan scan4 = PlanConstructor.newLogicalOlapScan(3, "t4", 0);
+ private MockUp mockUp = new MockUp() {
+ @Mock
+ public Set getEnableNereidsRules() {
+ return ImmutableSet.of(RuleType.PUSH_DOWN_AGG_THROUGH_JOIN_ONE_SIDE.type());
+ }
+ };
+
+ @Test
+ void testSingleJoin() {
+ Alias min = new Min(scan1.getOutput().get(0)).alias("min");
+ LogicalPlan plan = new LogicalPlanBuilder(scan1)
+ .join(scan2, JoinType.INNER_JOIN, Pair.of(0, 0))
+ .aggGroupUsingIndex(ImmutableList.of(0), ImmutableList.of(scan1.getOutput().get(0), min))
+ .build();
+
+ PlanChecker.from(MemoTestUtils.createConnectContext(), plan)
+ .applyTopDown(new PushDownAggThroughJoinOneSide())
+ .matches(
+ logicalAggregate(
+ logicalJoin(
+ logicalAggregate(),
+ logicalOlapScan()
+ )
+ )
+ );
+ }
+
+ @Test
+ void testMultiJoin() {
+ Alias min = new Min(scan1.getOutput().get(0)).alias("min");
+ LogicalPlan plan = new LogicalPlanBuilder(scan1)
+ .join(scan2, JoinType.INNER_JOIN, Pair.of(0, 0))
+ .join(scan3, JoinType.INNER_JOIN, Pair.of(0, 0))
+ .join(scan4, JoinType.INNER_JOIN, Pair.of(0, 0))
+ .aggGroupUsingIndex(ImmutableList.of(0), ImmutableList.of(scan1.getOutput().get(0), min))
+ .build();
+
+ PlanChecker.from(MemoTestUtils.createConnectContext(), plan)
+ .applyTopDown(new PushDownAggThroughJoinOneSide())
+ .printlnTree()
+ .matches(
+ logicalAggregate(
+ logicalJoin(
+ logicalAggregate(
+ logicalJoin(
+ logicalAggregate(
+ logicalJoin(
+ logicalAggregate(),
+ logicalOlapScan()
+ )
+ ),
+ logicalOlapScan()
+ )
+ ),
+ logicalOlapScan()
+ )
+ )
+ );
+ }
+
+ @Test
+ void testAggNotOutputGroupBy() {
+ // agg don't output group by
+ Alias min = new Min(scan1.getOutput().get(0)).alias("min");
+ LogicalPlan plan = new LogicalPlanBuilder(scan1)
+ .join(scan2, JoinType.INNER_JOIN, Pair.of(0, 0))
+ .join(scan3, JoinType.INNER_JOIN, Pair.of(0, 0))
+ .aggGroupUsingIndex(ImmutableList.of(0), ImmutableList.of(min))
+ .build();
+
+ PlanChecker.from(MemoTestUtils.createConnectContext(), plan)
+ .applyTopDown(new PushDownAggThroughJoinOneSide())
+ .matches(
+ logicalAggregate(
+ logicalJoin(
+ logicalAggregate(
+ logicalJoin(
+ logicalAggregate(),
+ logicalOlapScan()
+ )
+ ),
+ logicalOlapScan()
+ )
+ )
+ );
+ }
+
+ @Test
+ void testBothSideSingleJoin() {
+ Alias min = new Min(scan1.getOutput().get(1)).alias("min");
+ Alias max = new Max(scan2.getOutput().get(1)).alias("max");
+ LogicalPlan plan = new LogicalPlanBuilder(scan1)
+ .join(scan2, JoinType.INNER_JOIN, Pair.of(0, 0))
+ .aggGroupUsingIndex(ImmutableList.of(0), ImmutableList.of(scan1.getOutput().get(0), min, max))
+ .build();
+
+ PlanChecker.from(MemoTestUtils.createConnectContext(), plan)
+ .printlnTree()
+ .applyTopDown(new PushDownAggThroughJoinOneSide())
+ .matches(
+ logicalAggregate(
+ logicalJoin(
+ logicalAggregate(),
+ logicalAggregate()
+ )
+ )
+ );
+ }
+
+ @Test
+ void testBothSide() {
+ Alias min = new Min(scan1.getOutput().get(1)).alias("min");
+ Alias max = new Max(scan3.getOutput().get(1)).alias("max");
+ LogicalPlan plan = new LogicalPlanBuilder(scan1)
+ .join(scan2, JoinType.INNER_JOIN, Pair.of(0, 0))
+ .join(scan3, JoinType.INNER_JOIN, Pair.of(0, 0))
+ .aggGroupUsingIndex(ImmutableList.of(0), ImmutableList.of(min, max))
+ .build();
+
+ PlanChecker.from(MemoTestUtils.createConnectContext(), plan)
+ .applyTopDown(new PushDownAggThroughJoinOneSide())
+ .matches(
+ logicalAggregate(
+ logicalJoin(
+ logicalAggregate(
+ logicalJoin(
+ logicalAggregate(),
+ logicalOlapScan()
+ )
+ ),
+ logicalAggregate()
+ )
+ )
+ );
+ }
+
+ @Test
+ void testSingleJoinLeftSum() {
+ Alias sum = new Sum(scan1.getOutput().get(1)).alias("sum");
+ LogicalPlan plan = new LogicalPlanBuilder(scan1)
+ .join(scan2, JoinType.INNER_JOIN, Pair.of(0, 0))
+ .aggGroupUsingIndex(ImmutableList.of(0), ImmutableList.of(scan1.getOutput().get(0), sum))
+ .build();
+
+ PlanChecker.from(MemoTestUtils.createConnectContext(), plan)
+ .applyTopDown(new PushDownAggThroughJoinOneSide())
+ .matches(
+ logicalAggregate(
+ logicalJoin(
+ logicalAggregate(),
+ logicalOlapScan()
+ )
+ )
+ );
+ }
+
+ @Test
+ void testSingleJoinRightSum() {
+ Alias sum = new Sum(scan2.getOutput().get(1)).alias("sum");
+ LogicalPlan plan = new LogicalPlanBuilder(scan1)
+ .join(scan2, JoinType.INNER_JOIN, Pair.of(0, 0))
+ .aggGroupUsingIndex(ImmutableList.of(0), ImmutableList.of(scan1.getOutput().get(0), sum))
+ .build();
+
+ PlanChecker.from(MemoTestUtils.createConnectContext(), plan)
+ .applyTopDown(new PushDownAggThroughJoinOneSide())
+ .matches(
+ logicalAggregate(
+ logicalJoin(
+ logicalOlapScan(),
+ logicalAggregate()
+ )
+ )
+ );
+ }
+
+ @Test
+ void testSumAggNotOutputGroupBy() {
+ // agg don't output group by
+ Alias sum = new Sum(scan1.getOutput().get(1)).alias("sum");
+ LogicalPlan plan = new LogicalPlanBuilder(scan1)
+ .join(scan2, JoinType.INNER_JOIN, Pair.of(0, 0))
+ .aggGroupUsingIndex(ImmutableList.of(0), ImmutableList.of(sum))
+ .build();
+
+ PlanChecker.from(MemoTestUtils.createConnectContext(), plan)
+ .applyTopDown(new PushDownAggThroughJoinOneSide())
+ .matches(
+ logicalAggregate(
+ logicalJoin(
+ logicalAggregate(),
+ logicalOlapScan()
+ )
+ )
+ );
+ }
+
+ @Test
+ void testMultiSum() {
+ Alias leftSum1 = new Sum(scan1.getOutput().get(0)).alias("leftSum1");
+ Alias leftSum2 = new Sum(scan1.getOutput().get(1)).alias("leftSum2");
+ Alias rightSum1 = new Sum(scan2.getOutput().get(1)).alias("rightSum1");
+ LogicalPlan plan = new LogicalPlanBuilder(scan1)
+ .join(scan2, JoinType.INNER_JOIN, Pair.of(0, 0))
+ .aggGroupUsingIndex(ImmutableList.of(0),
+ ImmutableList.of(scan1.getOutput().get(0), leftSum1, leftSum2, rightSum1))
+ .build();
+
+ PlanChecker.from(MemoTestUtils.createConnectContext(), plan)
+ .applyTopDown(new PushDownAggThroughJoinOneSide())
+ .matches(
+ logicalAggregate(
+ logicalJoin(
+ logicalAggregate(),
+ logicalAggregate()
+ )
+ )
+ );
+ }
+
+ @Test
+ void testSingleCount() {
+ Alias count = new Count(scan1.getOutput().get(0)).alias("count");
+ LogicalPlan plan = new LogicalPlanBuilder(scan1)
+ .join(scan2, JoinType.INNER_JOIN, Pair.of(0, 0))
+ .aggGroupUsingIndex(ImmutableList.of(0), ImmutableList.of(scan1.getOutput().get(0), count))
+ .build();
+
+ PlanChecker.from(MemoTestUtils.createConnectContext(), plan)
+ .applyTopDown(new PushDownAggThroughJoinOneSide())
+ .printlnTree()
+ .matches(
+ logicalAggregate(
+ logicalJoin(
+ logicalAggregate(),
+ logicalOlapScan()
+ )
+ )
+ );
+ }
+
+ @Test
+ void testMultiCount() {
+ Alias leftCnt1 = new Count(scan1.getOutput().get(0)).alias("leftCnt1");
+ Alias leftCnt2 = new Count(scan1.getOutput().get(1)).alias("leftCnt2");
+ Alias rightCnt1 = new Count(scan2.getOutput().get(1)).alias("rightCnt1");
+ LogicalPlan plan = new LogicalPlanBuilder(scan1)
+ .join(scan2, JoinType.INNER_JOIN, Pair.of(0, 0))
+ .aggGroupUsingIndex(ImmutableList.of(0),
+ ImmutableList.of(scan1.getOutput().get(0), leftCnt1, leftCnt2, rightCnt1))
+ .build();
+
+ PlanChecker.from(MemoTestUtils.createConnectContext(), plan)
+ .applyTopDown(new PushDownAggThroughJoinOneSide())
+ .matches(
+ logicalAggregate(
+ logicalJoin(
+ logicalAggregate(),
+ logicalAggregate()
+ )
+ )
+ );
+ }
+
+ @Test
+ void testSingleCountStar() {
+ Alias count = new Count().alias("countStar");
+ LogicalPlan plan = new LogicalPlanBuilder(scan1)
+ .join(scan2, JoinType.INNER_JOIN, Pair.of(0, 0))
+ .aggGroupUsingIndex(ImmutableList.of(0), ImmutableList.of(scan1.getOutput().get(0), count))
+ .build();
+
+ PlanChecker.from(MemoTestUtils.createConnectContext(), plan)
+ .applyTopDown(new PushDownAggThroughJoinOneSide())
+ .printlnTree()
+ .matches(
+ logicalAggregate(
+ logicalJoin(
+ logicalOlapScan(),
+ logicalOlapScan()
+ )
+ )
+ );
+ }
+
+ @Test
+ void testBothSideCountAndCountStar() {
+ Alias leftCnt = new Count(scan1.getOutput().get(0)).alias("leftCnt");
+ Alias rightCnt = new Count(scan2.getOutput().get(0)).alias("rightCnt");
+ Alias countStar = new Count().alias("countStar");
+ LogicalPlan plan = new LogicalPlanBuilder(scan1)
+ .join(scan2, JoinType.INNER_JOIN, Pair.of(0, 0))
+ .aggGroupUsingIndex(ImmutableList.of(0),
+ ImmutableList.of(scan1.getOutput().get(0), leftCnt, rightCnt, countStar))
+ .build();
+
+ PlanChecker.from(MemoTestUtils.createConnectContext(), plan)
+ .applyTopDown(new PushDownAggThroughJoinOneSide())
+ .matches(
+ logicalAggregate(
+ logicalJoin(
+ logicalOlapScan(),
+ logicalOlapScan()
+ )
+ )
+ );
+ }
+}
diff --git a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/PushDownMinMaxThroughJoinTest.java b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/PushDownMinMaxThroughJoinTest.java
deleted file mode 100644
index cf28954a47..0000000000
--- a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/PushDownMinMaxThroughJoinTest.java
+++ /dev/null
@@ -1,183 +0,0 @@
-// Licensed to the Apache Software Foundation (ASF) under one
-// or more contributor license agreements. See the NOTICE file
-// distributed with this work for additional information
-// regarding copyright ownership. The ASF licenses this file
-// to you under the Apache License, Version 2.0 (the
-// "License"); you may not use this file except in compliance
-// with the License. You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing,
-// software distributed under the License is distributed on an
-// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-// KIND, either express or implied. See the License for the
-// specific language governing permissions and limitations
-// under the License.
-
-package org.apache.doris.nereids.rules.rewrite;
-
-import org.apache.doris.common.Pair;
-import org.apache.doris.nereids.rules.RuleType;
-import org.apache.doris.nereids.trees.expressions.Alias;
-import org.apache.doris.nereids.trees.expressions.functions.agg.Max;
-import org.apache.doris.nereids.trees.expressions.functions.agg.Min;
-import org.apache.doris.nereids.trees.plans.JoinType;
-import org.apache.doris.nereids.trees.plans.logical.LogicalOlapScan;
-import org.apache.doris.nereids.trees.plans.logical.LogicalPlan;
-import org.apache.doris.nereids.util.LogicalPlanBuilder;
-import org.apache.doris.nereids.util.MemoPatternMatchSupported;
-import org.apache.doris.nereids.util.MemoTestUtils;
-import org.apache.doris.nereids.util.PlanChecker;
-import org.apache.doris.nereids.util.PlanConstructor;
-import org.apache.doris.qe.SessionVariable;
-
-import com.google.common.collect.ImmutableList;
-import com.google.common.collect.ImmutableSet;
-import mockit.Mock;
-import mockit.MockUp;
-import org.junit.jupiter.api.Test;
-
-import java.util.Set;
-
-class PushDownMinMaxThroughJoinTest implements MemoPatternMatchSupported {
- private static final LogicalOlapScan scan1 = PlanConstructor.newLogicalOlapScan(0, "t1", 0);
- private static final LogicalOlapScan scan2 = PlanConstructor.newLogicalOlapScan(1, "t2", 0);
- private static final LogicalOlapScan scan3 = PlanConstructor.newLogicalOlapScan(2, "t3", 0);
- private static final LogicalOlapScan scan4 = PlanConstructor.newLogicalOlapScan(3, "t4", 0);
- private MockUp mockUp = new MockUp() {
- @Mock
- public Set getEnableNereidsRules() {
- return ImmutableSet.of(RuleType.PUSH_DOWN_MIN_MAX_THROUGH_JOIN.type());
- }
- };
-
- @Test
- void testSingleJoin() {
- Alias min = new Min(scan1.getOutput().get(0)).alias("min");
- LogicalPlan plan = new LogicalPlanBuilder(scan1)
- .join(scan2, JoinType.INNER_JOIN, Pair.of(0, 0))
- .aggGroupUsingIndex(ImmutableList.of(0), ImmutableList.of(scan1.getOutput().get(0), min))
- .build();
-
- PlanChecker.from(MemoTestUtils.createConnectContext(), plan)
- .applyTopDown(new PushDownMinMaxThroughJoin())
- .matches(
- logicalAggregate(
- logicalJoin(
- logicalAggregate(),
- logicalOlapScan()
- )
- )
- );
- }
-
- @Test
- void testMultiJoin() {
- Alias min = new Min(scan1.getOutput().get(0)).alias("min");
- LogicalPlan plan = new LogicalPlanBuilder(scan1)
- .join(scan2, JoinType.INNER_JOIN, Pair.of(0, 0))
- .join(scan3, JoinType.INNER_JOIN, Pair.of(0, 0))
- .join(scan4, JoinType.INNER_JOIN, Pair.of(0, 0))
- .aggGroupUsingIndex(ImmutableList.of(0), ImmutableList.of(scan1.getOutput().get(0), min))
- .build();
-
- PlanChecker.from(MemoTestUtils.createConnectContext(), plan)
- .applyTopDown(new PushDownMinMaxThroughJoin())
- .printlnTree()
- .matches(
- logicalAggregate(
- logicalJoin(
- logicalAggregate(
- logicalJoin(
- logicalAggregate(
- logicalJoin(
- logicalAggregate(),
- logicalOlapScan()
- )
- ),
- logicalOlapScan()
- )
- ),
- logicalOlapScan()
- )
- )
- );
- }
-
- @Test
- void testAggNotOutputGroupBy() {
- // agg don't output group by
- Alias min = new Min(scan1.getOutput().get(0)).alias("min");
- LogicalPlan plan = new LogicalPlanBuilder(scan1)
- .join(scan2, JoinType.INNER_JOIN, Pair.of(0, 0))
- .join(scan3, JoinType.INNER_JOIN, Pair.of(0, 0))
- .aggGroupUsingIndex(ImmutableList.of(0), ImmutableList.of(min))
- .build();
-
- PlanChecker.from(MemoTestUtils.createConnectContext(), plan)
- .applyTopDown(new PushDownMinMaxThroughJoin())
- .matches(
- logicalAggregate(
- logicalJoin(
- logicalAggregate(
- logicalJoin(
- logicalAggregate(),
- logicalOlapScan()
- )
- ),
- logicalOlapScan()
- )
- )
- );
- }
-
- @Test
- void testBothSideSingleJoin() {
- Alias min = new Min(scan1.getOutput().get(1)).alias("min");
- Alias max = new Max(scan2.getOutput().get(1)).alias("max");
- LogicalPlan plan = new LogicalPlanBuilder(scan1)
- .join(scan2, JoinType.INNER_JOIN, Pair.of(0, 0))
- .aggGroupUsingIndex(ImmutableList.of(0), ImmutableList.of(scan1.getOutput().get(0), min, max))
- .build();
-
- PlanChecker.from(MemoTestUtils.createConnectContext(), plan)
- .printlnTree()
- .applyTopDown(new PushDownMinMaxThroughJoin())
- .matches(
- logicalAggregate(
- logicalJoin(
- logicalAggregate(),
- logicalAggregate()
- )
- )
- );
- }
-
- @Test
- void testBothSide() {
- Alias min = new Min(scan1.getOutput().get(1)).alias("min");
- Alias max = new Max(scan3.getOutput().get(1)).alias("max");
- LogicalPlan plan = new LogicalPlanBuilder(scan1)
- .join(scan2, JoinType.INNER_JOIN, Pair.of(0, 0))
- .join(scan3, JoinType.INNER_JOIN, Pair.of(0, 0))
- .aggGroupUsingIndex(ImmutableList.of(0), ImmutableList.of(min, max))
- .build();
-
- PlanChecker.from(MemoTestUtils.createConnectContext(), plan)
- .applyTopDown(new PushDownMinMaxThroughJoin())
- .matches(
- logicalAggregate(
- logicalJoin(
- logicalAggregate(
- logicalJoin(
- logicalAggregate(),
- logicalOlapScan()
- )
- ),
- logicalAggregate()
- )
- )
- );
- }
-}
diff --git a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/PushDownSumThroughJoinOneSideTest.java b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/PushDownSumThroughJoinOneSideTest.java
deleted file mode 100644
index 2e0f124b81..0000000000
--- a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/PushDownSumThroughJoinOneSideTest.java
+++ /dev/null
@@ -1,135 +0,0 @@
-// Licensed to the Apache Software Foundation (ASF) under one
-// or more contributor license agreements. See the NOTICE file
-// distributed with this work for additional information
-// regarding copyright ownership. The ASF licenses this file
-// to you under the Apache License, Version 2.0 (the
-// "License"); you may not use this file except in compliance
-// with the License. You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing,
-// software distributed under the License is distributed on an
-// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-// KIND, either express or implied. See the License for the
-// specific language governing permissions and limitations
-// under the License.
-
-package org.apache.doris.nereids.rules.rewrite;
-
-import org.apache.doris.common.Pair;
-import org.apache.doris.nereids.rules.RuleType;
-import org.apache.doris.nereids.trees.expressions.Alias;
-import org.apache.doris.nereids.trees.expressions.functions.agg.Sum;
-import org.apache.doris.nereids.trees.plans.JoinType;
-import org.apache.doris.nereids.trees.plans.logical.LogicalOlapScan;
-import org.apache.doris.nereids.trees.plans.logical.LogicalPlan;
-import org.apache.doris.nereids.util.LogicalPlanBuilder;
-import org.apache.doris.nereids.util.MemoPatternMatchSupported;
-import org.apache.doris.nereids.util.MemoTestUtils;
-import org.apache.doris.nereids.util.PlanChecker;
-import org.apache.doris.nereids.util.PlanConstructor;
-import org.apache.doris.qe.SessionVariable;
-
-import com.google.common.collect.ImmutableList;
-import com.google.common.collect.ImmutableSet;
-import mockit.Mock;
-import mockit.MockUp;
-import org.junit.jupiter.api.Test;
-
-import java.util.Set;
-
-class PushDownSumThroughJoinOneSideTest implements MemoPatternMatchSupported {
- private static final LogicalOlapScan scan1 = PlanConstructor.newLogicalOlapScan(0, "t1", 0);
- private static final LogicalOlapScan scan2 = PlanConstructor.newLogicalOlapScan(1, "t2", 0);
- private MockUp mockUp = new MockUp() {
- @Mock
- public Set getEnableNereidsRules() {
- return ImmutableSet.of(RuleType.PUSH_DOWN_SUM_THROUGH_JOIN_ONE_SIDE.type());
- }
- };
-
- @Test
- void testSingleJoinLeftSum() {
- Alias sum = new Sum(scan1.getOutput().get(1)).alias("sum");
- LogicalPlan plan = new LogicalPlanBuilder(scan1)
- .join(scan2, JoinType.INNER_JOIN, Pair.of(0, 0))
- .aggGroupUsingIndex(ImmutableList.of(0), ImmutableList.of(scan1.getOutput().get(0), sum))
- .build();
-
- PlanChecker.from(MemoTestUtils.createConnectContext(), plan)
- .applyTopDown(new PushDownSumThroughJoinOneSide())
- .matches(
- logicalAggregate(
- logicalJoin(
- logicalAggregate(),
- logicalOlapScan()
- )
- )
- );
- }
-
- @Test
- void testSingleJoinRightSum() {
- Alias sum = new Sum(scan2.getOutput().get(1)).alias("sum");
- LogicalPlan plan = new LogicalPlanBuilder(scan1)
- .join(scan2, JoinType.INNER_JOIN, Pair.of(0, 0))
- .aggGroupUsingIndex(ImmutableList.of(0), ImmutableList.of(scan1.getOutput().get(0), sum))
- .build();
-
- PlanChecker.from(MemoTestUtils.createConnectContext(), plan)
- .applyTopDown(new PushDownSumThroughJoinOneSide())
- .matches(
- logicalAggregate(
- logicalJoin(
- logicalOlapScan(),
- logicalAggregate()
- )
- )
- );
- }
-
- @Test
- void testAggNotOutputGroupBy() {
- // agg don't output group by
- Alias sum = new Sum(scan1.getOutput().get(1)).alias("sum");
- LogicalPlan plan = new LogicalPlanBuilder(scan1)
- .join(scan2, JoinType.INNER_JOIN, Pair.of(0, 0))
- .aggGroupUsingIndex(ImmutableList.of(0), ImmutableList.of(sum))
- .build();
-
- PlanChecker.from(MemoTestUtils.createConnectContext(), plan)
- .applyTopDown(new PushDownSumThroughJoinOneSide())
- .matches(
- logicalAggregate(
- logicalJoin(
- logicalAggregate(),
- logicalOlapScan()
- )
- )
- );
- }
-
- @Test
- void testMultiSum() {
- Alias leftSum1 = new Sum(scan1.getOutput().get(0)).alias("leftSum1");
- Alias leftSum2 = new Sum(scan1.getOutput().get(1)).alias("leftSum2");
- Alias rightSum1 = new Sum(scan2.getOutput().get(1)).alias("rightSum1");
- LogicalPlan plan = new LogicalPlanBuilder(scan1)
- .join(scan2, JoinType.INNER_JOIN, Pair.of(0, 0))
- .aggGroupUsingIndex(ImmutableList.of(0),
- ImmutableList.of(scan1.getOutput().get(0), leftSum1, leftSum2, rightSum1))
- .build();
-
- PlanChecker.from(MemoTestUtils.createConnectContext(), plan)
- .applyTopDown(new PushDownSumThroughJoinOneSide())
- .matches(
- logicalAggregate(
- logicalJoin(
- logicalAggregate(),
- logicalAggregate()
- )
- )
- );
- }
-}
diff --git a/regression-test/data/nereids_rules_p0/eager_aggregate/push_down_count_through_join_one_side.out b/regression-test/data/nereids_rules_p0/eager_aggregate/push_down_count_through_join_one_side.out
index 59c57e460e..0de2a12166 100644
--- a/regression-test/data/nereids_rules_p0/eager_aggregate/push_down_count_through_join_one_side.out
+++ b/regression-test/data/nereids_rules_p0/eager_aggregate/push_down_count_through_join_one_side.out
@@ -148,8 +148,10 @@ PhysicalResultSink
--hashAgg[GLOBAL]
----hashAgg[LOCAL]
------hashJoin[INNER_JOIN] hashCondition=((t1.id = t2.id) and (t1.name = t2.name)) otherCondition=()
---------PhysicalOlapScan[count_t_one_side]
---------PhysicalOlapScan[count_t_one_side]
+--------hashAgg[LOCAL]
+----------PhysicalOlapScan[count_t_one_side]
+--------hashAgg[LOCAL]
+----------PhysicalOlapScan[count_t_one_side]
-- !groupby_pushdown_equal_conditions_non_aggregate_selection --
PhysicalResultSink
diff --git a/regression-test/data/nereids_rules_p0/eager_aggregate/push_down_max_through_join.out b/regression-test/data/nereids_rules_p0/eager_aggregate/push_down_max_through_join.out
index bd4430fcb6..9a7cfa6a4f 100644
--- a/regression-test/data/nereids_rules_p0/eager_aggregate/push_down_max_through_join.out
+++ b/regression-test/data/nereids_rules_p0/eager_aggregate/push_down_max_through_join.out
@@ -148,8 +148,10 @@ PhysicalResultSink
--hashAgg[GLOBAL]
----hashAgg[LOCAL]
------hashJoin[INNER_JOIN] hashCondition=((t1.id = t2.id) and (t1.name = t2.name)) otherCondition=()
---------PhysicalOlapScan[max_t]
---------PhysicalOlapScan[max_t]
+--------hashAgg[LOCAL]
+----------PhysicalOlapScan[max_t]
+--------hashAgg[LOCAL]
+----------PhysicalOlapScan[max_t]
-- !groupby_pushdown_equal_conditions_non_aggregate_selection --
PhysicalResultSink
diff --git a/regression-test/data/nereids_rules_p0/eager_aggregate/push_down_min_through_join.out b/regression-test/data/nereids_rules_p0/eager_aggregate/push_down_min_through_join.out
index a0a2acd944..3e2ccc6f43 100644
--- a/regression-test/data/nereids_rules_p0/eager_aggregate/push_down_min_through_join.out
+++ b/regression-test/data/nereids_rules_p0/eager_aggregate/push_down_min_through_join.out
@@ -148,8 +148,10 @@ PhysicalResultSink
--hashAgg[GLOBAL]
----hashAgg[LOCAL]
------hashJoin[INNER_JOIN] hashCondition=((t1.id = t2.id) and (t1.name = t2.name)) otherCondition=()
---------PhysicalOlapScan[min_t]
---------PhysicalOlapScan[min_t]
+--------hashAgg[LOCAL]
+----------PhysicalOlapScan[min_t]
+--------hashAgg[LOCAL]
+----------PhysicalOlapScan[min_t]
-- !groupby_pushdown_equal_conditions_non_aggregate_selection --
PhysicalResultSink
diff --git a/regression-test/data/nereids_rules_p0/eager_aggregate/push_down_sum_through_join_one_side.out b/regression-test/data/nereids_rules_p0/eager_aggregate/push_down_sum_through_join_one_side.out
index 8046cec6d9..65d3a7b68f 100644
--- a/regression-test/data/nereids_rules_p0/eager_aggregate/push_down_sum_through_join_one_side.out
+++ b/regression-test/data/nereids_rules_p0/eager_aggregate/push_down_sum_through_join_one_side.out
@@ -148,8 +148,10 @@ PhysicalResultSink
--hashAgg[GLOBAL]
----hashAgg[LOCAL]
------hashJoin[INNER_JOIN] hashCondition=((t1.id = t2.id) and (t1.name = t2.name)) otherCondition=()
---------PhysicalOlapScan[sum_t_one_side]
---------PhysicalOlapScan[sum_t_one_side]
+--------hashAgg[LOCAL]
+----------PhysicalOlapScan[sum_t_one_side]
+--------hashAgg[LOCAL]
+----------PhysicalOlapScan[sum_t_one_side]
-- !groupby_pushdown_equal_conditions_non_aggregate_selection --
PhysicalResultSink
diff --git a/regression-test/suites/nereids_rules_p0/eager_aggregate/basic.groovy b/regression-test/suites/nereids_rules_p0/eager_aggregate/basic.groovy
index afa64135d3..58d50b3add 100644
--- a/regression-test/suites/nereids_rules_p0/eager_aggregate/basic.groovy
+++ b/regression-test/suites/nereids_rules_p0/eager_aggregate/basic.groovy
@@ -21,7 +21,7 @@ suite("eager_aggregate_basic") {
sql "SET enable_fallback_to_original_planner=false"
sql "SET ignore_shape_nodes='PhysicalDistribute,PhysicalProject'"
- sql "SET ENABLE_NEREIDS_RULES=push_down_min_max_through_join"
+ sql "SET ENABLE_NEREIDS_RULES=push_down_agg_through_join_one_side"
sql "SET ENABLE_NEREIDS_RULES=push_down_sum_through_join"
sql "SET ENABLE_NEREIDS_RULES=push_down_count_through_join"
diff --git a/regression-test/suites/nereids_rules_p0/eager_aggregate/basic_one_side.groovy b/regression-test/suites/nereids_rules_p0/eager_aggregate/basic_one_side.groovy
index cb84e0cc1e..cc1c0c8c73 100644
--- a/regression-test/suites/nereids_rules_p0/eager_aggregate/basic_one_side.groovy
+++ b/regression-test/suites/nereids_rules_p0/eager_aggregate/basic_one_side.groovy
@@ -21,9 +21,7 @@ suite("eager_aggregate_basic_one_side") {
sql "SET enable_fallback_to_original_planner=false"
sql "SET ignore_shape_nodes='PhysicalDistribute,PhysicalProject'"
- sql "SET ENABLE_NEREIDS_RULES=push_down_min_max_through_join_one_side"
- sql "SET ENABLE_NEREIDS_RULES=push_down_sum_through_join_one_side"
- sql "SET ENABLE_NEREIDS_RULES=push_down_count_through_join_one_side"
+ sql "SET ENABLE_NEREIDS_RULES=push_down_agg_through_join_one_side"
sql """
DROP TABLE IF EXISTS shunt_log_com_dd_library_one_side;
diff --git a/regression-test/suites/nereids_rules_p0/eager_aggregate/push_down_count_through_join_one_side.groovy b/regression-test/suites/nereids_rules_p0/eager_aggregate/push_down_count_through_join_one_side.groovy
index 037368f051..8886287436 100644
--- a/regression-test/suites/nereids_rules_p0/eager_aggregate/push_down_count_through_join_one_side.groovy
+++ b/regression-test/suites/nereids_rules_p0/eager_aggregate/push_down_count_through_join_one_side.groovy
@@ -48,7 +48,7 @@ suite("push_down_count_through_join_one_side") {
sql "insert into count_t_one_side values (9, 3, null)"
sql "insert into count_t_one_side values (10, null, null)"
- sql "SET ENABLE_NEREIDS_RULES=push_down_count_through_join_one_side"
+ sql "SET ENABLE_NEREIDS_RULES=push_down_agg_through_join_one_side"
qt_groupby_pushdown_basic """
explain shape plan select count(t1.score) from count_t_one_side t1, count_t_one_side t2 where t1.id = t2.id group by t1.name;
diff --git a/regression-test/suites/nereids_rules_p0/eager_aggregate/push_down_max_through_join.groovy b/regression-test/suites/nereids_rules_p0/eager_aggregate/push_down_max_through_join.groovy
index 68d1946b35..26772637fe 100644
--- a/regression-test/suites/nereids_rules_p0/eager_aggregate/push_down_max_through_join.groovy
+++ b/regression-test/suites/nereids_rules_p0/eager_aggregate/push_down_max_through_join.groovy
@@ -48,7 +48,7 @@ suite("push_down_max_through_join") {
sql "insert into max_t values (9, 3, null)"
sql "insert into max_t values (10, null, null)"
- sql "SET ENABLE_NEREIDS_RULES=push_down_min_max_through_join"
+ sql "SET ENABLE_NEREIDS_RULES=push_down_agg_through_join_one_side"
qt_groupby_pushdown_basic """
explain shape plan select max(t1.score) from max_t t1, max_t t2 where t1.id = t2.id group by t1.name;
diff --git a/regression-test/suites/nereids_rules_p0/eager_aggregate/push_down_min_through_join.groovy b/regression-test/suites/nereids_rules_p0/eager_aggregate/push_down_min_through_join.groovy
index 560bf1c0d7..7942fbd28c 100644
--- a/regression-test/suites/nereids_rules_p0/eager_aggregate/push_down_min_through_join.groovy
+++ b/regression-test/suites/nereids_rules_p0/eager_aggregate/push_down_min_through_join.groovy
@@ -48,7 +48,7 @@ suite("push_down_min_through_join") {
sql "insert into min_t values (9, 3, null)"
sql "insert into min_t values (10, null, null)"
- sql "SET ENABLE_NEREIDS_RULES=push_down_min_max_through_join"
+ sql "SET ENABLE_NEREIDS_RULES=push_down_agg_through_join_one_side"
qt_groupby_pushdown_basic """
explain shape plan select min(t1.score) from min_t t1, min_t t2 where t1.id = t2.id group by t1.name;
diff --git a/regression-test/suites/nereids_rules_p0/eager_aggregate/push_down_sum_through_join_one_side.groovy b/regression-test/suites/nereids_rules_p0/eager_aggregate/push_down_sum_through_join_one_side.groovy
index 1ecc6aa48a..fecf141026 100644
--- a/regression-test/suites/nereids_rules_p0/eager_aggregate/push_down_sum_through_join_one_side.groovy
+++ b/regression-test/suites/nereids_rules_p0/eager_aggregate/push_down_sum_through_join_one_side.groovy
@@ -48,7 +48,7 @@ suite("push_down_sum_through_join_one_side") {
sql "insert into sum_t_one_side values (9, 3, null)"
sql "insert into sum_t_one_side values (10, null, null)"
- sql "SET ENABLE_NEREIDS_RULES=push_down_sum_through_join_one_side"
+ sql "SET ENABLE_NEREIDS_RULES=push_down_agg_through_join_one_side"
qt_groupby_pushdown_basic """
explain shape plan select sum(t1.score) from sum_t_one_side t1, sum_t_one_side t2 where t1.id = t2.id group by t1.name;