diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/logical/BuildAggForUnion.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/logical/BuildAggForUnion.java index 3958f58bb2..225e4283b7 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/logical/BuildAggForUnion.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/logical/BuildAggForUnion.java @@ -22,7 +22,6 @@ import org.apache.doris.nereids.rules.RuleType; import org.apache.doris.nereids.rules.rewrite.OneRewriteRuleFactory; import org.apache.doris.nereids.trees.plans.algebra.SetOperation.Qualifier; import org.apache.doris.nereids.trees.plans.logical.LogicalAggregate; -import org.apache.doris.nereids.trees.plans.logical.LogicalUnion; import com.google.common.collect.ImmutableList; @@ -34,12 +33,9 @@ import java.util.Optional; public class BuildAggForUnion extends OneRewriteRuleFactory { @Override public Rule build() { - return logicalUnion().whenNot(LogicalUnion::hasBuildAgg).then(union -> { - if (union.getQualifier() == Qualifier.DISTINCT) { - return new LogicalAggregate<>(ImmutableList.copyOf(union.getOutputs()), union.getOutputs(), - true, Optional.empty(), union.withHasBuildAgg()); - } - return union; - }).toRule(RuleType.BUILD_AGG_FOR_UNION); + return logicalUnion().when(union -> union.getQualifier() == Qualifier.DISTINCT) + .then(union -> new LogicalAggregate<>(ImmutableList.copyOf(union.getOutputs()), union.getOutputs(), + true, Optional.empty(), union.withAllQualifier())) + .toRule(RuleType.BUILD_AGG_FOR_UNION); } } diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/logical/LogicalUnion.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/logical/LogicalUnion.java index 1a0ff53645..1f53f89d6d 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/logical/LogicalUnion.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/logical/LogicalUnion.java @@ -33,45 +33,32 @@ import java.util.Optional; * Logical Union. */ public class LogicalUnion extends LogicalSetOperation implements OutputPrunable { - - // When the union is DISTINCT, an additional LogicalAggregation needs to be created, - // so add this flag to judge whether agg has been created to avoid repeated creation - private final boolean hasBuildAgg; - // When there is an agg on the union and there is a filter on the agg, // it is necessary to keep the filter on the agg and push the filter down to each child of the union. private final boolean hasPushedFilter; public LogicalUnion(Qualifier qualifier, List inputs) { super(PlanType.LOGICAL_UNION, qualifier, inputs); - this.hasBuildAgg = false; this.hasPushedFilter = false; } - public LogicalUnion(Qualifier qualifier, List outputs, - boolean hasBuildAgg, boolean hasPushedFilter, - List inputs) { + public LogicalUnion(Qualifier qualifier, List outputs, boolean hasPushedFilter, + List inputs) { super(PlanType.LOGICAL_UNION, qualifier, outputs, inputs); - this.hasBuildAgg = hasBuildAgg; this.hasPushedFilter = hasPushedFilter; } - public LogicalUnion(Qualifier qualifier, List outputs, - boolean hasBuildAgg, boolean hasPushedFilter, + public LogicalUnion(Qualifier qualifier, List outputs, boolean hasPushedFilter, Optional groupExpression, Optional logicalProperties, List inputs) { super(PlanType.LOGICAL_UNION, qualifier, outputs, groupExpression, logicalProperties, inputs); - this.hasBuildAgg = hasBuildAgg; this.hasPushedFilter = hasPushedFilter; } @Override public String toString() { - return Utils.toSqlString("LogicalUnion", - "qualifier", qualifier, - "outputs", outputs, - "hasBuildAgg", hasBuildAgg, - "hasPushedFilter", hasPushedFilter); + return Utils.toSqlString("LogicalUnion", "qualifier", qualifier, "outputs", outputs, "hasPushedFilter", + hasPushedFilter); } @Override @@ -83,14 +70,12 @@ public class LogicalUnion extends LogicalSetOperation implements OutputPrunable return false; } LogicalUnion that = (LogicalUnion) o; - return super.equals(that) - && hasBuildAgg == that.hasBuildAgg - && hasPushedFilter == that.hasPushedFilter; + return super.equals(that) && hasPushedFilter == that.hasPushedFilter; } @Override public int hashCode() { - return Objects.hash(super.hashCode(), hasBuildAgg, hasPushedFilter); + return Objects.hash(super.hashCode(), hasPushedFilter); } @Override @@ -100,34 +85,27 @@ public class LogicalUnion extends LogicalSetOperation implements OutputPrunable @Override public LogicalUnion withChildren(List children) { - return new LogicalUnion(qualifier, outputs, hasBuildAgg, hasPushedFilter, children); + return new LogicalUnion(qualifier, outputs, hasPushedFilter, children); } @Override public LogicalUnion withGroupExpression(Optional groupExpression) { - return new LogicalUnion(qualifier, outputs, hasBuildAgg, hasPushedFilter, groupExpression, + return new LogicalUnion(qualifier, outputs, hasPushedFilter, groupExpression, Optional.of(getLogicalProperties()), children); } @Override public LogicalUnion withLogicalProperties(Optional logicalProperties) { - return new LogicalUnion(qualifier, outputs, hasBuildAgg, hasPushedFilter, - Optional.empty(), logicalProperties, children); + return new LogicalUnion(qualifier, outputs, hasPushedFilter, Optional.empty(), logicalProperties, children); } @Override public LogicalUnion withNewOutputs(List newOutputs) { - return new LogicalUnion(qualifier, newOutputs, hasBuildAgg, hasPushedFilter, - Optional.empty(), Optional.empty(), children); + return new LogicalUnion(qualifier, newOutputs, hasPushedFilter, Optional.empty(), Optional.empty(), children); } - public boolean hasBuildAgg() { - return hasBuildAgg; - } - - public LogicalUnion withHasBuildAgg() { - return new LogicalUnion(qualifier, outputs, true, hasPushedFilter, - Optional.empty(), Optional.empty(), children); + public LogicalUnion withAllQualifier() { + return new LogicalUnion(Qualifier.ALL, outputs, hasPushedFilter, Optional.empty(), Optional.empty(), children); } public boolean hasPushedFilter() { @@ -135,8 +113,7 @@ public class LogicalUnion extends LogicalSetOperation implements OutputPrunable } public LogicalUnion withHasPushedFilter() { - return new LogicalUnion(qualifier, outputs, hasBuildAgg, true, - Optional.empty(), Optional.empty(), children); + return new LogicalUnion(qualifier, outputs, true, Optional.empty(), Optional.empty(), children); } @Override