[improve](agg)support push down min/max on unique table (#29242)

This commit is contained in:
zhangstar333
2024-01-02 19:40:23 +08:00
committed by GitHub
parent 3eca457edd
commit af39217d14
5 changed files with 473 additions and 0 deletions

View File

@ -378,6 +378,8 @@ public enum RuleType {
STORAGE_LAYER_AGGREGATE_WITH_PROJECT(RuleTypeClass.IMPLEMENTATION),
STORAGE_LAYER_AGGREGATE_WITHOUT_PROJECT_FOR_FILE_SCAN(RuleTypeClass.IMPLEMENTATION),
STORAGE_LAYER_AGGREGATE_WITH_PROJECT_FOR_FILE_SCAN(RuleTypeClass.IMPLEMENTATION),
STORAGE_LAYER_AGGREGATE_MINMAX_ON_UNIQUE(RuleTypeClass.IMPLEMENTATION),
STORAGE_LAYER_AGGREGATE_MINMAX_ON_UNIQUE_WITHOUT_PROJECT(RuleTypeClass.IMPLEMENTATION),
COUNT_ON_INDEX(RuleTypeClass.IMPLEMENTATION),
COUNT_ON_INDEX_WITHOUT_PROJECT(RuleTypeClass.IMPLEMENTATION),
ONE_PHASE_AGGREGATE_WITHOUT_DISTINCT(RuleTypeClass.IMPLEMENTATION),

View File

@ -40,12 +40,15 @@ import org.apache.doris.nereids.trees.expressions.Cast;
import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.expressions.IsNull;
import org.apache.doris.nereids.trees.expressions.NamedExpression;
import org.apache.doris.nereids.trees.expressions.Slot;
import org.apache.doris.nereids.trees.expressions.SlotReference;
import org.apache.doris.nereids.trees.expressions.functions.ExpressionTrait;
import org.apache.doris.nereids.trees.expressions.functions.agg.AggregateFunction;
import org.apache.doris.nereids.trees.expressions.functions.agg.AggregateParam;
import org.apache.doris.nereids.trees.expressions.functions.agg.Count;
import org.apache.doris.nereids.trees.expressions.functions.agg.GroupConcat;
import org.apache.doris.nereids.trees.expressions.functions.agg.Max;
import org.apache.doris.nereids.trees.expressions.functions.agg.Min;
import org.apache.doris.nereids.trees.expressions.functions.agg.MultiDistinctCount;
import org.apache.doris.nereids.trees.expressions.functions.agg.MultiDistinctSum;
import org.apache.doris.nereids.trees.expressions.functions.agg.Sum;
@ -140,6 +143,72 @@ public class AggregateStrategies implements ImplementationRuleFactory {
return pushdownCountOnIndex(agg, project, filter, olapScan, ctx.cascadesContext);
})
),
RuleType.STORAGE_LAYER_AGGREGATE_MINMAX_ON_UNIQUE_WITHOUT_PROJECT.build(
logicalAggregate(
logicalFilter(
logicalOlapScan().when(this::isUniqueKeyTable))
.when(filter -> {
if (filter.getConjuncts().size() != 1) {
return false;
}
Expression childExpr = filter.getConjuncts().iterator().next().children().get(0);
if (childExpr instanceof SlotReference) {
Optional<Column> column = ((SlotReference) childExpr).getColumn();
return column.isPresent() ? column.get().isDeleteSignColumn() : false;
}
return false;
})
)
.when(agg -> enablePushDownMinMaxOnUnique())
.when(agg -> agg.getGroupByExpressions().isEmpty())
.when(agg -> {
Set<AggregateFunction> funcs = agg.getAggregateFunctions();
return !funcs.isEmpty() && funcs.stream()
.allMatch(f -> (f instanceof Min) || (f instanceof Max));
})
.thenApply(ctx -> {
LogicalAggregate<LogicalFilter<LogicalOlapScan>> agg = ctx.root;
LogicalFilter<LogicalOlapScan> filter = agg.child();
LogicalOlapScan olapScan = filter.child();
return pushdownMinMaxOnUniqueTable(agg, null, filter, olapScan,
ctx.cascadesContext);
})
),
RuleType.STORAGE_LAYER_AGGREGATE_MINMAX_ON_UNIQUE.build(
logicalAggregate(
logicalProject(
logicalFilter(
logicalOlapScan().when(this::isUniqueKeyTable))
.when(filter -> {
if (filter.getConjuncts().size() != 1) {
return false;
}
Expression childExpr = filter.getConjuncts().iterator().next()
.children().get(0);
if (childExpr instanceof SlotReference) {
Optional<Column> column = ((SlotReference) childExpr).getColumn();
return column.isPresent() ? column.get().isDeleteSignColumn()
: false;
}
return false;
}))
)
.when(agg -> enablePushDownMinMaxOnUnique())
.when(agg -> agg.getGroupByExpressions().isEmpty())
.when(agg -> {
Set<AggregateFunction> funcs = agg.getAggregateFunctions();
return !funcs.isEmpty()
&& funcs.stream().allMatch(f -> (f instanceof Min) || (f instanceof Max));
})
.thenApply(ctx -> {
LogicalAggregate<LogicalProject<LogicalFilter<LogicalOlapScan>>> agg = ctx.root;
LogicalProject<LogicalFilter<LogicalOlapScan>> project = agg.child();
LogicalFilter<LogicalOlapScan> filter = project.child();
LogicalOlapScan olapScan = filter.child();
return pushdownMinMaxOnUniqueTable(agg, project, filter, olapScan,
ctx.cascadesContext);
})
),
RuleType.STORAGE_LAYER_AGGREGATE_WITHOUT_PROJECT.build(
logicalAggregate(
logicalOlapScan()
@ -238,6 +307,19 @@ public class AggregateStrategies implements ImplementationRuleFactory {
);
}
private boolean enablePushDownMinMaxOnUnique() {
ConnectContext connectContext = ConnectContext.get();
return connectContext != null && connectContext.getSessionVariable().isEnablePushDownMinMaxOnUnique();
}
private boolean isUniqueKeyTable(LogicalOlapScan logicalScan) {
if (logicalScan != null) {
KeysType keysType = logicalScan.getTable().getKeysType();
return keysType == KeysType.UNIQUE_KEYS;
}
return false;
}
private boolean enablePushDownCountOnIndex() {
ConnectContext connectContext = ConnectContext.get();
return connectContext != null && connectContext.getSessionVariable().isEnablePushDownCountOnIndex();
@ -314,6 +396,90 @@ public class AggregateStrategies implements ImplementationRuleFactory {
}
}
//select /*+SET_VAR(enable_pushdown_minmax_on_unique=true) */min(user_id) from table_unique;
//push pushAggOp=MINMAX to scan node
private LogicalAggregate<? extends Plan> pushdownMinMaxOnUniqueTable(
LogicalAggregate<? extends Plan> aggregate,
@Nullable LogicalProject<? extends Plan> project,
LogicalFilter<? extends Plan> filter,
LogicalOlapScan olapScan,
CascadesContext cascadesContext) {
final LogicalAggregate<? extends Plan> canNotPush = aggregate;
Set<AggregateFunction> aggregateFunctions = aggregate.getAggregateFunctions();
if (checkWhetherPushDownMinMax(aggregateFunctions, project, olapScan.getOutput())) {
PhysicalOlapScan physicalOlapScan = (PhysicalOlapScan) new LogicalOlapScanToPhysicalOlapScan()
.build()
.transform(olapScan, cascadesContext)
.get(0);
if (project != null) {
return aggregate.withChildren(ImmutableList.of(
project.withChildren(ImmutableList.of(
filter.withChildren(ImmutableList.of(
new PhysicalStorageLayerAggregate(
physicalOlapScan,
PushDownAggOp.MIN_MAX)))))));
} else {
return aggregate.withChildren(ImmutableList.of(
filter.withChildren(ImmutableList.of(
new PhysicalStorageLayerAggregate(
physicalOlapScan,
PushDownAggOp.MIN_MAX)))));
}
} else {
return canNotPush;
}
}
private boolean checkWhetherPushDownMinMax(Set<AggregateFunction> aggregateFunctions,
@Nullable LogicalProject<? extends Plan> project, List<Slot> outPutSlots) {
boolean onlyContainsSlotOrNumericCastSlot = aggregateFunctions.stream()
.map(ExpressionTrait::getArguments)
.flatMap(List::stream)
.allMatch(argument -> {
if (argument instanceof SlotReference) {
return true;
}
return false;
});
if (!onlyContainsSlotOrNumericCastSlot) {
return false;
}
List<Expression> argumentsOfAggregateFunction = aggregateFunctions.stream()
.flatMap(aggregateFunction -> aggregateFunction.getArguments().stream())
.collect(ImmutableList.toImmutableList());
if (project != null) {
argumentsOfAggregateFunction = Project.findProject(
argumentsOfAggregateFunction, project.getProjects())
.stream()
.map(p -> p instanceof Alias ? p.child(0) : p)
.collect(ImmutableList.toImmutableList());
}
onlyContainsSlotOrNumericCastSlot = argumentsOfAggregateFunction
.stream()
.allMatch(argument -> {
if (argument instanceof SlotReference) {
return true;
}
return false;
});
if (!onlyContainsSlotOrNumericCastSlot) {
return false;
}
Set<SlotReference> aggUsedSlots = ExpressionUtils.collect(argumentsOfAggregateFunction,
SlotReference.class::isInstance);
List<SlotReference> usedSlotInTable = (List<SlotReference>) Project.findProject(aggUsedSlots,
outPutSlots);
for (SlotReference slot : usedSlotInTable) {
Column column = slot.getColumn().get();
PrimitiveType colType = column.getType().getPrimitiveType();
if (colType.isComplexType() || colType.isHllType() || colType.isBitmapType()) {
return false;
}
}
return true;
}
/**
* sql: select count(*) from tbl
* <p>

View File

@ -480,6 +480,8 @@ public class SessionVariable implements Serializable, Writable {
public static final String MATERIALIZED_VIEW_REWRITE_ENABLE_CONTAIN_FOREIGN_TABLE
= "materialized_view_rewrite_enable_contain_foreign_table";
public static final String ENABLE_PUSHDOWN_MINMAX_ON_UNIQUE = "enable_pushdown_minmax_on_unique";
// When set use fix replica = true, the fixed replica maybe bad, try to use the health one if
// this session variable is set to true.
public static final String FALLBACK_OTHER_REPLICA_WHEN_FIXED_CORRUPT = "fallback_other_replica_when_fixed_corrupt";
@ -1221,6 +1223,11 @@ public class SessionVariable implements Serializable, Writable {
"是否启用count_on_index pushdown。", "Set whether to pushdown count_on_index."})
public boolean enablePushDownCountOnIndex = true;
// Whether enable pushdown minmax to scan node of unique table.
@VariableMgr.VarAttr(name = ENABLE_PUSHDOWN_MINMAX_ON_UNIQUE, needForward = true, description = {
"是否启用pushdown minmax on unique table。", "Set whether to pushdown minmax on unique table."})
public boolean enablePushDownMinMaxOnUnique = false;
// Whether drop table when create table as select insert data appear error.
@VariableMgr.VarAttr(name = DROP_TABLE_IF_CTAS_FAILED, needForward = true)
public boolean dropTableIfCtasFailed = true;
@ -2438,6 +2445,14 @@ public class SessionVariable implements Serializable, Writable {
this.disableJoinReorder = disableJoinReorder;
}
public boolean isEnablePushDownMinMaxOnUnique() {
return enablePushDownMinMaxOnUnique;
}
public void setEnablePushDownMinMaxOnUnique(boolean enablePushDownMinMaxOnUnique) {
this.enablePushDownMinMaxOnUnique = enablePushDownMinMaxOnUnique;
}
/**
* Nereids only support vectorized engine.
*