prune for agg with constant expr (#12274)

Currently, nereids doesn't support aggregate function with no slot reference in query, since all the column would be pruned, e.g.

SELECT COUNT(1) FROM t;

This PR reserve the column with the smallest amount of data when doing column prune under this situation.

To be noticed, this PR ONLY handle aggregate functions. So projection with no slot reference need to be handled in future.
This commit is contained in:
Kikyou1997
2022-09-05 19:09:00 +08:00
committed by GitHub
parent 8bfb89c100
commit dadfd85c40
25 changed files with 314 additions and 18 deletions

View File

@ -347,8 +347,4 @@ public class SlotDescriptor {
return parent.getTable() instanceof OlapTable;
}
public void setMaterialized(boolean materialized) {
isMaterialized = materialized;
}
}

View File

@ -21,6 +21,7 @@ import org.apache.doris.analysis.AggregateInfo;
import org.apache.doris.analysis.BaseTableRef;
import org.apache.doris.analysis.Expr;
import org.apache.doris.analysis.FunctionCallExpr;
import org.apache.doris.analysis.SlotDescriptor;
import org.apache.doris.analysis.SlotId;
import org.apache.doris.analysis.SlotRef;
import org.apache.doris.analysis.SortInfo;
@ -31,7 +32,6 @@ import org.apache.doris.catalog.OlapTable;
import org.apache.doris.catalog.Table;
import org.apache.doris.nereids.exceptions.AnalysisException;
import org.apache.doris.nereids.properties.OrderKey;
import org.apache.doris.nereids.trees.expressions.Alias;
import org.apache.doris.nereids.trees.expressions.EqualTo;
import org.apache.doris.nereids.trees.expressions.ExprId;
import org.apache.doris.nereids.trees.expressions.Expression;
@ -354,8 +354,12 @@ public class PhysicalPlanTranslator extends DefaultPlanVisitor<PlanFragment, Pla
.map(e -> ExpressionTranslator.translate(e, context))
.collect(Collectors.toList());
TupleDescriptor leftTuple = context.getTupleDesc(leftPlanRoot);
TupleDescriptor rightTuple = context.getTupleDesc(rightPlanRoot);
TupleDescriptor leftChildOutputTupleDesc = leftPlanRoot.getOutputTupleDesc();
TupleDescriptor leftTuple =
leftChildOutputTupleDesc != null ? leftChildOutputTupleDesc : context.getTupleDesc(leftPlanRoot);
TupleDescriptor rightChildOutputTupleDesc = rightPlanRoot.getOutputTupleDesc();
TupleDescriptor rightTuple =
rightChildOutputTupleDesc != null ? rightChildOutputTupleDesc : context.getTupleDesc(rightPlanRoot);
// Nereids does not care about output order of join,
// but BE need left child's output must be before right child's output.
@ -414,12 +418,6 @@ public class PhysicalPlanTranslator extends DefaultPlanVisitor<PlanFragment, Pla
public PlanFragment visitPhysicalProject(PhysicalProject<? extends Plan> project, PlanTranslatorContext context) {
PlanFragment inputFragment = project.child(0).accept(this, context);
// TODO: handle p.child(0) is not NamedExpression.
project.getProjects().stream().filter(Alias.class::isInstance).forEach(p -> {
SlotRef ref = context.findSlotRef(((NamedExpression) p.child(0)).getExprId());
context.addExprIdSlotRefPair(p.getExprId(), ref);
});
List<Expr> execExprList = project.getProjects()
.stream()
.map(e -> ExpressionTranslator.translate(e, context))
@ -464,11 +462,17 @@ public class PhysicalPlanTranslator extends DefaultPlanVisitor<PlanFragment, Pla
Set<Integer> slotIdSet = slotRefSet.stream()
.map(SlotRef::getSlotId).map(SlotId::asInt).collect(Collectors.toSet());
slotIdSet.addAll(requiredSlotIdList);
execPlan.getTupleIds().stream()
boolean noneMaterialized = execPlan.getTupleIds().stream()
.map(context::getTupleDesc)
.map(TupleDescriptor::getSlots)
.flatMap(List::stream)
.forEach(s -> s.setIsMaterialized(slotIdSet.contains(s.getId().asInt())));
.peek(s -> s.setIsMaterialized(slotIdSet.contains(s.getId().asInt())))
.filter(SlotDescriptor::isMaterialized)
.count() == 0;
if (noneMaterialized) {
context.getDescTable()
.getTupleDesc(execPlan.getTupleIds().get(0)).getSlots().get(0).setIsMaterialized(true);
}
}
@Override

View File

@ -60,8 +60,9 @@ public class RewriteJob extends BatchRulesJob {
.add(topDownBatch(ImmutableList.of(new ReorderJoin())))
.add(topDownBatch(ImmutableList.of(new FindHashConditionForJoin())))
.add(topDownBatch(ImmutableList.of(new PushPredicateThroughJoin())))
.add(topDownBatch(ImmutableList.of(new AggregateDisassemble())))
.add(topDownBatch(ImmutableList.of(new NormalizeAggregate())))
.add(topDownBatch(ImmutableList.of(new ColumnPruning())))
.add(topDownBatch(ImmutableList.of(new AggregateDisassemble())))
.add(topDownBatch(ImmutableList.of(new SwapFilterAndProject())))
.add(bottomUpBatch(ImmutableList.of(new MergeConsecutiveProjects())))
.add(topDownBatch(ImmutableList.of(new MergeConsecutiveFilters())))

View File

@ -101,7 +101,7 @@ public class BindFunction implements AnalysisRuleFactory {
if (arguments.size() > 1 || (arguments.size() == 0 && !unboundFunction.isStar())) {
return unboundFunction;
}
if (unboundFunction.isStar()) {
if (unboundFunction.isStar() || arguments.stream().allMatch(Expression::isConstant)) {
return new Count();
}
return new Count(unboundFunction.getArguments().get(0));

View File

@ -23,6 +23,9 @@ import org.apache.doris.nereids.rules.rewrite.OneRewriteRuleFactory;
import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.expressions.NamedExpression;
import org.apache.doris.nereids.trees.expressions.Slot;
import org.apache.doris.nereids.trees.expressions.SlotReference;
import org.apache.doris.nereids.trees.plans.GroupPlan;
import org.apache.doris.nereids.trees.plans.logical.LogicalAggregate;
import org.apache.doris.nereids.trees.plans.logical.LogicalProject;
import org.apache.doris.nereids.util.SlotExtractor;
@ -54,10 +57,18 @@ public class PruneAggChildColumns extends OneRewriteRuleFactory {
@Override
public Rule build() {
return RuleType.COLUMN_PRUNE_AGGREGATION_CHILD.build(logicalAggregate().then(agg -> {
List<Slot> childOutput = agg.child().getOutput();
if (isAggregateWithConstant(agg)) {
Slot slot = selectMinimumColumn(childOutput);
if (childOutput.size() == 1 && childOutput.get(0).equals(slot)) {
return agg;
}
return agg.withChildren(ImmutableList.of(new LogicalProject<>(ImmutableList.of(slot), agg.child())));
}
List<Expression> slots = Lists.newArrayList();
slots.addAll(agg.getExpressions());
Set<Slot> outputs = SlotExtractor.extractSlot(slots);
List<NamedExpression> prunedOutputs = agg.child().getOutput().stream().filter(outputs::contains)
List<NamedExpression> prunedOutputs = childOutput.stream().filter(outputs::contains)
.collect(Collectors.toList());
if (prunedOutputs.size() == agg.child().getOutput().size()) {
return agg;
@ -65,4 +76,32 @@ public class PruneAggChildColumns extends OneRewriteRuleFactory {
return agg.withChildren(ImmutableList.of(new LogicalProject<>(prunedOutputs, agg.child())));
}));
}
/**
* For these aggregate function with constant param. Such as:
* count(*), count(1), sum(1)..etc.
* @return null, if there exists an aggregation function that its parameters contains non-constant expr.
* else return a slot with min data type.
*/
private boolean isAggregateWithConstant(LogicalAggregate<GroupPlan> agg) {
for (NamedExpression output : agg.getOutputExpressions()) {
if (output.anyMatch(SlotReference.class::isInstance)) {
return false;
}
}
return true;
}
private Slot selectMinimumColumn(List<Slot> outputList) {
Slot minSlot = null;
for (Slot slot : outputList) {
if (minSlot == null) {
minSlot = slot;
} else {
int slotDataTypeWidth = slot.getDataType().width();
minSlot = minSlot.getDataType().width() > slotDataTypeWidth ? slot : minSlot;
}
}
return minSlot;
}
}

View File

@ -44,6 +44,7 @@ import java.util.stream.Stream;
* |
* scan(k1,k2,k3,v1)
* transformed:
*  project(k1)
* |
* filter(k2 > 3)
* |

View File

@ -24,8 +24,11 @@ import org.apache.doris.nereids.types.coercion.IntegralType;
* BigInt data type in Nereids.
*/
public class BigIntType extends IntegralType {
public static BigIntType INSTANCE = new BigIntType();
private static final int WIDTH = 8;
private BigIntType() {
}
@ -48,4 +51,9 @@ public class BigIntType extends IntegralType {
public DataType defaultConcreteType() {
return this;
}
@Override
public int width() {
return WIDTH;
}
}

View File

@ -26,6 +26,8 @@ import org.apache.doris.nereids.types.coercion.PrimitiveType;
public class BooleanType extends PrimitiveType {
public static BooleanType INSTANCE = new BooleanType();
private static int WIDTH = 1;
private BooleanType() {
}
@ -38,4 +40,9 @@ public class BooleanType extends PrimitiveType {
public String simpleString() {
return "boolean";
}
@Override
public int width() {
return WIDTH;
}
}

View File

@ -227,4 +227,6 @@ public abstract class DataType implements AbstractDataType {
return this;
}
}
public abstract int width();
}

View File

@ -27,6 +27,8 @@ public class DateTimeType extends PrimitiveType {
public static DateTimeType INSTANCE = new DateTimeType();
private static final int WIDTH = 16;
private DateTimeType() {
}
@ -39,4 +41,9 @@ public class DateTimeType extends PrimitiveType {
public boolean equals(Object o) {
return o instanceof DateTimeType;
}
@Override
public int width() {
return WIDTH;
}
}

View File

@ -27,6 +27,8 @@ public class DateType extends PrimitiveType {
public static DateType INSTANCE = new DateType();
private static final int WIDTH = 16;
private DateType() {
}
@ -34,5 +36,10 @@ public class DateType extends PrimitiveType {
public Type toCatalogDataType() {
return Type.DATE;
}
@Override
public int width() {
return WIDTH;
}
}

View File

@ -46,6 +46,8 @@ public class DecimalType extends FractionalType {
private static final DecimalType FLOAT_DECIMAL = new DecimalType(14, 7);
private static final DecimalType DOUBLE_DECIMAL = new DecimalType(30, 15);
private static final int WIDTH = 16;
private static final Map<DataType, DecimalType> FOR_TYPE_MAP = ImmutableMap.<DataType, DecimalType>builder()
.put(TinyIntType.INSTANCE, TINYINT_DECIMAL)
.put(SmallIntType.INSTANCE, SMALLINT_DECIMAL)
@ -160,5 +162,11 @@ public class DecimalType extends FractionalType {
public int hashCode() {
return Objects.hash(super.hashCode(), precision, scale);
}
@Override
public int width() {
return WIDTH;
}
}

View File

@ -26,6 +26,8 @@ import org.apache.doris.nereids.types.coercion.FractionalType;
public class DoubleType extends FractionalType {
public static DoubleType INSTANCE = new DoubleType();
private static final int WIDTH = 8;
private DoubleType() {
}
@ -53,4 +55,9 @@ public class DoubleType extends FractionalType {
public DataType defaultConcreteType() {
return this;
}
@Override
public int width() {
return WIDTH;
}
}

View File

@ -26,6 +26,8 @@ import org.apache.doris.nereids.types.coercion.FractionalType;
public class FloatType extends FractionalType {
public static FloatType INSTANCE = new FloatType();
private static final int WIDTH = 4;
private FloatType() {
}
@ -53,4 +55,9 @@ public class FloatType extends FractionalType {
public String simpleString() {
return "float";
}
@Override
public int width() {
return WIDTH;
}
}

View File

@ -26,6 +26,8 @@ import org.apache.doris.nereids.types.coercion.IntegralType;
public class IntegerType extends IntegralType {
public static IntegerType INSTANCE = new IntegerType();
private static final int WIDTH = 4;
private IntegerType() {
}
@ -53,4 +55,9 @@ public class IntegerType extends IntegralType {
public DataType defaultConcreteType() {
return this;
}
@Override
public int width() {
return WIDTH;
}
}

View File

@ -26,6 +26,8 @@ import org.apache.doris.nereids.types.coercion.IntegralType;
public class LargeIntType extends IntegralType {
public static LargeIntType INSTANCE = new LargeIntType();
private static final int WIDTH = 16;
private LargeIntType() {
}
@ -53,4 +55,9 @@ public class LargeIntType extends IntegralType {
public DataType defaultConcreteType() {
return this;
}
@Override
public int width() {
return WIDTH;
}
}

View File

@ -26,6 +26,8 @@ import org.apache.doris.nereids.types.coercion.PrimitiveType;
public class NullType extends PrimitiveType {
public static NullType INSTANCE = new NullType();
private static final int WIDTH = 1;
private NullType() {
}
@ -33,4 +35,9 @@ public class NullType extends PrimitiveType {
public Type toCatalogDataType() {
return Type.NULL;
}
@Override
public int width() {
return WIDTH;
}
}

View File

@ -26,6 +26,8 @@ import org.apache.doris.nereids.types.coercion.IntegralType;
public class SmallIntType extends IntegralType {
public static SmallIntType INSTANCE = new SmallIntType();
private static final int WIDTH = 2;
private SmallIntType() {
}
@ -53,4 +55,9 @@ public class SmallIntType extends IntegralType {
public DataType defaultConcreteType() {
return this;
}
@Override
public int width() {
return WIDTH;
}
}

View File

@ -26,6 +26,8 @@ import org.apache.doris.nereids.types.coercion.IntegralType;
public class TinyIntType extends IntegralType {
public static TinyIntType INSTANCE = new TinyIntType();
private static final int WIDTH = 1;
private TinyIntType() {
}
@ -53,4 +55,9 @@ public class TinyIntType extends IntegralType {
public DataType defaultConcreteType() {
return this;
}
@Override
public int width() {
return WIDTH;
}
}

View File

@ -28,6 +28,8 @@ public class CharacterType extends PrimitiveType {
public static final CharacterType INSTANCE = new CharacterType(-1);
private static final int WIDTH = 16;
protected final int len;
public CharacterType(int len) {
@ -52,4 +54,9 @@ public class CharacterType extends PrimitiveType {
public DataType defaultConcreteType() {
return StringType.INSTANCE;
}
@Override
public int width() {
return WIDTH;
}
}

View File

@ -40,4 +40,9 @@ public class FractionalType extends NumericType {
public String simpleString() {
return "fractional";
}
@Override
public int width() {
throw new RuntimeException("Unimplemented exception");
}
}

View File

@ -29,4 +29,9 @@ public abstract class PrimitiveType extends DataType {
public String toSql() {
return simpleString().toUpperCase(Locale.ROOT);
}
@Override
public int width() {
throw new RuntimeException("Unimplemented exception");
}
}

View File

@ -20,6 +20,8 @@ package org.apache.doris.nereids.rules.rewrite.logical;
import org.apache.doris.nereids.trees.expressions.NamedExpression;
import org.apache.doris.nereids.trees.plans.Plan;
import org.apache.doris.nereids.trees.plans.logical.LogicalProject;
import org.apache.doris.nereids.types.DoubleType;
import org.apache.doris.nereids.types.IntegerType;
import org.apache.doris.nereids.util.PatternMatchSupported;
import org.apache.doris.nereids.util.PlanChecker;
import org.apache.doris.utframe.TestWithFeService;
@ -177,6 +179,96 @@ public class ColumnPruningTest extends TestWithFeService implements PatternMatch
);
}
@Test
public void pruneCountStarStmt() {
PlanChecker.from(connectContext)
.analyze("SELECT COUNT(*) FROM test.course")
.applyTopDown(new ColumnPruning())
.matchesFromRoot(
logicalAggregate(
logicalProject(
logicalOlapScan()
).when(p -> p.getProjects().get(0).getDataType().equals(IntegerType.INSTANCE)
&& p.getProjects().size() == 1)
)
);
}
@Test
public void pruneCountConstantStmt() {
PlanChecker.from(connectContext)
.analyze("SELECT COUNT(1) FROM test.course")
.applyTopDown(new ColumnPruning())
.matchesFromRoot(
logicalAggregate(
logicalProject(
logicalOlapScan()
).when(p -> p.getProjects().get(0).getDataType().equals(IntegerType.INSTANCE)
&& p.getProjects().size() == 1)
)
);
}
@Test
public void pruneCountConstantAndSumConstantStmt() {
PlanChecker.from(connectContext)
.analyze("SELECT COUNT(1), SUM(2) FROM test.course")
.applyTopDown(new ColumnPruning())
.matchesFromRoot(
logicalAggregate(
logicalProject(
logicalOlapScan()
).when(p -> p.getProjects().get(0).getDataType().equals(IntegerType.INSTANCE)
&& p.getProjects().size() == 1)
)
);
}
@Test
public void pruneCountStarAndSumConstantStmt() {
PlanChecker.from(connectContext)
.analyze("SELECT COUNT(*), SUM(2) FROM test.course")
.applyTopDown(new ColumnPruning())
.matchesFromRoot(
logicalAggregate(
logicalProject(
logicalOlapScan()
).when(p -> p.getProjects().get(0).getDataType().equals(IntegerType.INSTANCE)
&& p.getProjects().size() == 1)
)
);
}
@Test
public void pruneCountStarAndSumColumnStmt() {
PlanChecker.from(connectContext)
.analyze("SELECT COUNT(*), SUM(grade) FROM test.score")
.applyTopDown(new ColumnPruning())
.matchesFromRoot(
logicalAggregate(
logicalProject(
logicalOlapScan()
).when(p -> p.getProjects().get(0).getDataType().equals(DoubleType.INSTANCE)
&& p.getProjects().size() == 1)
)
);
}
@Test
public void pruneCountStarAndSumColumnAndSumConstantStmt() {
PlanChecker.from(connectContext)
.analyze("SELECT COUNT(*), SUM(grade) + SUM(2) FROM test.score")
.applyTopDown(new ColumnPruning())
.matchesFromRoot(
logicalAggregate(
logicalProject(
logicalOlapScan()
).when(p -> p.getProjects().get(0).getDataType().equals(DoubleType.INSTANCE)
&& p.getProjects().size() == 1)
)
);
}
private List<String> getOutputQualifiedNames(LogicalProject<? extends Plan> p) {
return p.getProjects().stream().map(NamedExpression::getQualifiedName).collect(Collectors.toList());
}

View File

@ -0,0 +1,10 @@
-- This file is automatically generated. You should know what you did if you want to edit this
-- !select --
2 1996
-- !select --
1
-- !select --
2 4

View File

@ -0,0 +1,48 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.
suite("agg_with_const") {
sql """
DROP TABLE IF EXISTS t1
"""
sql """CREATE TABLE t1 (col1 int not null, col2 int not null, col3 int not null)
DISTRIBUTED BY HASH(col3)
BUCKETS 1
PROPERTIES(
"replication_num"="1"
)
"""
sql """
insert into t1 values(1994, 1994, 1995)
"""
qt_select """
select count(2) + 1, sum(2) + sum(col1) from t1
"""
qt_select """
select count(*) from t1
"""
qt_select """
select count(2) + 1, sum(2) + sum(2) from t1
"""
}