diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/implementation/LogicalAggToPhysicalHashAgg.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/implementation/LogicalAggToPhysicalHashAgg.java index 56e94eb740..f6672aa7ee 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/implementation/LogicalAggToPhysicalHashAgg.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/implementation/LogicalAggToPhysicalHashAgg.java @@ -27,16 +27,18 @@ import org.apache.doris.nereids.trees.plans.physical.PhysicalAggregate; public class LogicalAggToPhysicalHashAgg extends OneImplementationRuleFactory { @Override public Rule build() { - return logicalAggregate().then(agg -> new PhysicalAggregate<>( - // TODO: for use a function to judge whether use stream - agg.getGroupByExpressions(), - agg.getOutputExpressions(), - agg.getPartitionExpressions(), - agg.getAggPhase(), - false, - agg.isFinalPhase(), - agg.getLogicalProperties(), - agg.child()) - ).toRule(RuleType.LOGICAL_AGG_TO_PHYSICAL_HASH_AGG_RULE); + return logicalAggregate().thenApply(ctx -> { + boolean useStreamAgg = !ctx.connectContext.getSessionVariable().disableStreamPreaggregations + && !ctx.root.getGroupByExpressions().isEmpty(); + return new PhysicalAggregate<>( + ctx.root.getGroupByExpressions(), + ctx.root.getOutputExpressions(), + ctx.root.getPartitionExpressions(), + ctx.root.getAggPhase(), + useStreamAgg, + ctx.root.isFinalPhase(), + ctx.root.getLogicalProperties(), + ctx.root.child()); + }).toRule(RuleType.LOGICAL_AGG_TO_PHYSICAL_HASH_AGG_RULE); } } diff --git a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/implementation/ImplementationTest.java b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/implementation/ImplementationTest.java index d61922db91..1a624f32d9 100644 --- a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/implementation/ImplementationTest.java +++ b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/implementation/ImplementationTest.java @@ -49,6 +49,7 @@ import org.junit.jupiter.api.Test; import java.util.List; import java.util.Map; +@SuppressWarnings({"unchecked", "unused"}) public class ImplementationTest { private static final Map rulesMap = ImmutableMap.builder() @@ -82,7 +83,7 @@ public class ImplementationTest { PhysicalPlan physicalPlan = executeImplementationRule(project); Assertions.assertEquals(PlanType.PHYSICAL_PROJECT, physicalPlan.getType()); - PhysicalProject physicalProject = (PhysicalProject) physicalPlan; + PhysicalProject physicalProject = (PhysicalProject) physicalPlan; Assertions.assertEquals(2, physicalProject.getExpressions().size()); Assertions.assertEquals(col1, physicalProject.getExpressions().get(0)); Assertions.assertEquals(col2, physicalProject.getExpressions().get(1)); @@ -98,7 +99,7 @@ public class ImplementationTest { PhysicalPlan physicalPlan = executeImplementationRule(topN); Assertions.assertEquals(PlanType.PHYSICAL_TOP_N, physicalPlan.getType()); - PhysicalTopN physicalTopN = (PhysicalTopN) physicalPlan; + PhysicalTopN physicalTopN = (PhysicalTopN) physicalPlan; Assertions.assertEquals(limit, physicalTopN.getLimit()); Assertions.assertEquals(offset, physicalTopN.getOffset()); Assertions.assertEquals(2, physicalTopN.getOrderKeys().size()); @@ -110,10 +111,10 @@ public class ImplementationTest { public void toPhysicalLimitTest() { int limit = 10; int offset = 100; - LogicalLimit logicalLimit = new LogicalLimit<>(limit, offset, groupPlan); + LogicalLimit logicalLimit = new LogicalLimit<>(limit, offset, groupPlan); PhysicalPlan physicalPlan = executeImplementationRule(logicalLimit); Assertions.assertEquals(PlanType.PHYSICAL_LIMIT, physicalPlan.getType()); - PhysicalLimit physicalLimit = (PhysicalLimit) physicalPlan; + PhysicalLimit physicalLimit = (PhysicalLimit) physicalPlan; Assertions.assertEquals(limit, physicalLimit.getLimit()); Assertions.assertEquals(offset, physicalLimit.getOffset()); }