diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/FillUpMissingSlots.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/FillUpMissingSlots.java index 045ac37aed..c05b1cdf5e 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/FillUpMissingSlots.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/FillUpMissingSlots.java @@ -31,12 +31,12 @@ import org.apache.doris.nereids.trees.expressions.functions.agg.AggregateFunctio import org.apache.doris.nereids.trees.plans.Plan; import org.apache.doris.nereids.trees.plans.algebra.Aggregate; import org.apache.doris.nereids.trees.plans.logical.LogicalFilter; +import org.apache.doris.nereids.trees.plans.logical.LogicalHaving; import org.apache.doris.nereids.trees.plans.logical.LogicalProject; import org.apache.doris.nereids.trees.plans.logical.LogicalSort; import org.apache.doris.nereids.util.ExpressionUtils; import com.google.common.collect.ImmutableList; -import com.google.common.collect.ImmutableList.Builder; import com.google.common.collect.Lists; import com.google.common.collect.Maps; import com.google.common.collect.Streams; @@ -58,37 +58,43 @@ public class FillUpMissingSlots implements AnalysisRuleFactory { return ImmutableList.of( RuleType.FILL_UP_SORT_PROJECT.build( logicalSort(logicalProject()) - .when(this::checkSort) .then(sort -> { - final Builder projectionsBuilder = ImmutableList.builder(); - projectionsBuilder.addAll(sort.child().getProjects()); - Set notExistedInProject = sort.getExpressions().stream() + LogicalProject project = sort.child(); + Set projectOutputSet = project.getOutputSet(); + Set notExistedInProject = sort.getOrderKeys().stream() + .map(OrderKey::getExpr) .map(Expression::getInputSlots) .flatMap(Set::stream) - .filter(s -> !sort.child().getOutputSet().contains(s)) + .filter(s -> !projectOutputSet.contains(s)) .collect(Collectors.toSet()); - projectionsBuilder.addAll(notExistedInProject); - return new LogicalProject(sort.child().getOutput(), - new LogicalSort<>(sort.getOrderKeys(), - new LogicalProject<>(projectionsBuilder.build(), - sort.child().child()))); + if (notExistedInProject.size() == 0) { + return null; + } + List projects = ImmutableList.builder() + .addAll(project.getProjects()).addAll(notExistedInProject).build(); + return new LogicalProject<>(ImmutableList.copyOf(project.getOutput()), + sort.withChildren(new LogicalProject<>(projects, project.child()))); }) ), RuleType.FILL_UP_SORT_AGGREGATE.build( logicalSort(aggregate()) .when(this::checkSort) .then(sort -> { - Aggregate aggregate = sort.child(); - Resolver resolver = new Resolver(aggregate); + Aggregate agg = sort.child(); + Resolver resolver = new Resolver(agg); sort.getExpressions().forEach(resolver::resolve); - return createPlan(resolver, sort.child(), (r, a) -> { + return createPlan(resolver, agg, (r, a) -> { List newOrderKeys = sort.getOrderKeys().stream() .map(ok -> new OrderKey( ExpressionUtils.replace(ok.getExpr(), r.getSubstitution()), ok.isAsc(), ok.isNullFirst())) .collect(ImmutableList.toImmutableList()); - return new LogicalSort<>(newOrderKeys, a); + boolean notChanged = newOrderKeys.equals(sort.getOrderKeys()); + if (notChanged && a.equals(agg)) { + return null; + } + return notChanged ? sort.withChildren(a) : new LogicalSort<>(newOrderKeys, a); }); }) ), @@ -96,35 +102,42 @@ public class FillUpMissingSlots implements AnalysisRuleFactory { logicalSort(logicalHaving(aggregate())) .when(this::checkSort) .then(sort -> { - Aggregate aggregate = sort.child().child(); - Resolver resolver = new Resolver(aggregate); + LogicalHaving> having = sort.child(); + Aggregate agg = having.child(); + Resolver resolver = new Resolver(agg); sort.getExpressions().forEach(resolver::resolve); - return createPlan(resolver, sort.child().child(), (r, a) -> { + return createPlan(resolver, agg, (r, a) -> { List newOrderKeys = sort.getOrderKeys().stream() - .map(ok -> new OrderKey( - ExpressionUtils.replace(ok.getExpr(), r.getSubstitution()), - ok.isAsc(), - ok.isNullFirst())) + .map(key -> key.withExpression( + ExpressionUtils.replace(key.getExpr(), r.getSubstitution()))) .collect(ImmutableList.toImmutableList()); - return new LogicalSort<>(newOrderKeys, sort.child().withChildren(a)); + boolean notChanged = newOrderKeys.equals(sort.getOrderKeys()); + if (notChanged && a.equals(agg)) { + return null; + } + return notChanged ? sort.withChildren(a) : new LogicalSort<>(newOrderKeys, a); }); }) ), RuleType.FILL_UP_HAVING_AGGREGATE.build( logicalHaving(aggregate()).then(having -> { - Aggregate aggregate = having.child(); - Resolver resolver = new Resolver(aggregate); + Aggregate agg = having.child(); + Resolver resolver = new Resolver(agg); having.getConjuncts().forEach(resolver::resolve); - return createPlan(resolver, having.child(), (r, a) -> { + return createPlan(resolver, agg, (r, a) -> { Set newConjuncts = ExpressionUtils.replace( having.getConjuncts(), r.getSubstitution()); - return new LogicalFilter<>(newConjuncts, a); + boolean notChanged = newConjuncts.equals(having.getConjuncts()); + if (notChanged && a.equals(agg)) { + return null; + } + return notChanged ? having.withChildren(a) : new LogicalHaving<>(newConjuncts, a); }); }) ), + // Convert having to filter RuleType.FILL_UP_HAVING_PROJECT.build( - logicalHaving(logicalProject()).then(having -> new LogicalFilter<>(having.getConjuncts(), - having.child())) + logicalHaving().then(having -> new LogicalFilter<>(having.getConjuncts(), having.child())) ) ); } @@ -244,18 +257,27 @@ public class FillUpMissingSlots implements AnalysisRuleFactory { } private Plan createPlan(Resolver resolver, Aggregate aggregate, PlanGenerator planGenerator) { + Aggregate newAggregate; + if (resolver.getNewOutputSlots().isEmpty()) { + newAggregate = aggregate; + } else { + List newOutputExpressions = Streams + .concat(aggregate.getOutputExpressions().stream(), resolver.getNewOutputSlots().stream()) + .collect(ImmutableList.toImmutableList()); + newAggregate = aggregate.withAggOutput(newOutputExpressions); + } + Plan plan = planGenerator.apply(resolver, newAggregate); + if (plan == null) { + return null; + } List projections = aggregate.getOutputExpressions().stream() .map(NamedExpression::toSlot).collect(ImmutableList.toImmutableList()); - List newOutputExpressions = Streams - .concat(aggregate.getOutputExpressions().stream(), resolver.getNewOutputSlots().stream()) - .collect(ImmutableList.toImmutableList()); - Aggregate newAggregate = aggregate.withAggOutput(newOutputExpressions); - Plan plan = planGenerator.apply(resolver, newAggregate); return new LogicalProject<>(projections, plan); } private boolean checkSort(LogicalSort logicalSort) { - return logicalSort.getExpressions().stream() + return logicalSort.getOrderKeys().stream() + .map(OrderKey::getExpr) .map(Expression::getInputSlots) .flatMap(Set::stream) .anyMatch(s -> !logicalSort.child().getOutputSet().contains(s)) diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/logical/ColumnPruning.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/logical/ColumnPruning.java index 63cea0ff4c..bca8e43e2c 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/logical/ColumnPruning.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/logical/ColumnPruning.java @@ -228,11 +228,11 @@ public class ColumnPruning extends DefaultPlanRewriter implements : Optional.of(prunedOutputs); } - private final

P pruneChildren(P plan) { + private

P pruneChildren(P plan) { return pruneChildren(plan, ImmutableSet.of()); } - private final

P pruneChildren(P plan, Set parentRequiredSlots) { + private

P pruneChildren(P plan, Set parentRequiredSlots) { if (plan.arity() == 0) { // leaf return plan; diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/logical/LogicalHaving.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/logical/LogicalHaving.java index d6f0ae6d59..a47f6abc8f 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/logical/LogicalHaving.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/logical/LogicalHaving.java @@ -48,7 +48,7 @@ public class LogicalHaving extends LogicalUnary conjuncts, Optional groupExpression, + private LogicalHaving(Set conjuncts, Optional groupExpression, Optional logicalProperties, CHILD_TYPE child) { super(PlanType.LOGICAL_HAVING, groupExpression, logicalProperties, child); this.conjuncts = ImmutableSet.copyOf(Objects.requireNonNull(conjuncts, "conjuncts can not be null")); diff --git a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/analysis/FillUpMissingSlotsTest.java b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/analysis/FillUpMissingSlotsTest.java index 4c1fec2ba5..51342b9dba 100644 --- a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/analysis/FillUpMissingSlotsTest.java +++ b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/analysis/FillUpMissingSlotsTest.java @@ -86,12 +86,10 @@ public class FillUpMissingSlotsTest extends AnalyzeCheckTestBase implements Memo ); PlanChecker.from(connectContext).analyze(sql) .matchesFromRoot( - logicalProject( - logicalFilter( - logicalAggregate( - logicalOlapScan() - ).when(FieldChecker.check("outputExpressions", Lists.newArrayList(a1))) - ).when(FieldChecker.check("conjuncts", ImmutableSet.of(new GreaterThan(a1, new TinyIntLiteral((byte) 0))))))); + logicalFilter( + logicalAggregate( + logicalOlapScan() + ).when(FieldChecker.check("outputExpressions", Lists.newArrayList(a1))))); sql = "SELECT a1 as value FROM t1 GROUP BY a1 HAVING a1 > 0"; a1 = new SlotReference( @@ -113,12 +111,11 @@ public class FillUpMissingSlotsTest extends AnalyzeCheckTestBase implements Memo PlanChecker.from(connectContext).analyze(sql) .applyBottomUp(new ExpressionRewrite(FunctionBinder.INSTANCE)) .matchesFromRoot( - logicalProject( logicalFilter( logicalAggregate( logicalOlapScan() ).when(FieldChecker.check("outputExpressions", Lists.newArrayList(value))) - ).when(FieldChecker.check("conjuncts", ImmutableSet.of(new GreaterThan(value.toSlot(), new TinyIntLiteral((byte) 0))))))); + ).when(FieldChecker.check("conjuncts", ImmutableSet.of(new GreaterThan(value.toSlot(), new TinyIntLiteral((byte) 0)))))); sql = "SELECT SUM(a2) FROM t1 GROUP BY a1 HAVING a1 > 0"; a1 = new SlotReference( @@ -197,12 +194,11 @@ public class FillUpMissingSlotsTest extends AnalyzeCheckTestBase implements Memo sql = "SELECT a1, SUM(a2) as value FROM t1 GROUP BY a1 HAVING value > 0"; PlanChecker.from(connectContext).analyze(sql) .matchesFromRoot( - logicalProject( logicalFilter( logicalAggregate( logicalOlapScan() ).when(FieldChecker.check("outputExpressions", Lists.newArrayList(a1, value))) - ).when(FieldChecker.check("conjuncts", ImmutableSet.of(new GreaterThan(value.toSlot(), Literal.of(0L))))))); + ).when(FieldChecker.check("conjuncts", ImmutableSet.of(new GreaterThan(value.toSlot(), Literal.of(0L)))))); sql = "SELECT a1, SUM(a2) FROM t1 GROUP BY a1 HAVING MIN(pk) > 0"; a1 = new SlotReference(