[fix](Nereids) dead loop in FillUpMissingSlots (#18902)
FillUpMissingSlots don't handle some cornel case, sometime we don't need fillup, we should return null
This commit is contained in:
@ -31,12 +31,12 @@ import org.apache.doris.nereids.trees.expressions.functions.agg.AggregateFunctio
|
||||
import org.apache.doris.nereids.trees.plans.Plan;
|
||||
import org.apache.doris.nereids.trees.plans.algebra.Aggregate;
|
||||
import org.apache.doris.nereids.trees.plans.logical.LogicalFilter;
|
||||
import org.apache.doris.nereids.trees.plans.logical.LogicalHaving;
|
||||
import org.apache.doris.nereids.trees.plans.logical.LogicalProject;
|
||||
import org.apache.doris.nereids.trees.plans.logical.LogicalSort;
|
||||
import org.apache.doris.nereids.util.ExpressionUtils;
|
||||
|
||||
import com.google.common.collect.ImmutableList;
|
||||
import com.google.common.collect.ImmutableList.Builder;
|
||||
import com.google.common.collect.Lists;
|
||||
import com.google.common.collect.Maps;
|
||||
import com.google.common.collect.Streams;
|
||||
@ -58,37 +58,43 @@ public class FillUpMissingSlots implements AnalysisRuleFactory {
|
||||
return ImmutableList.of(
|
||||
RuleType.FILL_UP_SORT_PROJECT.build(
|
||||
logicalSort(logicalProject())
|
||||
.when(this::checkSort)
|
||||
.then(sort -> {
|
||||
final Builder<NamedExpression> projectionsBuilder = ImmutableList.builder();
|
||||
projectionsBuilder.addAll(sort.child().getProjects());
|
||||
Set<Slot> notExistedInProject = sort.getExpressions().stream()
|
||||
LogicalProject<Plan> project = sort.child();
|
||||
Set<Slot> projectOutputSet = project.getOutputSet();
|
||||
Set<Slot> notExistedInProject = sort.getOrderKeys().stream()
|
||||
.map(OrderKey::getExpr)
|
||||
.map(Expression::getInputSlots)
|
||||
.flatMap(Set::stream)
|
||||
.filter(s -> !sort.child().getOutputSet().contains(s))
|
||||
.filter(s -> !projectOutputSet.contains(s))
|
||||
.collect(Collectors.toSet());
|
||||
projectionsBuilder.addAll(notExistedInProject);
|
||||
return new LogicalProject(sort.child().getOutput(),
|
||||
new LogicalSort<>(sort.getOrderKeys(),
|
||||
new LogicalProject<>(projectionsBuilder.build(),
|
||||
sort.child().child())));
|
||||
if (notExistedInProject.size() == 0) {
|
||||
return null;
|
||||
}
|
||||
List<NamedExpression> projects = ImmutableList.<NamedExpression>builder()
|
||||
.addAll(project.getProjects()).addAll(notExistedInProject).build();
|
||||
return new LogicalProject<>(ImmutableList.copyOf(project.getOutput()),
|
||||
sort.withChildren(new LogicalProject<>(projects, project.child())));
|
||||
})
|
||||
),
|
||||
RuleType.FILL_UP_SORT_AGGREGATE.build(
|
||||
logicalSort(aggregate())
|
||||
.when(this::checkSort)
|
||||
.then(sort -> {
|
||||
Aggregate aggregate = sort.child();
|
||||
Resolver resolver = new Resolver(aggregate);
|
||||
Aggregate<Plan> agg = sort.child();
|
||||
Resolver resolver = new Resolver(agg);
|
||||
sort.getExpressions().forEach(resolver::resolve);
|
||||
return createPlan(resolver, sort.child(), (r, a) -> {
|
||||
return createPlan(resolver, agg, (r, a) -> {
|
||||
List<OrderKey> newOrderKeys = sort.getOrderKeys().stream()
|
||||
.map(ok -> new OrderKey(
|
||||
ExpressionUtils.replace(ok.getExpr(), r.getSubstitution()),
|
||||
ok.isAsc(),
|
||||
ok.isNullFirst()))
|
||||
.collect(ImmutableList.toImmutableList());
|
||||
return new LogicalSort<>(newOrderKeys, a);
|
||||
boolean notChanged = newOrderKeys.equals(sort.getOrderKeys());
|
||||
if (notChanged && a.equals(agg)) {
|
||||
return null;
|
||||
}
|
||||
return notChanged ? sort.withChildren(a) : new LogicalSort<>(newOrderKeys, a);
|
||||
});
|
||||
})
|
||||
),
|
||||
@ -96,35 +102,42 @@ public class FillUpMissingSlots implements AnalysisRuleFactory {
|
||||
logicalSort(logicalHaving(aggregate()))
|
||||
.when(this::checkSort)
|
||||
.then(sort -> {
|
||||
Aggregate aggregate = sort.child().child();
|
||||
Resolver resolver = new Resolver(aggregate);
|
||||
LogicalHaving<Aggregate<Plan>> having = sort.child();
|
||||
Aggregate<Plan> agg = having.child();
|
||||
Resolver resolver = new Resolver(agg);
|
||||
sort.getExpressions().forEach(resolver::resolve);
|
||||
return createPlan(resolver, sort.child().child(), (r, a) -> {
|
||||
return createPlan(resolver, agg, (r, a) -> {
|
||||
List<OrderKey> newOrderKeys = sort.getOrderKeys().stream()
|
||||
.map(ok -> new OrderKey(
|
||||
ExpressionUtils.replace(ok.getExpr(), r.getSubstitution()),
|
||||
ok.isAsc(),
|
||||
ok.isNullFirst()))
|
||||
.map(key -> key.withExpression(
|
||||
ExpressionUtils.replace(key.getExpr(), r.getSubstitution())))
|
||||
.collect(ImmutableList.toImmutableList());
|
||||
return new LogicalSort<>(newOrderKeys, sort.child().withChildren(a));
|
||||
boolean notChanged = newOrderKeys.equals(sort.getOrderKeys());
|
||||
if (notChanged && a.equals(agg)) {
|
||||
return null;
|
||||
}
|
||||
return notChanged ? sort.withChildren(a) : new LogicalSort<>(newOrderKeys, a);
|
||||
});
|
||||
})
|
||||
),
|
||||
RuleType.FILL_UP_HAVING_AGGREGATE.build(
|
||||
logicalHaving(aggregate()).then(having -> {
|
||||
Aggregate aggregate = having.child();
|
||||
Resolver resolver = new Resolver(aggregate);
|
||||
Aggregate<Plan> agg = having.child();
|
||||
Resolver resolver = new Resolver(agg);
|
||||
having.getConjuncts().forEach(resolver::resolve);
|
||||
return createPlan(resolver, having.child(), (r, a) -> {
|
||||
return createPlan(resolver, agg, (r, a) -> {
|
||||
Set<Expression> newConjuncts = ExpressionUtils.replace(
|
||||
having.getConjuncts(), r.getSubstitution());
|
||||
return new LogicalFilter<>(newConjuncts, a);
|
||||
boolean notChanged = newConjuncts.equals(having.getConjuncts());
|
||||
if (notChanged && a.equals(agg)) {
|
||||
return null;
|
||||
}
|
||||
return notChanged ? having.withChildren(a) : new LogicalHaving<>(newConjuncts, a);
|
||||
});
|
||||
})
|
||||
),
|
||||
// Convert having to filter
|
||||
RuleType.FILL_UP_HAVING_PROJECT.build(
|
||||
logicalHaving(logicalProject()).then(having -> new LogicalFilter<>(having.getConjuncts(),
|
||||
having.child()))
|
||||
logicalHaving().then(having -> new LogicalFilter<>(having.getConjuncts(), having.child()))
|
||||
)
|
||||
);
|
||||
}
|
||||
@ -244,18 +257,27 @@ public class FillUpMissingSlots implements AnalysisRuleFactory {
|
||||
}
|
||||
|
||||
private Plan createPlan(Resolver resolver, Aggregate<? extends Plan> aggregate, PlanGenerator planGenerator) {
|
||||
Aggregate<? extends Plan> newAggregate;
|
||||
if (resolver.getNewOutputSlots().isEmpty()) {
|
||||
newAggregate = aggregate;
|
||||
} else {
|
||||
List<NamedExpression> newOutputExpressions = Streams
|
||||
.concat(aggregate.getOutputExpressions().stream(), resolver.getNewOutputSlots().stream())
|
||||
.collect(ImmutableList.toImmutableList());
|
||||
newAggregate = aggregate.withAggOutput(newOutputExpressions);
|
||||
}
|
||||
Plan plan = planGenerator.apply(resolver, newAggregate);
|
||||
if (plan == null) {
|
||||
return null;
|
||||
}
|
||||
List<NamedExpression> projections = aggregate.getOutputExpressions().stream()
|
||||
.map(NamedExpression::toSlot).collect(ImmutableList.toImmutableList());
|
||||
List<NamedExpression> newOutputExpressions = Streams
|
||||
.concat(aggregate.getOutputExpressions().stream(), resolver.getNewOutputSlots().stream())
|
||||
.collect(ImmutableList.toImmutableList());
|
||||
Aggregate newAggregate = aggregate.withAggOutput(newOutputExpressions);
|
||||
Plan plan = planGenerator.apply(resolver, newAggregate);
|
||||
return new LogicalProject<>(projections, plan);
|
||||
}
|
||||
|
||||
private boolean checkSort(LogicalSort<? extends Plan> logicalSort) {
|
||||
return logicalSort.getExpressions().stream()
|
||||
return logicalSort.getOrderKeys().stream()
|
||||
.map(OrderKey::getExpr)
|
||||
.map(Expression::getInputSlots)
|
||||
.flatMap(Set::stream)
|
||||
.anyMatch(s -> !logicalSort.child().getOutputSet().contains(s))
|
||||
|
||||
@ -228,11 +228,11 @@ public class ColumnPruning extends DefaultPlanRewriter<PruneContext> implements
|
||||
: Optional.of(prunedOutputs);
|
||||
}
|
||||
|
||||
private final <P extends Plan> P pruneChildren(P plan) {
|
||||
private <P extends Plan> P pruneChildren(P plan) {
|
||||
return pruneChildren(plan, ImmutableSet.of());
|
||||
}
|
||||
|
||||
private final <P extends Plan> P pruneChildren(P plan, Set<Slot> parentRequiredSlots) {
|
||||
private <P extends Plan> P pruneChildren(P plan, Set<Slot> parentRequiredSlots) {
|
||||
if (plan.arity() == 0) {
|
||||
// leaf
|
||||
return plan;
|
||||
|
||||
@ -48,7 +48,7 @@ public class LogicalHaving<CHILD_TYPE extends Plan> extends LogicalUnary<CHILD_T
|
||||
this(conjuncts, Optional.empty(), Optional.empty(), child);
|
||||
}
|
||||
|
||||
public LogicalHaving(Set<Expression> conjuncts, Optional<GroupExpression> groupExpression,
|
||||
private LogicalHaving(Set<Expression> conjuncts, Optional<GroupExpression> groupExpression,
|
||||
Optional<LogicalProperties> logicalProperties, CHILD_TYPE child) {
|
||||
super(PlanType.LOGICAL_HAVING, groupExpression, logicalProperties, child);
|
||||
this.conjuncts = ImmutableSet.copyOf(Objects.requireNonNull(conjuncts, "conjuncts can not be null"));
|
||||
|
||||
@ -86,12 +86,10 @@ public class FillUpMissingSlotsTest extends AnalyzeCheckTestBase implements Memo
|
||||
);
|
||||
PlanChecker.from(connectContext).analyze(sql)
|
||||
.matchesFromRoot(
|
||||
logicalProject(
|
||||
logicalFilter(
|
||||
logicalAggregate(
|
||||
logicalOlapScan()
|
||||
).when(FieldChecker.check("outputExpressions", Lists.newArrayList(a1)))
|
||||
).when(FieldChecker.check("conjuncts", ImmutableSet.of(new GreaterThan(a1, new TinyIntLiteral((byte) 0)))))));
|
||||
logicalFilter(
|
||||
logicalAggregate(
|
||||
logicalOlapScan()
|
||||
).when(FieldChecker.check("outputExpressions", Lists.newArrayList(a1)))));
|
||||
|
||||
sql = "SELECT a1 as value FROM t1 GROUP BY a1 HAVING a1 > 0";
|
||||
a1 = new SlotReference(
|
||||
@ -113,12 +111,11 @@ public class FillUpMissingSlotsTest extends AnalyzeCheckTestBase implements Memo
|
||||
PlanChecker.from(connectContext).analyze(sql)
|
||||
.applyBottomUp(new ExpressionRewrite(FunctionBinder.INSTANCE))
|
||||
.matchesFromRoot(
|
||||
logicalProject(
|
||||
logicalFilter(
|
||||
logicalAggregate(
|
||||
logicalOlapScan()
|
||||
).when(FieldChecker.check("outputExpressions", Lists.newArrayList(value)))
|
||||
).when(FieldChecker.check("conjuncts", ImmutableSet.of(new GreaterThan(value.toSlot(), new TinyIntLiteral((byte) 0)))))));
|
||||
).when(FieldChecker.check("conjuncts", ImmutableSet.of(new GreaterThan(value.toSlot(), new TinyIntLiteral((byte) 0))))));
|
||||
|
||||
sql = "SELECT SUM(a2) FROM t1 GROUP BY a1 HAVING a1 > 0";
|
||||
a1 = new SlotReference(
|
||||
@ -197,12 +194,11 @@ public class FillUpMissingSlotsTest extends AnalyzeCheckTestBase implements Memo
|
||||
sql = "SELECT a1, SUM(a2) as value FROM t1 GROUP BY a1 HAVING value > 0";
|
||||
PlanChecker.from(connectContext).analyze(sql)
|
||||
.matchesFromRoot(
|
||||
logicalProject(
|
||||
logicalFilter(
|
||||
logicalAggregate(
|
||||
logicalOlapScan()
|
||||
).when(FieldChecker.check("outputExpressions", Lists.newArrayList(a1, value)))
|
||||
).when(FieldChecker.check("conjuncts", ImmutableSet.of(new GreaterThan(value.toSlot(), Literal.of(0L)))))));
|
||||
).when(FieldChecker.check("conjuncts", ImmutableSet.of(new GreaterThan(value.toSlot(), Literal.of(0L))))));
|
||||
|
||||
sql = "SELECT a1, SUM(a2) FROM t1 GROUP BY a1 HAVING MIN(pk) > 0";
|
||||
a1 = new SlotReference(
|
||||
|
||||
Reference in New Issue
Block a user