pick from master #41719 just like previous PR #41548 this PR process union node to ensure not require any column from its children when it is required by its parent with empty slot set
This commit is contained in:
@ -25,6 +25,7 @@ import org.apache.doris.nereids.trees.expressions.NamedExpression;
|
||||
import org.apache.doris.nereids.trees.expressions.Slot;
|
||||
import org.apache.doris.nereids.trees.expressions.SlotReference;
|
||||
import org.apache.doris.nereids.trees.expressions.functions.agg.AggregateFunction;
|
||||
import org.apache.doris.nereids.trees.expressions.literal.TinyIntLiteral;
|
||||
import org.apache.doris.nereids.trees.plans.Plan;
|
||||
import org.apache.doris.nereids.trees.plans.algebra.Aggregate;
|
||||
import org.apache.doris.nereids.trees.plans.algebra.SetOperation.Qualifier;
|
||||
@ -41,6 +42,7 @@ import org.apache.doris.nereids.trees.plans.logical.LogicalWindow;
|
||||
import org.apache.doris.nereids.trees.plans.logical.OutputPrunable;
|
||||
import org.apache.doris.nereids.trees.plans.visitor.CustomRewriter;
|
||||
import org.apache.doris.nereids.trees.plans.visitor.DefaultPlanRewriter;
|
||||
import org.apache.doris.nereids.types.TinyIntType;
|
||||
import org.apache.doris.nereids.util.ExpressionUtils;
|
||||
import org.apache.doris.nereids.util.Utils;
|
||||
|
||||
@ -314,6 +316,8 @@ public class ColumnPruning extends DefaultPlanRewriter<PruneContext> implements
|
||||
}
|
||||
List<NamedExpression> prunedOutputs = Lists.newArrayList();
|
||||
List<List<NamedExpression>> constantExprsList = union.getConstantExprsList();
|
||||
List<List<SlotReference>> regularChildrenOutputs = union.getRegularChildrenOutputs();
|
||||
List<Plan> children = union.children();
|
||||
List<Integer> extractColumnIndex = Lists.newArrayList();
|
||||
for (int i = 0; i < originOutput.size(); i++) {
|
||||
NamedExpression output = originOutput.get(i);
|
||||
@ -322,31 +326,41 @@ public class ColumnPruning extends DefaultPlanRewriter<PruneContext> implements
|
||||
extractColumnIndex.add(i);
|
||||
}
|
||||
}
|
||||
if (prunedOutputs.isEmpty()) {
|
||||
List<NamedExpression> candidates = Lists.newArrayList(originOutput);
|
||||
candidates.retainAll(keys);
|
||||
if (candidates.isEmpty()) {
|
||||
candidates = originOutput;
|
||||
}
|
||||
NamedExpression minimumColumn = ExpressionUtils.selectMinimumColumn(candidates);
|
||||
prunedOutputs = ImmutableList.of(minimumColumn);
|
||||
extractColumnIndex.add(originOutput.indexOf(minimumColumn));
|
||||
}
|
||||
|
||||
int len = extractColumnIndex.size();
|
||||
ImmutableList.Builder<List<NamedExpression>> prunedConstantExprsList
|
||||
= ImmutableList.builderWithExpectedSize(constantExprsList.size());
|
||||
for (List<NamedExpression> row : constantExprsList) {
|
||||
ImmutableList.Builder<NamedExpression> newRow = ImmutableList.builderWithExpectedSize(len);
|
||||
for (int idx : extractColumnIndex) {
|
||||
newRow.add(row.get(idx));
|
||||
if (prunedOutputs.isEmpty()) {
|
||||
// process prune all columns
|
||||
NamedExpression originSlot = originOutput.get(0);
|
||||
prunedOutputs = ImmutableList.of(new SlotReference(originSlot.getExprId(), originSlot.getName(),
|
||||
TinyIntType.INSTANCE, false, originSlot.getQualifier()));
|
||||
regularChildrenOutputs = Lists.newArrayListWithCapacity(regularChildrenOutputs.size());
|
||||
children = Lists.newArrayListWithCapacity(children.size());
|
||||
for (int i = 0; i < union.getArity(); i++) {
|
||||
LogicalProject<?> project = new LogicalProject<>(
|
||||
ImmutableList.of(new Alias(new TinyIntLiteral((byte) 1))), union.child(i));
|
||||
regularChildrenOutputs.add((List) project.getOutput());
|
||||
children.add(project);
|
||||
}
|
||||
for (int i = 0; i < constantExprsList.size(); i++) {
|
||||
prunedConstantExprsList.add(ImmutableList.of(new Alias(new TinyIntLiteral((byte) 1))));
|
||||
}
|
||||
} else {
|
||||
int len = extractColumnIndex.size();
|
||||
for (List<NamedExpression> row : constantExprsList) {
|
||||
ImmutableList.Builder<NamedExpression> newRow = ImmutableList.builderWithExpectedSize(len);
|
||||
for (int idx : extractColumnIndex) {
|
||||
newRow.add(row.get(idx));
|
||||
}
|
||||
prunedConstantExprsList.add(newRow.build());
|
||||
}
|
||||
prunedConstantExprsList.add(newRow.build());
|
||||
}
|
||||
if (prunedOutputs.equals(originOutput)) {
|
||||
|
||||
if (prunedOutputs.equals(originOutput) && !context.requiredSlots.isEmpty()) {
|
||||
return union;
|
||||
} else {
|
||||
return union.withNewOutputsAndConstExprsList(prunedOutputs, prunedConstantExprsList.build());
|
||||
return union.withNewOutputsChildrenAndConstExprsList(prunedOutputs, children,
|
||||
regularChildrenOutputs, prunedConstantExprsList.build());
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@ -19,6 +19,7 @@ package org.apache.doris.nereids.rules.rewrite;
|
||||
|
||||
import org.apache.doris.nereids.trees.expressions.NamedExpression;
|
||||
import org.apache.doris.nereids.trees.expressions.SlotReference;
|
||||
import org.apache.doris.nereids.trees.expressions.literal.TinyIntLiteral;
|
||||
import org.apache.doris.nereids.trees.plans.Plan;
|
||||
import org.apache.doris.nereids.trees.plans.logical.LogicalProject;
|
||||
import org.apache.doris.nereids.types.DoubleType;
|
||||
@ -313,6 +314,21 @@ public class ColumnPruningTest extends TestWithFeService implements MemoPatternM
|
||||
);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void pruneUnionAllWithCount() {
|
||||
PlanChecker.from(connectContext)
|
||||
.analyze("select count() from (select 1, 2 union all select id, age from student) t")
|
||||
.customRewrite(new ColumnPruning())
|
||||
.matches(
|
||||
logicalProject(
|
||||
logicalUnion(
|
||||
logicalProject().when(p -> p.getProjects().size() == 1 && p.getProjects().get(0).child(0) instanceof TinyIntLiteral),
|
||||
logicalProject().when(p -> p.getProjects().size() == 1 && p.getProjects().get(0).child(0) instanceof TinyIntLiteral)
|
||||
)
|
||||
).when(p -> p.getProjects().size() == 1 && p.getProjects().get(0).child(0) instanceof TinyIntLiteral)
|
||||
);
|
||||
}
|
||||
|
||||
private List<String> getOutputQualifiedNames(LogicalProject<? extends Plan> p) {
|
||||
return getOutputQualifiedNames(p.getOutputs());
|
||||
}
|
||||
|
||||
@ -18,6 +18,7 @@
|
||||
suite("const_expr_column_pruning") {
|
||||
sql """SET ignore_shape_nodes='PhysicalDistribute,PhysicalProject'"""
|
||||
// should only keep one column in union
|
||||
sql "select count(1) from(select 3, 6 union all select 1, 3) t"
|
||||
sql "select count(a) from(select 3 a, 6 union all select 1, 3) t"
|
||||
}
|
||||
sql """select count(1) from(select 3, 6 union all select 1, 3) t"""
|
||||
sql """select count(1) from(select 3, 6 union all select "1", 3) t"""
|
||||
sql """select count(a) from(select 3 a, 6 union all select "1", 3) t"""
|
||||
}
|
||||
|
||||
@ -56,5 +56,10 @@ suite("window_column_pruning") {
|
||||
sql "select id from (select id, rank() over() px from window_column_pruning union all select id, rank() over() px from window_column_pruning) a"
|
||||
notContains "rank"
|
||||
}
|
||||
|
||||
explain {
|
||||
sql "select count() from (select row_number() over(partition by id) from window_column_pruning) tmp"
|
||||
notContains "row_number"
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user