[enhancement](Nereids) support 4 phases distinct aggregate with full distribution (#36016)
cherry pick from #35871
This commit is contained in:
@ -113,16 +113,6 @@ public class ChildrenPropertiesRegulator extends PlanVisitor<Boolean, Void> {
|
||||
// this means one stage gather agg, usually bad pattern
|
||||
return false;
|
||||
}
|
||||
// forbid three or four stage distinct agg inter by distribute
|
||||
if (agg.getAggMode() == AggMode.BUFFER_TO_BUFFER && children.get(0).getPlan() instanceof PhysicalDistribute) {
|
||||
// if distinct without group by key, we prefer three or four stage distinct agg
|
||||
// because the second phase of multi-distinct only have one instance, and it is slow generally.
|
||||
if (agg.getGroupByExpressions().size() == 1
|
||||
&& agg.getOutputExpressions().size() == 1) {
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
// forbid TWO_PHASE_AGGREGATE_WITH_DISTINCT after shuffle
|
||||
// TODO: this is forbid good plan after cte reuse by mistake
|
||||
|
||||
@ -443,6 +443,7 @@ public enum RuleType {
|
||||
TWO_PHASE_AGGREGATE_WITH_MULTI_DISTINCT(RuleTypeClass.IMPLEMENTATION),
|
||||
THREE_PHASE_AGGREGATE_WITH_DISTINCT(RuleTypeClass.IMPLEMENTATION),
|
||||
FOUR_PHASE_AGGREGATE_WITH_DISTINCT(RuleTypeClass.IMPLEMENTATION),
|
||||
FOUR_PHASE_AGGREGATE_WITH_DISTINCT_WITH_FULL_DISTRIBUTE(RuleTypeClass.IMPLEMENTATION),
|
||||
LOGICAL_UNION_TO_PHYSICAL_UNION(RuleTypeClass.IMPLEMENTATION),
|
||||
LOGICAL_EXCEPT_TO_PHYSICAL_EXCEPT(RuleTypeClass.IMPLEMENTATION),
|
||||
LOGICAL_INTERSECT_TO_PHYSICAL_INTERSECT(RuleTypeClass.IMPLEMENTATION),
|
||||
|
||||
@ -75,6 +75,7 @@ import org.apache.doris.nereids.util.ExpressionUtils;
|
||||
import org.apache.doris.nereids.util.TypeCoercionUtils;
|
||||
import org.apache.doris.qe.ConnectContext;
|
||||
|
||||
import com.google.common.base.Function;
|
||||
import com.google.common.base.Preconditions;
|
||||
import com.google.common.collect.ImmutableList;
|
||||
import com.google.common.collect.ImmutableMap;
|
||||
@ -293,15 +294,89 @@ public class AggregateStrategies implements ImplementationRuleFactory {
|
||||
// .thenApplyMulti(ctx -> twoPhaseAggregateWithDistinct(ctx.root, ctx.connectContext))
|
||||
// ),
|
||||
RuleType.THREE_PHASE_AGGREGATE_WITH_DISTINCT.build(
|
||||
basePattern
|
||||
.when(agg -> agg.getDistinctArguments().size() == 1)
|
||||
.thenApplyMulti(ctx -> threePhaseAggregateWithDistinct(ctx.root, ctx.connectContext))
|
||||
basePattern
|
||||
.when(agg -> agg.getDistinctArguments().size() == 1)
|
||||
.thenApplyMulti(ctx -> threePhaseAggregateWithDistinct(ctx.root, ctx.connectContext))
|
||||
),
|
||||
/*
|
||||
* sql:
|
||||
* select count(distinct name), sum(age) from student;
|
||||
* <p>
|
||||
* 4 phase plan
|
||||
* DISTINCT_GLOBAL(BUFFER_TO_RESULT, groupBy(),
|
||||
* output[count(partial_count(name)), sum(partial_sum(partial_sum(age)))],
|
||||
* GATHER)
|
||||
* +--DISTINCT_LOCAL(INPUT_TO_BUFFER, groupBy(),
|
||||
* output(partial_count(name), partial_sum(partial_sum(age))),
|
||||
* hash distribute by name)
|
||||
* +--GLOBAL(BUFFER_TO_BUFFER, groupBy(name),
|
||||
* output(name, partial_sum(age)),
|
||||
* hash_distribute by name)
|
||||
* +--LOCAL(INPUT_TO_BUFFER, groupBy(name), output(name, partial_sum(age)))
|
||||
* +--scan(name, age)
|
||||
*/
|
||||
RuleType.FOUR_PHASE_AGGREGATE_WITH_DISTINCT.build(
|
||||
basePattern
|
||||
.when(agg -> agg.getDistinctArguments().size() == 1)
|
||||
.when(agg -> agg.getGroupByExpressions().isEmpty())
|
||||
.thenApplyMulti(ctx -> fourPhaseAggregateWithDistinct(ctx.root, ctx.connectContext))
|
||||
basePattern
|
||||
.when(agg -> agg.getDistinctArguments().size() == 1)
|
||||
.when(agg -> agg.getGroupByExpressions().isEmpty())
|
||||
.thenApplyMulti(ctx -> {
|
||||
Function<List<Expression>, RequireProperties> secondPhaseRequireDistinctHash =
|
||||
groupByAndDistinct -> RequireProperties.of(
|
||||
PhysicalProperties.createHash(
|
||||
ctx.root.getDistinctArguments(), ShuffleType.REQUIRE
|
||||
)
|
||||
);
|
||||
Function<LogicalAggregate<? extends Plan>, RequireProperties> fourPhaseRequireGather =
|
||||
agg -> RequireProperties.of(PhysicalProperties.GATHER);
|
||||
return fourPhaseAggregateWithDistinct(
|
||||
ctx.root, ctx.connectContext,
|
||||
secondPhaseRequireDistinctHash, fourPhaseRequireGather
|
||||
);
|
||||
})
|
||||
),
|
||||
/*
|
||||
* sql:
|
||||
* select age, count(distinct name) from student group by age;
|
||||
* <p>
|
||||
* 4 phase plan
|
||||
* DISTINCT_GLOBAL(BUFFER_TO_RESULT, groupBy(age),
|
||||
* output[age, sum(partial_count(name))],
|
||||
* hash distribute by name)
|
||||
* +--DISTINCT_LOCAL(INPUT_TO_BUFFER, groupBy(age),
|
||||
* output(age, partial_count(name)),
|
||||
* hash distribute by age, name)
|
||||
* +--GLOBAL(BUFFER_TO_BUFFER, groupBy(age, name),
|
||||
* output(age, name),
|
||||
* hash_distribute by age, name)
|
||||
* +--LOCAL(INPUT_TO_BUFFER, groupBy(age, name), output(age, name))
|
||||
* +--scan(age, name)
|
||||
*/
|
||||
RuleType.FOUR_PHASE_AGGREGATE_WITH_DISTINCT_WITH_FULL_DISTRIBUTE.build(
|
||||
basePattern
|
||||
.when(agg -> agg.everyDistinctArgumentNumIsOne() && !agg.getGroupByExpressions().isEmpty())
|
||||
.when(agg ->
|
||||
ImmutableSet.builder()
|
||||
.addAll(agg.getGroupByExpressions())
|
||||
.addAll(agg.getDistinctArguments())
|
||||
.build().size() > agg.getGroupByExpressions().size()
|
||||
)
|
||||
.thenApplyMulti(ctx -> {
|
||||
Function<List<Expression>, RequireProperties> secondPhaseRequireGroupByAndDistinctHash =
|
||||
groupByAndDistinct -> RequireProperties.of(
|
||||
PhysicalProperties.createHash(groupByAndDistinct, ShuffleType.REQUIRE)
|
||||
);
|
||||
|
||||
Function<LogicalAggregate<? extends Plan>, RequireProperties> fourPhaseRequireGroupByHash =
|
||||
agg -> RequireProperties.of(
|
||||
PhysicalProperties.createHash(
|
||||
agg.getGroupByExpressions(), ShuffleType.REQUIRE
|
||||
)
|
||||
);
|
||||
return fourPhaseAggregateWithDistinct(
|
||||
ctx.root, ctx.connectContext,
|
||||
secondPhaseRequireGroupByAndDistinctHash, fourPhaseRequireGroupByHash
|
||||
);
|
||||
})
|
||||
)
|
||||
);
|
||||
}
|
||||
@ -1649,19 +1724,10 @@ public class AggregateStrategies implements ImplementationRuleFactory {
|
||||
return connectContext == null || connectContext.getSessionVariable().enablePushDownNoGroupAgg();
|
||||
}
|
||||
|
||||
/**
|
||||
* sql:
|
||||
* select count(distinct name), sum(age) from student;
|
||||
* <p>
|
||||
* 4 phase plan
|
||||
* DISTINCT_GLOBAL, BUFFER_TO_RESULT groupBy(), output[count(name), sum(age#5)], [GATHER]
|
||||
* +--DISTINCT_LOCAL, INPUT_TO_BUFFER, groupBy()), output(count(name), partial_sum(age)), hash distribute by name
|
||||
* +--GLOBAL, BUFFER_TO_BUFFER, groupBy(name), output(name, partial_sum(age)), hash_distribute by name
|
||||
* +--LOCAL, INPUT_TO_BUFFER, groupBy(name), output(name, partial_sum(age))
|
||||
* +--scan(name, age)
|
||||
*/
|
||||
private List<PhysicalHashAggregate<? extends Plan>> fourPhaseAggregateWithDistinct(
|
||||
LogicalAggregate<? extends Plan> logicalAgg, ConnectContext connectContext) {
|
||||
LogicalAggregate<? extends Plan> logicalAgg, ConnectContext connectContext,
|
||||
Function<List<Expression>, RequireProperties> secondPhaseRequireSupplier,
|
||||
Function<LogicalAggregate<? extends Plan>, RequireProperties> fourPhaseRequireSupplier) {
|
||||
boolean couldBanned = couldConvertToMulti(logicalAgg);
|
||||
|
||||
Set<AggregateFunction> aggregateFunctions = logicalAgg.getAggregateFunctions();
|
||||
@ -1734,16 +1800,13 @@ public class AggregateStrategies implements ImplementationRuleFactory {
|
||||
globalAggOutput = ImmutableList.of(new Alias(new NullLiteral(TinyIntType.INSTANCE)));
|
||||
}
|
||||
|
||||
RequireProperties requireGather = RequireProperties.of(PhysicalProperties.GATHER);
|
||||
|
||||
RequireProperties requireDistinctHash = RequireProperties.of(
|
||||
PhysicalProperties.createHash(logicalAgg.getDistinctArguments(), ShuffleType.REQUIRE));
|
||||
RequireProperties secondPhaseRequire = secondPhaseRequireSupplier.apply(localAggGroupBy);
|
||||
|
||||
//phase 2
|
||||
PhysicalHashAggregate<? extends Plan> anyLocalHashGlobalAgg = new PhysicalHashAggregate<>(
|
||||
localAggGroupBy, globalAggOutput, Optional.of(ImmutableList.copyOf(logicalAgg.getDistinctArguments())),
|
||||
bufferToBufferParam, false, logicalAgg.getLogicalProperties(),
|
||||
requireDistinctHash, anyLocalAgg);
|
||||
secondPhaseRequire, anyLocalAgg);
|
||||
|
||||
// phase 3
|
||||
AggregateParam distinctLocalParam = new AggregateParam(
|
||||
@ -1787,7 +1850,7 @@ public class AggregateStrategies implements ImplementationRuleFactory {
|
||||
PhysicalHashAggregate<? extends Plan> distinctLocal = new PhysicalHashAggregate<>(
|
||||
logicalAgg.getGroupByExpressions(), localDistinctOutput, Optional.empty(),
|
||||
distinctLocalParam, false, logicalAgg.getLogicalProperties(),
|
||||
requireDistinctHash, anyLocalHashGlobalAgg);
|
||||
secondPhaseRequire, anyLocalHashGlobalAgg);
|
||||
|
||||
//phase 4
|
||||
AggregateParam distinctGlobalParam = new AggregateParam(
|
||||
@ -1801,7 +1864,7 @@ public class AggregateStrategies implements ImplementationRuleFactory {
|
||||
if (aggregateFunction.isDistinct()) {
|
||||
Set<Expression> aggChild = Sets.newLinkedHashSet(aggregateFunction.children());
|
||||
Preconditions.checkArgument(aggChild.size() == 1
|
||||
|| aggregateFunction.getDistinctArguments().size() == 1,
|
||||
|| aggregateFunction.getDistinctArguments().size() == 1,
|
||||
"cannot process more than one child in aggregate distinct function: "
|
||||
+ aggregateFunction);
|
||||
AggregateFunction nonDistinct = aggregateFunction
|
||||
@ -1821,10 +1884,12 @@ public class AggregateStrategies implements ImplementationRuleFactory {
|
||||
});
|
||||
globalDistinctOutput.add(outputExprPhase4);
|
||||
}
|
||||
|
||||
RequireProperties fourPhaseRequire = fourPhaseRequireSupplier.apply(logicalAgg);
|
||||
PhysicalHashAggregate<? extends Plan> distinctGlobal = new PhysicalHashAggregate<>(
|
||||
logicalAgg.getGroupByExpressions(), globalDistinctOutput, Optional.empty(),
|
||||
distinctGlobalParam, false, logicalAgg.getLogicalProperties(),
|
||||
requireGather, distinctLocal);
|
||||
fourPhaseRequire, distinctLocal);
|
||||
|
||||
return ImmutableList.<PhysicalHashAggregate<? extends Plan>>builder()
|
||||
.add(distinctGlobal)
|
||||
|
||||
@ -29,6 +29,7 @@ import com.google.common.collect.ImmutableSet;
|
||||
|
||||
import java.util.List;
|
||||
import java.util.Set;
|
||||
import java.util.concurrent.atomic.AtomicBoolean;
|
||||
|
||||
/**
|
||||
* Common interface for logical/physical Aggregate.
|
||||
@ -68,4 +69,25 @@ public interface Aggregate<CHILD_TYPE extends Plan> extends UnaryPlan<CHILD_TYPE
|
||||
}
|
||||
return distinctArguments.build();
|
||||
}
|
||||
|
||||
/** everyDistinctArgumentNumIsOne */
|
||||
default boolean everyDistinctArgumentNumIsOne() {
|
||||
AtomicBoolean hasDistinctArguments = new AtomicBoolean(false);
|
||||
for (NamedExpression outputExpression : getOutputExpressions()) {
|
||||
boolean distinctArgumentSizeNotOne = outputExpression.anyMatch(expr -> {
|
||||
if (expr instanceof AggregateFunction) {
|
||||
AggregateFunction aggFun = (AggregateFunction) expr;
|
||||
if (aggFun.isDistinct()) {
|
||||
hasDistinctArguments.set(true);
|
||||
return aggFun.getDistinctArguments().size() != 1;
|
||||
}
|
||||
}
|
||||
return false;
|
||||
});
|
||||
if (distinctArgumentSizeNotOne) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
return hasDistinctArguments.get();
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user