[fix](Nereids) dead loop in FillUpMissingSlots (#18902)

FillUpMissingSlots don't handle some cornel case, sometime we don't need fillup, we should return null
This commit is contained in:
jakevin
2023-04-26 13:31:51 +08:00
committed by GitHub
parent a7773d16d6
commit aa88083c1e
4 changed files with 66 additions and 48 deletions

View File

@ -31,12 +31,12 @@ import org.apache.doris.nereids.trees.expressions.functions.agg.AggregateFunctio
import org.apache.doris.nereids.trees.plans.Plan;
import org.apache.doris.nereids.trees.plans.algebra.Aggregate;
import org.apache.doris.nereids.trees.plans.logical.LogicalFilter;
import org.apache.doris.nereids.trees.plans.logical.LogicalHaving;
import org.apache.doris.nereids.trees.plans.logical.LogicalProject;
import org.apache.doris.nereids.trees.plans.logical.LogicalSort;
import org.apache.doris.nereids.util.ExpressionUtils;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableList.Builder;
import com.google.common.collect.Lists;
import com.google.common.collect.Maps;
import com.google.common.collect.Streams;
@ -58,37 +58,43 @@ public class FillUpMissingSlots implements AnalysisRuleFactory {
return ImmutableList.of(
RuleType.FILL_UP_SORT_PROJECT.build(
logicalSort(logicalProject())
.when(this::checkSort)
.then(sort -> {
final Builder<NamedExpression> projectionsBuilder = ImmutableList.builder();
projectionsBuilder.addAll(sort.child().getProjects());
Set<Slot> notExistedInProject = sort.getExpressions().stream()
LogicalProject<Plan> project = sort.child();
Set<Slot> projectOutputSet = project.getOutputSet();
Set<Slot> notExistedInProject = sort.getOrderKeys().stream()
.map(OrderKey::getExpr)
.map(Expression::getInputSlots)
.flatMap(Set::stream)
.filter(s -> !sort.child().getOutputSet().contains(s))
.filter(s -> !projectOutputSet.contains(s))
.collect(Collectors.toSet());
projectionsBuilder.addAll(notExistedInProject);
return new LogicalProject(sort.child().getOutput(),
new LogicalSort<>(sort.getOrderKeys(),
new LogicalProject<>(projectionsBuilder.build(),
sort.child().child())));
if (notExistedInProject.size() == 0) {
return null;
}
List<NamedExpression> projects = ImmutableList.<NamedExpression>builder()
.addAll(project.getProjects()).addAll(notExistedInProject).build();
return new LogicalProject<>(ImmutableList.copyOf(project.getOutput()),
sort.withChildren(new LogicalProject<>(projects, project.child())));
})
),
RuleType.FILL_UP_SORT_AGGREGATE.build(
logicalSort(aggregate())
.when(this::checkSort)
.then(sort -> {
Aggregate aggregate = sort.child();
Resolver resolver = new Resolver(aggregate);
Aggregate<Plan> agg = sort.child();
Resolver resolver = new Resolver(agg);
sort.getExpressions().forEach(resolver::resolve);
return createPlan(resolver, sort.child(), (r, a) -> {
return createPlan(resolver, agg, (r, a) -> {
List<OrderKey> newOrderKeys = sort.getOrderKeys().stream()
.map(ok -> new OrderKey(
ExpressionUtils.replace(ok.getExpr(), r.getSubstitution()),
ok.isAsc(),
ok.isNullFirst()))
.collect(ImmutableList.toImmutableList());
return new LogicalSort<>(newOrderKeys, a);
boolean notChanged = newOrderKeys.equals(sort.getOrderKeys());
if (notChanged && a.equals(agg)) {
return null;
}
return notChanged ? sort.withChildren(a) : new LogicalSort<>(newOrderKeys, a);
});
})
),
@ -96,35 +102,42 @@ public class FillUpMissingSlots implements AnalysisRuleFactory {
logicalSort(logicalHaving(aggregate()))
.when(this::checkSort)
.then(sort -> {
Aggregate aggregate = sort.child().child();
Resolver resolver = new Resolver(aggregate);
LogicalHaving<Aggregate<Plan>> having = sort.child();
Aggregate<Plan> agg = having.child();
Resolver resolver = new Resolver(agg);
sort.getExpressions().forEach(resolver::resolve);
return createPlan(resolver, sort.child().child(), (r, a) -> {
return createPlan(resolver, agg, (r, a) -> {
List<OrderKey> newOrderKeys = sort.getOrderKeys().stream()
.map(ok -> new OrderKey(
ExpressionUtils.replace(ok.getExpr(), r.getSubstitution()),
ok.isAsc(),
ok.isNullFirst()))
.map(key -> key.withExpression(
ExpressionUtils.replace(key.getExpr(), r.getSubstitution())))
.collect(ImmutableList.toImmutableList());
return new LogicalSort<>(newOrderKeys, sort.child().withChildren(a));
boolean notChanged = newOrderKeys.equals(sort.getOrderKeys());
if (notChanged && a.equals(agg)) {
return null;
}
return notChanged ? sort.withChildren(a) : new LogicalSort<>(newOrderKeys, a);
});
})
),
RuleType.FILL_UP_HAVING_AGGREGATE.build(
logicalHaving(aggregate()).then(having -> {
Aggregate aggregate = having.child();
Resolver resolver = new Resolver(aggregate);
Aggregate<Plan> agg = having.child();
Resolver resolver = new Resolver(agg);
having.getConjuncts().forEach(resolver::resolve);
return createPlan(resolver, having.child(), (r, a) -> {
return createPlan(resolver, agg, (r, a) -> {
Set<Expression> newConjuncts = ExpressionUtils.replace(
having.getConjuncts(), r.getSubstitution());
return new LogicalFilter<>(newConjuncts, a);
boolean notChanged = newConjuncts.equals(having.getConjuncts());
if (notChanged && a.equals(agg)) {
return null;
}
return notChanged ? having.withChildren(a) : new LogicalHaving<>(newConjuncts, a);
});
})
),
// Convert having to filter
RuleType.FILL_UP_HAVING_PROJECT.build(
logicalHaving(logicalProject()).then(having -> new LogicalFilter<>(having.getConjuncts(),
having.child()))
logicalHaving().then(having -> new LogicalFilter<>(having.getConjuncts(), having.child()))
)
);
}
@ -244,18 +257,27 @@ public class FillUpMissingSlots implements AnalysisRuleFactory {
}
private Plan createPlan(Resolver resolver, Aggregate<? extends Plan> aggregate, PlanGenerator planGenerator) {
Aggregate<? extends Plan> newAggregate;
if (resolver.getNewOutputSlots().isEmpty()) {
newAggregate = aggregate;
} else {
List<NamedExpression> newOutputExpressions = Streams
.concat(aggregate.getOutputExpressions().stream(), resolver.getNewOutputSlots().stream())
.collect(ImmutableList.toImmutableList());
newAggregate = aggregate.withAggOutput(newOutputExpressions);
}
Plan plan = planGenerator.apply(resolver, newAggregate);
if (plan == null) {
return null;
}
List<NamedExpression> projections = aggregate.getOutputExpressions().stream()
.map(NamedExpression::toSlot).collect(ImmutableList.toImmutableList());
List<NamedExpression> newOutputExpressions = Streams
.concat(aggregate.getOutputExpressions().stream(), resolver.getNewOutputSlots().stream())
.collect(ImmutableList.toImmutableList());
Aggregate newAggregate = aggregate.withAggOutput(newOutputExpressions);
Plan plan = planGenerator.apply(resolver, newAggregate);
return new LogicalProject<>(projections, plan);
}
private boolean checkSort(LogicalSort<? extends Plan> logicalSort) {
return logicalSort.getExpressions().stream()
return logicalSort.getOrderKeys().stream()
.map(OrderKey::getExpr)
.map(Expression::getInputSlots)
.flatMap(Set::stream)
.anyMatch(s -> !logicalSort.child().getOutputSet().contains(s))

View File

@ -228,11 +228,11 @@ public class ColumnPruning extends DefaultPlanRewriter<PruneContext> implements
: Optional.of(prunedOutputs);
}
private final <P extends Plan> P pruneChildren(P plan) {
private <P extends Plan> P pruneChildren(P plan) {
return pruneChildren(plan, ImmutableSet.of());
}
private final <P extends Plan> P pruneChildren(P plan, Set<Slot> parentRequiredSlots) {
private <P extends Plan> P pruneChildren(P plan, Set<Slot> parentRequiredSlots) {
if (plan.arity() == 0) {
// leaf
return plan;

View File

@ -48,7 +48,7 @@ public class LogicalHaving<CHILD_TYPE extends Plan> extends LogicalUnary<CHILD_T
this(conjuncts, Optional.empty(), Optional.empty(), child);
}
public LogicalHaving(Set<Expression> conjuncts, Optional<GroupExpression> groupExpression,
private LogicalHaving(Set<Expression> conjuncts, Optional<GroupExpression> groupExpression,
Optional<LogicalProperties> logicalProperties, CHILD_TYPE child) {
super(PlanType.LOGICAL_HAVING, groupExpression, logicalProperties, child);
this.conjuncts = ImmutableSet.copyOf(Objects.requireNonNull(conjuncts, "conjuncts can not be null"));

View File

@ -86,12 +86,10 @@ public class FillUpMissingSlotsTest extends AnalyzeCheckTestBase implements Memo
);
PlanChecker.from(connectContext).analyze(sql)
.matchesFromRoot(
logicalProject(
logicalFilter(
logicalAggregate(
logicalOlapScan()
).when(FieldChecker.check("outputExpressions", Lists.newArrayList(a1)))
).when(FieldChecker.check("conjuncts", ImmutableSet.of(new GreaterThan(a1, new TinyIntLiteral((byte) 0)))))));
logicalFilter(
logicalAggregate(
logicalOlapScan()
).when(FieldChecker.check("outputExpressions", Lists.newArrayList(a1)))));
sql = "SELECT a1 as value FROM t1 GROUP BY a1 HAVING a1 > 0";
a1 = new SlotReference(
@ -113,12 +111,11 @@ public class FillUpMissingSlotsTest extends AnalyzeCheckTestBase implements Memo
PlanChecker.from(connectContext).analyze(sql)
.applyBottomUp(new ExpressionRewrite(FunctionBinder.INSTANCE))
.matchesFromRoot(
logicalProject(
logicalFilter(
logicalAggregate(
logicalOlapScan()
).when(FieldChecker.check("outputExpressions", Lists.newArrayList(value)))
).when(FieldChecker.check("conjuncts", ImmutableSet.of(new GreaterThan(value.toSlot(), new TinyIntLiteral((byte) 0)))))));
).when(FieldChecker.check("conjuncts", ImmutableSet.of(new GreaterThan(value.toSlot(), new TinyIntLiteral((byte) 0))))));
sql = "SELECT SUM(a2) FROM t1 GROUP BY a1 HAVING a1 > 0";
a1 = new SlotReference(
@ -197,12 +194,11 @@ public class FillUpMissingSlotsTest extends AnalyzeCheckTestBase implements Memo
sql = "SELECT a1, SUM(a2) as value FROM t1 GROUP BY a1 HAVING value > 0";
PlanChecker.from(connectContext).analyze(sql)
.matchesFromRoot(
logicalProject(
logicalFilter(
logicalAggregate(
logicalOlapScan()
).when(FieldChecker.check("outputExpressions", Lists.newArrayList(a1, value)))
).when(FieldChecker.check("conjuncts", ImmutableSet.of(new GreaterThan(value.toSlot(), Literal.of(0L)))))));
).when(FieldChecker.check("conjuncts", ImmutableSet.of(new GreaterThan(value.toSlot(), Literal.of(0L))))));
sql = "SELECT a1, SUM(a2) FROM t1 GROUP BY a1 HAVING MIN(pk) > 0";
a1 = new SlotReference(