diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/ColumnPruning.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/ColumnPruning.java index 284ac52e14..cf94caa25c 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/ColumnPruning.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/ColumnPruning.java @@ -80,8 +80,41 @@ import java.util.stream.IntStream; * plan */ public class ColumnPruning extends DefaultPlanRewriter implements CustomRewriter { + private Set keys; + + /** + * collect all columns used in expressions, which should not be pruned + */ + public static class KeyColumnCollector + extends DefaultPlanRewriter implements CustomRewriter { + public Set keys = Sets.newHashSet(); + + @Override + public Plan rewriteRoot(Plan plan, JobContext jobContext) { + return plan.accept(this, jobContext); + } + + @Override + public Plan visit(Plan plan, JobContext jobContext) { + for (Plan child : plan.children()) { + child.accept(this, jobContext); + } + plan.getExpressions().stream().filter( + expression -> !(expression instanceof SlotReference) + ).forEach( + expression -> { + keys.addAll(expression.getInputSlots()); + } + ); + return plan; + } + } + @Override public Plan rewriteRoot(Plan plan, JobContext jobContext) { + KeyColumnCollector keyColumnCollector = new KeyColumnCollector(); + plan.accept(keyColumnCollector, jobContext); + keys = keyColumnCollector.keys; return plan.accept(this, new PruneContext(plan.getOutputSet(), null)); } @@ -244,7 +277,7 @@ public class ColumnPruning extends DefaultPlanRewriter implements } /** prune output */ - public static

P pruneOutput(P plan, List originOutput, + public

P pruneOutput(P plan, List originOutput, Function, P> withPrunedOutput, PruneContext context) { if (originOutput.isEmpty()) { return plan; @@ -254,7 +287,12 @@ public class ColumnPruning extends DefaultPlanRewriter implements .collect(ImmutableList.toImmutableList()); if (prunedOutputs.isEmpty()) { - NamedExpression minimumColumn = ExpressionUtils.selectMinimumColumn(originOutput); + List candidates = Lists.newArrayList(originOutput); + candidates.retainAll(keys); + if (candidates.isEmpty()) { + candidates = originOutput; + } + NamedExpression minimumColumn = ExpressionUtils.selectMinimumColumn(candidates); prunedOutputs = ImmutableList.of(minimumColumn); } diff --git a/fe/fe-core/src/test/java/org/apache/doris/nereids/jobs/joinorder/TPCHTest.java b/fe/fe-core/src/test/java/org/apache/doris/nereids/jobs/joinorder/TPCHTest.java index ef2db355b4..3e88ac73bd 100644 --- a/fe/fe-core/src/test/java/org/apache/doris/nereids/jobs/joinorder/TPCHTest.java +++ b/fe/fe-core/src/test/java/org/apache/doris/nereids/jobs/joinorder/TPCHTest.java @@ -19,11 +19,13 @@ package org.apache.doris.nereids.jobs.joinorder; import org.apache.doris.nereids.datasets.tpch.TPCHTestBase; import org.apache.doris.nereids.datasets.tpch.TPCHUtils; +import org.apache.doris.nereids.trees.expressions.SlotReference; +import org.apache.doris.nereids.util.MemoPatternMatchSupported; import org.apache.doris.nereids.util.PlanChecker; import org.junit.jupiter.api.Test; -public class TPCHTest extends TPCHTestBase { +public class TPCHTest extends TPCHTestBase implements MemoPatternMatchSupported { @Test void testQ5() { PlanChecker.from(connectContext) @@ -33,4 +35,40 @@ public class TPCHTest extends TPCHTestBase { .dpHypOptimize() .printlnBestPlanTree(); } + + + // count(*) projects on children key columns + @Test + void testCountStarProject() { + String sql = "select\n" + + " count(*) as order_count\n" + + "from\n" + + " orders\n" + + "where\n" + + " o_orderdate >= date '1993-07-01'\n" + + " and o_orderdate < date '1993-07-01' + interval '3' month\n" + + " and exists (\n" + + " select\n" + + " *\n" + + " from\n" + + " lineitem\n" + + " where\n" + + " l_orderkey = o_orderkey\n" + + " and l_commitdate < l_receiptdate\n" + + " );"; + + // o_orderstatus is smaller than o_orderdate, but o_orderstatus is not used in this sql + // it is better to choose the column which is already used to represent count(*) + PlanChecker.from(connectContext) + .analyze(sql) + .rewrite() + .matches( + logicalResultSink( + logicalAggregate( + logicalProject().when( + project -> project.getProjects().size() == 1 + && project.getProjects().get(0) instanceof SlotReference + && "o_orderdate".equals(project.getProjects().get(0).toSql())))) + ); + } }