[enhancement](Nereids) support 4 phases distinct aggregate with full distribution (#36016)

cherry pick from #35871
This commit is contained in:
924060929
2024-06-07 21:08:33 +08:00
committed by GitHub
parent 1715bae26f
commit 67f4d88988
7 changed files with 144 additions and 38 deletions

View File

@ -113,16 +113,6 @@ public class ChildrenPropertiesRegulator extends PlanVisitor<Boolean, Void> {
// this means one stage gather agg, usually bad pattern
return false;
}
// forbid three or four stage distinct agg inter by distribute
if (agg.getAggMode() == AggMode.BUFFER_TO_BUFFER && children.get(0).getPlan() instanceof PhysicalDistribute) {
// if distinct without group by key, we prefer three or four stage distinct agg
// because the second phase of multi-distinct only have one instance, and it is slow generally.
if (agg.getGroupByExpressions().size() == 1
&& agg.getOutputExpressions().size() == 1) {
return true;
}
return false;
}
// forbid TWO_PHASE_AGGREGATE_WITH_DISTINCT after shuffle
// TODO: this is forbid good plan after cte reuse by mistake

View File

@ -443,6 +443,7 @@ public enum RuleType {
TWO_PHASE_AGGREGATE_WITH_MULTI_DISTINCT(RuleTypeClass.IMPLEMENTATION),
THREE_PHASE_AGGREGATE_WITH_DISTINCT(RuleTypeClass.IMPLEMENTATION),
FOUR_PHASE_AGGREGATE_WITH_DISTINCT(RuleTypeClass.IMPLEMENTATION),
FOUR_PHASE_AGGREGATE_WITH_DISTINCT_WITH_FULL_DISTRIBUTE(RuleTypeClass.IMPLEMENTATION),
LOGICAL_UNION_TO_PHYSICAL_UNION(RuleTypeClass.IMPLEMENTATION),
LOGICAL_EXCEPT_TO_PHYSICAL_EXCEPT(RuleTypeClass.IMPLEMENTATION),
LOGICAL_INTERSECT_TO_PHYSICAL_INTERSECT(RuleTypeClass.IMPLEMENTATION),

View File

@ -75,6 +75,7 @@ import org.apache.doris.nereids.util.ExpressionUtils;
import org.apache.doris.nereids.util.TypeCoercionUtils;
import org.apache.doris.qe.ConnectContext;
import com.google.common.base.Function;
import com.google.common.base.Preconditions;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
@ -293,15 +294,89 @@ public class AggregateStrategies implements ImplementationRuleFactory {
// .thenApplyMulti(ctx -> twoPhaseAggregateWithDistinct(ctx.root, ctx.connectContext))
// ),
RuleType.THREE_PHASE_AGGREGATE_WITH_DISTINCT.build(
basePattern
.when(agg -> agg.getDistinctArguments().size() == 1)
.thenApplyMulti(ctx -> threePhaseAggregateWithDistinct(ctx.root, ctx.connectContext))
basePattern
.when(agg -> agg.getDistinctArguments().size() == 1)
.thenApplyMulti(ctx -> threePhaseAggregateWithDistinct(ctx.root, ctx.connectContext))
),
/*
* sql:
* select count(distinct name), sum(age) from student;
* <p>
* 4 phase plan
* DISTINCT_GLOBAL(BUFFER_TO_RESULT, groupBy(),
* output[count(partial_count(name)), sum(partial_sum(partial_sum(age)))],
* GATHER)
* +--DISTINCT_LOCAL(INPUT_TO_BUFFER, groupBy(),
* output(partial_count(name), partial_sum(partial_sum(age))),
* hash distribute by name)
* +--GLOBAL(BUFFER_TO_BUFFER, groupBy(name),
* output(name, partial_sum(age)),
* hash_distribute by name)
* +--LOCAL(INPUT_TO_BUFFER, groupBy(name), output(name, partial_sum(age)))
* +--scan(name, age)
*/
RuleType.FOUR_PHASE_AGGREGATE_WITH_DISTINCT.build(
basePattern
.when(agg -> agg.getDistinctArguments().size() == 1)
.when(agg -> agg.getGroupByExpressions().isEmpty())
.thenApplyMulti(ctx -> fourPhaseAggregateWithDistinct(ctx.root, ctx.connectContext))
basePattern
.when(agg -> agg.getDistinctArguments().size() == 1)
.when(agg -> agg.getGroupByExpressions().isEmpty())
.thenApplyMulti(ctx -> {
Function<List<Expression>, RequireProperties> secondPhaseRequireDistinctHash =
groupByAndDistinct -> RequireProperties.of(
PhysicalProperties.createHash(
ctx.root.getDistinctArguments(), ShuffleType.REQUIRE
)
);
Function<LogicalAggregate<? extends Plan>, RequireProperties> fourPhaseRequireGather =
agg -> RequireProperties.of(PhysicalProperties.GATHER);
return fourPhaseAggregateWithDistinct(
ctx.root, ctx.connectContext,
secondPhaseRequireDistinctHash, fourPhaseRequireGather
);
})
),
/*
* sql:
* select age, count(distinct name) from student group by age;
* <p>
* 4 phase plan
* DISTINCT_GLOBAL(BUFFER_TO_RESULT, groupBy(age),
* output[age, sum(partial_count(name))],
* hash distribute by name)
* +--DISTINCT_LOCAL(INPUT_TO_BUFFER, groupBy(age),
* output(age, partial_count(name)),
* hash distribute by age, name)
* +--GLOBAL(BUFFER_TO_BUFFER, groupBy(age, name),
* output(age, name),
* hash_distribute by age, name)
* +--LOCAL(INPUT_TO_BUFFER, groupBy(age, name), output(age, name))
* +--scan(age, name)
*/
RuleType.FOUR_PHASE_AGGREGATE_WITH_DISTINCT_WITH_FULL_DISTRIBUTE.build(
basePattern
.when(agg -> agg.everyDistinctArgumentNumIsOne() && !agg.getGroupByExpressions().isEmpty())
.when(agg ->
ImmutableSet.builder()
.addAll(agg.getGroupByExpressions())
.addAll(agg.getDistinctArguments())
.build().size() > agg.getGroupByExpressions().size()
)
.thenApplyMulti(ctx -> {
Function<List<Expression>, RequireProperties> secondPhaseRequireGroupByAndDistinctHash =
groupByAndDistinct -> RequireProperties.of(
PhysicalProperties.createHash(groupByAndDistinct, ShuffleType.REQUIRE)
);
Function<LogicalAggregate<? extends Plan>, RequireProperties> fourPhaseRequireGroupByHash =
agg -> RequireProperties.of(
PhysicalProperties.createHash(
agg.getGroupByExpressions(), ShuffleType.REQUIRE
)
);
return fourPhaseAggregateWithDistinct(
ctx.root, ctx.connectContext,
secondPhaseRequireGroupByAndDistinctHash, fourPhaseRequireGroupByHash
);
})
)
);
}
@ -1649,19 +1724,10 @@ public class AggregateStrategies implements ImplementationRuleFactory {
return connectContext == null || connectContext.getSessionVariable().enablePushDownNoGroupAgg();
}
/**
* sql:
* select count(distinct name), sum(age) from student;
* <p>
* 4 phase plan
* DISTINCT_GLOBAL, BUFFER_TO_RESULT groupBy(), output[count(name), sum(age#5)], [GATHER]
* +--DISTINCT_LOCAL, INPUT_TO_BUFFER, groupBy()), output(count(name), partial_sum(age)), hash distribute by name
* +--GLOBAL, BUFFER_TO_BUFFER, groupBy(name), output(name, partial_sum(age)), hash_distribute by name
* +--LOCAL, INPUT_TO_BUFFER, groupBy(name), output(name, partial_sum(age))
* +--scan(name, age)
*/
private List<PhysicalHashAggregate<? extends Plan>> fourPhaseAggregateWithDistinct(
LogicalAggregate<? extends Plan> logicalAgg, ConnectContext connectContext) {
LogicalAggregate<? extends Plan> logicalAgg, ConnectContext connectContext,
Function<List<Expression>, RequireProperties> secondPhaseRequireSupplier,
Function<LogicalAggregate<? extends Plan>, RequireProperties> fourPhaseRequireSupplier) {
boolean couldBanned = couldConvertToMulti(logicalAgg);
Set<AggregateFunction> aggregateFunctions = logicalAgg.getAggregateFunctions();
@ -1734,16 +1800,13 @@ public class AggregateStrategies implements ImplementationRuleFactory {
globalAggOutput = ImmutableList.of(new Alias(new NullLiteral(TinyIntType.INSTANCE)));
}
RequireProperties requireGather = RequireProperties.of(PhysicalProperties.GATHER);
RequireProperties requireDistinctHash = RequireProperties.of(
PhysicalProperties.createHash(logicalAgg.getDistinctArguments(), ShuffleType.REQUIRE));
RequireProperties secondPhaseRequire = secondPhaseRequireSupplier.apply(localAggGroupBy);
//phase 2
PhysicalHashAggregate<? extends Plan> anyLocalHashGlobalAgg = new PhysicalHashAggregate<>(
localAggGroupBy, globalAggOutput, Optional.of(ImmutableList.copyOf(logicalAgg.getDistinctArguments())),
bufferToBufferParam, false, logicalAgg.getLogicalProperties(),
requireDistinctHash, anyLocalAgg);
secondPhaseRequire, anyLocalAgg);
// phase 3
AggregateParam distinctLocalParam = new AggregateParam(
@ -1787,7 +1850,7 @@ public class AggregateStrategies implements ImplementationRuleFactory {
PhysicalHashAggregate<? extends Plan> distinctLocal = new PhysicalHashAggregate<>(
logicalAgg.getGroupByExpressions(), localDistinctOutput, Optional.empty(),
distinctLocalParam, false, logicalAgg.getLogicalProperties(),
requireDistinctHash, anyLocalHashGlobalAgg);
secondPhaseRequire, anyLocalHashGlobalAgg);
//phase 4
AggregateParam distinctGlobalParam = new AggregateParam(
@ -1801,7 +1864,7 @@ public class AggregateStrategies implements ImplementationRuleFactory {
if (aggregateFunction.isDistinct()) {
Set<Expression> aggChild = Sets.newLinkedHashSet(aggregateFunction.children());
Preconditions.checkArgument(aggChild.size() == 1
|| aggregateFunction.getDistinctArguments().size() == 1,
|| aggregateFunction.getDistinctArguments().size() == 1,
"cannot process more than one child in aggregate distinct function: "
+ aggregateFunction);
AggregateFunction nonDistinct = aggregateFunction
@ -1821,10 +1884,12 @@ public class AggregateStrategies implements ImplementationRuleFactory {
});
globalDistinctOutput.add(outputExprPhase4);
}
RequireProperties fourPhaseRequire = fourPhaseRequireSupplier.apply(logicalAgg);
PhysicalHashAggregate<? extends Plan> distinctGlobal = new PhysicalHashAggregate<>(
logicalAgg.getGroupByExpressions(), globalDistinctOutput, Optional.empty(),
distinctGlobalParam, false, logicalAgg.getLogicalProperties(),
requireGather, distinctLocal);
fourPhaseRequire, distinctLocal);
return ImmutableList.<PhysicalHashAggregate<? extends Plan>>builder()
.add(distinctGlobal)

View File

@ -29,6 +29,7 @@ import com.google.common.collect.ImmutableSet;
import java.util.List;
import java.util.Set;
import java.util.concurrent.atomic.AtomicBoolean;
/**
* Common interface for logical/physical Aggregate.
@ -68,4 +69,25 @@ public interface Aggregate<CHILD_TYPE extends Plan> extends UnaryPlan<CHILD_TYPE
}
return distinctArguments.build();
}
/** everyDistinctArgumentNumIsOne */
default boolean everyDistinctArgumentNumIsOne() {
AtomicBoolean hasDistinctArguments = new AtomicBoolean(false);
for (NamedExpression outputExpression : getOutputExpressions()) {
boolean distinctArgumentSizeNotOne = outputExpression.anyMatch(expr -> {
if (expr instanceof AggregateFunction) {
AggregateFunction aggFun = (AggregateFunction) expr;
if (aggFun.isDistinct()) {
hasDistinctArguments.set(true);
return aggFun.getDistinctArguments().size() != 1;
}
}
return false;
});
if (distinctArgumentSizeNotOne) {
return false;
}
}
return hasDistinctArguments.get();
}
}