[opt](Nereids) use 1 as narrowest column when do column pruning on union (#41719) (#41975)

pick from master #41719

just like previous PR #41548

this PR process union node to ensure not require any column from its
children when it is required by its parent with empty slot set
This commit is contained in:
morrySnow
2024-10-17 15:28:27 +08:00
committed by GitHub
parent b4875c2789
commit 3fcd64366f
4 changed files with 57 additions and 21 deletions

View File

@ -25,6 +25,7 @@ import org.apache.doris.nereids.trees.expressions.NamedExpression;
import org.apache.doris.nereids.trees.expressions.Slot;
import org.apache.doris.nereids.trees.expressions.SlotReference;
import org.apache.doris.nereids.trees.expressions.functions.agg.AggregateFunction;
import org.apache.doris.nereids.trees.expressions.literal.TinyIntLiteral;
import org.apache.doris.nereids.trees.plans.Plan;
import org.apache.doris.nereids.trees.plans.algebra.Aggregate;
import org.apache.doris.nereids.trees.plans.algebra.SetOperation.Qualifier;
@ -41,6 +42,7 @@ import org.apache.doris.nereids.trees.plans.logical.LogicalWindow;
import org.apache.doris.nereids.trees.plans.logical.OutputPrunable;
import org.apache.doris.nereids.trees.plans.visitor.CustomRewriter;
import org.apache.doris.nereids.trees.plans.visitor.DefaultPlanRewriter;
import org.apache.doris.nereids.types.TinyIntType;
import org.apache.doris.nereids.util.ExpressionUtils;
import org.apache.doris.nereids.util.Utils;
@ -314,6 +316,8 @@ public class ColumnPruning extends DefaultPlanRewriter<PruneContext> implements
}
List<NamedExpression> prunedOutputs = Lists.newArrayList();
List<List<NamedExpression>> constantExprsList = union.getConstantExprsList();
List<List<SlotReference>> regularChildrenOutputs = union.getRegularChildrenOutputs();
List<Plan> children = union.children();
List<Integer> extractColumnIndex = Lists.newArrayList();
for (int i = 0; i < originOutput.size(); i++) {
NamedExpression output = originOutput.get(i);
@ -322,31 +326,41 @@ public class ColumnPruning extends DefaultPlanRewriter<PruneContext> implements
extractColumnIndex.add(i);
}
}
if (prunedOutputs.isEmpty()) {
List<NamedExpression> candidates = Lists.newArrayList(originOutput);
candidates.retainAll(keys);
if (candidates.isEmpty()) {
candidates = originOutput;
}
NamedExpression minimumColumn = ExpressionUtils.selectMinimumColumn(candidates);
prunedOutputs = ImmutableList.of(minimumColumn);
extractColumnIndex.add(originOutput.indexOf(minimumColumn));
}
int len = extractColumnIndex.size();
ImmutableList.Builder<List<NamedExpression>> prunedConstantExprsList
= ImmutableList.builderWithExpectedSize(constantExprsList.size());
for (List<NamedExpression> row : constantExprsList) {
ImmutableList.Builder<NamedExpression> newRow = ImmutableList.builderWithExpectedSize(len);
for (int idx : extractColumnIndex) {
newRow.add(row.get(idx));
if (prunedOutputs.isEmpty()) {
// process prune all columns
NamedExpression originSlot = originOutput.get(0);
prunedOutputs = ImmutableList.of(new SlotReference(originSlot.getExprId(), originSlot.getName(),
TinyIntType.INSTANCE, false, originSlot.getQualifier()));
regularChildrenOutputs = Lists.newArrayListWithCapacity(regularChildrenOutputs.size());
children = Lists.newArrayListWithCapacity(children.size());
for (int i = 0; i < union.getArity(); i++) {
LogicalProject<?> project = new LogicalProject<>(
ImmutableList.of(new Alias(new TinyIntLiteral((byte) 1))), union.child(i));
regularChildrenOutputs.add((List) project.getOutput());
children.add(project);
}
for (int i = 0; i < constantExprsList.size(); i++) {
prunedConstantExprsList.add(ImmutableList.of(new Alias(new TinyIntLiteral((byte) 1))));
}
} else {
int len = extractColumnIndex.size();
for (List<NamedExpression> row : constantExprsList) {
ImmutableList.Builder<NamedExpression> newRow = ImmutableList.builderWithExpectedSize(len);
for (int idx : extractColumnIndex) {
newRow.add(row.get(idx));
}
prunedConstantExprsList.add(newRow.build());
}
prunedConstantExprsList.add(newRow.build());
}
if (prunedOutputs.equals(originOutput)) {
if (prunedOutputs.equals(originOutput) && !context.requiredSlots.isEmpty()) {
return union;
} else {
return union.withNewOutputsAndConstExprsList(prunedOutputs, prunedConstantExprsList.build());
return union.withNewOutputsChildrenAndConstExprsList(prunedOutputs, children,
regularChildrenOutputs, prunedConstantExprsList.build());
}
}

View File

@ -19,6 +19,7 @@ package org.apache.doris.nereids.rules.rewrite;
import org.apache.doris.nereids.trees.expressions.NamedExpression;
import org.apache.doris.nereids.trees.expressions.SlotReference;
import org.apache.doris.nereids.trees.expressions.literal.TinyIntLiteral;
import org.apache.doris.nereids.trees.plans.Plan;
import org.apache.doris.nereids.trees.plans.logical.LogicalProject;
import org.apache.doris.nereids.types.DoubleType;
@ -313,6 +314,21 @@ public class ColumnPruningTest extends TestWithFeService implements MemoPatternM
);
}
@Test
public void pruneUnionAllWithCount() {
PlanChecker.from(connectContext)
.analyze("select count() from (select 1, 2 union all select id, age from student) t")
.customRewrite(new ColumnPruning())
.matches(
logicalProject(
logicalUnion(
logicalProject().when(p -> p.getProjects().size() == 1 && p.getProjects().get(0).child(0) instanceof TinyIntLiteral),
logicalProject().when(p -> p.getProjects().size() == 1 && p.getProjects().get(0).child(0) instanceof TinyIntLiteral)
)
).when(p -> p.getProjects().size() == 1 && p.getProjects().get(0).child(0) instanceof TinyIntLiteral)
);
}
private List<String> getOutputQualifiedNames(LogicalProject<? extends Plan> p) {
return getOutputQualifiedNames(p.getOutputs());
}

View File

@ -18,6 +18,7 @@
suite("const_expr_column_pruning") {
sql """SET ignore_shape_nodes='PhysicalDistribute,PhysicalProject'"""
// should only keep one column in union
sql "select count(1) from(select 3, 6 union all select 1, 3) t"
sql "select count(a) from(select 3 a, 6 union all select 1, 3) t"
}
sql """select count(1) from(select 3, 6 union all select 1, 3) t"""
sql """select count(1) from(select 3, 6 union all select "1", 3) t"""
sql """select count(a) from(select 3 a, 6 union all select "1", 3) t"""
}

View File

@ -56,5 +56,10 @@ suite("window_column_pruning") {
sql "select id from (select id, rank() over() px from window_column_pruning union all select id, rank() over() px from window_column_pruning) a"
notContains "rank"
}
explain {
sql "select count() from (select row_number() over(partition by id) from window_column_pruning) tmp"
notContains "row_number"
}
}