[opt](Nereids) forbid some bad case on agg plans (#21565)

1. forbid all candidates that need to gather process except must do it
2. forbid do local agg after reshuffle of two phase agg of distinct
3. forbid one phase agg after reshuffle
4. forbid three or four phase agg for distinct if any stage need reshuffle
5. forbid multi distinct for one distinct agg if do not need reshuffle
This commit is contained in:
morrySnow
2023-07-07 17:45:55 +08:00
committed by GitHub
parent b471cf2045
commit 2d445bbb6d
20 changed files with 314 additions and 215 deletions

View File

@ -1742,7 +1742,7 @@ public class PhysicalPlanTranslator extends DefaultPlanVisitor<PlanFragment, Pla
.map(e -> {
Expression function = e.child(0).child(0);
if (function instanceof AggregateFunction) {
AggregateParam param = AggregateParam.localResult();
AggregateParam param = AggregateParam.LOCAL_RESULT;
function = new AggregateExpression((AggregateFunction) function, param);
}
return ExpressionTranslator.translate(function, context);

View File

@ -23,7 +23,12 @@ import org.apache.doris.nereids.cost.CostCalculator;
import org.apache.doris.nereids.jobs.JobContext;
import org.apache.doris.nereids.memo.GroupExpression;
import org.apache.doris.nereids.properties.DistributionSpecHash.ShuffleType;
import org.apache.doris.nereids.trees.expressions.AggregateExpression;
import org.apache.doris.nereids.trees.expressions.Alias;
import org.apache.doris.nereids.trees.expressions.ExprId;
import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.expressions.SlotReference;
import org.apache.doris.nereids.trees.expressions.functions.agg.MultiDistinction;
import org.apache.doris.nereids.trees.plans.AggMode;
import org.apache.doris.nereids.trees.plans.Plan;
import org.apache.doris.nereids.trees.plans.physical.PhysicalDistribute;
@ -43,6 +48,7 @@ import java.util.ArrayList;
import java.util.List;
import java.util.Optional;
import java.util.Set;
import java.util.stream.Collectors;
/**
* ensure child add enough distribute. update children properties if we do regular
@ -88,12 +94,55 @@ public class ChildrenPropertiesRegulator extends PlanVisitor<Boolean, Void> {
@Override
public Boolean visitPhysicalHashAggregate(PhysicalHashAggregate<? extends Plan> agg, Void context) {
// forbid one phase agg on distribute
if (agg.getAggMode() == AggMode.INPUT_TO_RESULT
if (!agg.getAggregateParam().canBeBanned) {
return true;
}
// forbid one phase agg on distribute and three or four stage distinct agg inter by distribute
if ((agg.getAggMode() == AggMode.INPUT_TO_RESULT || agg.getAggMode() == AggMode.BUFFER_TO_BUFFER)
&& children.get(0).getPlan() instanceof PhysicalDistribute) {
// this means one stage gather agg, usually bad pattern
return false;
}
// forbid TWO_PHASE_AGGREGATE_WITH_DISTINCT after shuffle
// TODO: this is forbid good plan after cte reuse by mistake
if (agg.getAggMode() == AggMode.INPUT_TO_BUFFER
&& requiredProperties.get(0).getDistributionSpec() instanceof DistributionSpecHash
&& children.get(0).getPlan() instanceof PhysicalDistribute) {
return false;
}
// forbid multi distinct opt that bad than multi-stage version when multi-stage can be executed in one fragment
if (agg.getAggMode() == AggMode.INPUT_TO_BUFFER || agg.getAggMode() == AggMode.INPUT_TO_RESULT) {
List<MultiDistinction> multiDistinctions = agg.getOutputExpressions().stream()
.filter(Alias.class::isInstance)
.map(a -> ((Alias) a).child())
.filter(AggregateExpression.class::isInstance)
.map(a -> ((AggregateExpression) a).getFunction())
.filter(MultiDistinction.class::isInstance)
.map(MultiDistinction.class::cast)
.collect(Collectors.toList());
if (multiDistinctions.size() == 1) {
Expression distinctChild = multiDistinctions.get(0).child(0);
DistributionSpec childDistribution = childrenProperties.get(0).getDistributionSpec();
if (distinctChild instanceof SlotReference && childDistribution instanceof DistributionSpecHash) {
SlotReference slotReference = (SlotReference) distinctChild;
DistributionSpecHash distributionSpecHash = (DistributionSpecHash) childDistribution;
List<ExprId> groupByColumns = agg.getGroupByExpressions().stream()
.map(SlotReference.class::cast)
.map(SlotReference::getExprId)
.collect(Collectors.toList());
DistributionSpecHash groupByRequire = new DistributionSpecHash(
groupByColumns, ShuffleType.REQUIRE);
List<ExprId> distinctChildColumns = Lists.newArrayList(slotReference.getExprId());
distinctChildColumns.add(slotReference.getExprId());
DistributionSpecHash distinctChildRequire = new DistributionSpecHash(
distinctChildColumns, ShuffleType.REQUIRE);
if ((!groupByColumns.isEmpty() && distributionSpecHash.satisfy(groupByRequire))
|| (groupByColumns.isEmpty() && distributionSpecHash.satisfy(distinctChildRequire))) {
return false;
}
}
}
}
// process must shuffle
visit(agg, context);
// process agg

View File

@ -70,11 +70,13 @@ import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import com.google.common.collect.ImmutableSet;
import com.google.common.collect.Lists;
import com.google.common.collect.Sets;
import java.util.Collection;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Map.Entry;
import java.util.Optional;
import java.util.Set;
import java.util.stream.Collectors;
@ -123,41 +125,43 @@ public class AggregateStrategies implements ImplementationRuleFactory {
.when(agg -> agg.getDistinctArguments().size() == 0)
.thenApplyMulti(ctx -> twoPhaseAggregateWithoutDistinct(ctx.root, ctx.connectContext))
),
RuleType.TWO_PHASE_AGGREGATE_WITH_COUNT_DISTINCT_MULTI.build(
basePattern
.when(this::containsCountDistinctMultiExpr)
.thenApplyMulti(ctx -> twoPhaseAggregateWithCountDistinctMulti(ctx.root, ctx.cascadesContext))
),
// RuleType.TWO_PHASE_AGGREGATE_WITH_COUNT_DISTINCT_MULTI.build(
// basePattern
// .when(this::containsCountDistinctMultiExpr)
// .thenApplyMulti(ctx -> twoPhaseAggregateWithCountDistinctMulti(ctx.root, ctx.cascadesContext))
// ),
RuleType.THREE_PHASE_AGGREGATE_WITH_COUNT_DISTINCT_MULTI.build(
basePattern
.when(this::containsCountDistinctMultiExpr)
.thenApplyMulti(ctx -> threePhaseAggregateWithCountDistinctMulti(ctx.root, ctx.cascadesContext))
),
RuleType.TWO_PHASE_AGGREGATE_WITH_DISTINCT.build(
basePattern
.when(agg -> agg.getDistinctArguments().size() == 1)
.thenApplyMulti(ctx -> twoPhaseAggregateWithDistinct(ctx.root, ctx.connectContext))
),
RuleType.ONE_PHASE_AGGREGATE_SINGLE_DISTINCT_TO_MULTI.build(
basePattern
.when(agg -> agg.getDistinctArguments().size() == 1 && enableSingleDistinctColumnOpt())
.when(agg -> agg.getDistinctArguments().size() == 1 && couldConvertToMulti(agg))
.thenApplyMulti(ctx -> onePhaseAggregateWithMultiDistinct(ctx.root, ctx.connectContext))
),
RuleType.TWO_PHASE_AGGREGATE_SINGLE_DISTINCT_TO_MULTI.build(
basePattern
.when(agg -> agg.getDistinctArguments().size() == 1 && enableSingleDistinctColumnOpt())
.when(agg -> agg.getDistinctArguments().size() == 1 && couldConvertToMulti(agg))
.thenApplyMulti(ctx -> twoPhaseAggregateWithMultiDistinct(ctx.root, ctx.connectContext))
),
RuleType.THREE_PHASE_AGGREGATE_WITH_DISTINCT.build(
basePattern
.when(agg -> agg.getDistinctArguments().size() == 1)
.thenApplyMulti(ctx -> threePhaseAggregateWithDistinct(ctx.root, ctx.connectContext))
),
RuleType.TWO_PHASE_AGGREGATE_WITH_MULTI_DISTINCT.build(
basePattern
.when(agg -> agg.getDistinctArguments().size() > 1 && !containsCountDistinctMultiExpr(agg))
.when(agg -> agg.getDistinctArguments().size() > 1
&& !containsCountDistinctMultiExpr(agg)
&& couldConvertToMulti(agg))
.thenApplyMulti(ctx -> twoPhaseAggregateWithMultiDistinct(ctx.root, ctx.connectContext))
),
// RuleType.TWO_PHASE_AGGREGATE_WITH_DISTINCT.build(
// basePattern
// .when(agg -> agg.getDistinctArguments().size() == 1)
// .thenApplyMulti(ctx -> twoPhaseAggregateWithDistinct(ctx.root, ctx.connectContext))
// ),
RuleType.THREE_PHASE_AGGREGATE_WITH_DISTINCT.build(
basePattern
.when(agg -> agg.getDistinctArguments().size() == 1)
.thenApplyMulti(ctx -> threePhaseAggregateWithDistinct(ctx.root, ctx.connectContext))
),
RuleType.FOUR_PHASE_AGGREGATE_WITH_DISTINCT.build(
basePattern
.when(agg -> agg.getDistinctArguments().size() == 1)
@ -169,15 +173,15 @@ public class AggregateStrategies implements ImplementationRuleFactory {
/**
* sql: select count(*) from tbl
*
* <p>
* before:
*
* <p>
* LogicalAggregate(groupBy=[], output=[count(*)])
* |
* LogicalOlapScan(table=tbl)
*
* <p>
* after:
*
* <p>
* LogicalAggregate(groupBy=[], output=[count(*)])
* |
* PhysicalStorageLayerAggregate(pushAggOp=COUNT, table=PhysicalOlapScan(table=tbl))
@ -205,7 +209,7 @@ public class AggregateStrategies implements ImplementationRuleFactory {
.map(AggregateFunction::getClass)
.collect(Collectors.toSet());
Map<Class, PushDownAggOp> supportedAgg = PushDownAggOp.supportedFunctions();
Map<Class<? extends AggregateFunction>, PushDownAggOp> supportedAgg = PushDownAggOp.supportedFunctions();
if (!supportedAgg.keySet().containsAll(functionClasses)) {
return canNotPush;
}
@ -292,7 +296,7 @@ public class AggregateStrategies implements ImplementationRuleFactory {
|| colType == PrimitiveType.STRING) {
return canNotPush;
}
if (colType.isCharFamily() && mergeOp != PushDownAggOp.COUNT && column.getType().getLength() > 512) {
if (colType.isCharFamily() && column.getType().getLength() > 512) {
return canNotPush;
}
}
@ -324,25 +328,25 @@ public class AggregateStrategies implements ImplementationRuleFactory {
/**
* sql: select count(*) from tbl group by id
*
* <p>
* before:
*
* <p>
* LogicalAggregate(groupBy=[id], output=[count(*)])
* |
* LogicalOlapScan(table=tbl)
*
* <p>
* after:
*
* <p>
* single node aggregate:
*
* <p>
* PhysicalHashAggregate(groupBy=[id], output=[count(*)])
* |
* PhysicalDistribute(distributionSpec=GATHER)
* |
* LogicalOlapScan(table=tbl)
*
* <p>
* distribute node aggregate:
*
* <p>
* PhysicalHashAggregate(groupBy=[id], output=[count(*)])
* |
* LogicalOlapScan(table=tbl, **already distribute by id**)
@ -351,7 +355,7 @@ public class AggregateStrategies implements ImplementationRuleFactory {
private List<PhysicalHashAggregate<Plan>> onePhaseAggregateWithoutDistinct(
LogicalAggregate<? extends Plan> logicalAgg, ConnectContext connectContext) {
RequireProperties requireGather = RequireProperties.of(PhysicalProperties.GATHER);
AggregateParam inputToResultParam = AggregateParam.localResult();
AggregateParam inputToResultParam = AggregateParam.LOCAL_RESULT;
List<NamedExpression> newOutput = ExpressionUtils.rewriteDownShortCircuit(
logicalAgg.getOutputExpressions(), outputChild -> {
if (outputChild instanceof AggregateFunction) {
@ -366,7 +370,9 @@ public class AggregateStrategies implements ImplementationRuleFactory {
requireGather, logicalAgg.child());
if (logicalAgg.getGroupByExpressions().isEmpty()) {
return ImmutableList.of(gatherLocalAgg);
// TODO: usually bad, disable it until we could do better cost computation.
// return ImmutableList.of(gatherLocalAgg);
return ImmutableList.of();
} else {
RequireProperties requireHash = RequireProperties.of(
PhysicalProperties.createHash(logicalAgg.getGroupByExpressions(), ShuffleType.REQUIRE));
@ -383,17 +389,17 @@ public class AggregateStrategies implements ImplementationRuleFactory {
/**
* sql: select count(distinct id, name) from tbl group by name
*
* <p>
* before:
*
* <p>
* LogicalAggregate(groupBy=[name], output=[count(distinct id, name)])
* |
* LogicalOlapScan(table=tbl)
*
* <p>
* after:
*
* <p>
* single node aggregate:
*
* <p>
* PhysicalHashAggregate(groupBy=[name], output=[count(if(id is null, null, name))])
* |
* PhysicalHashAggregate(groupBy=[name, id], output=[name, id])
@ -401,9 +407,9 @@ public class AggregateStrategies implements ImplementationRuleFactory {
* PhysicalDistribute(distributionSpec=GATHER)
* |
* LogicalOlapScan(table=tbl)
*
* <p>
* distribute node aggregate:
*
* <p>
* PhysicalHashAggregate(groupBy=[name], output=[count(if(id is null, null, name))])
* |
* PhysicalHashAggregate(groupBy=[name, id], output=[name, id])
@ -415,7 +421,7 @@ public class AggregateStrategies implements ImplementationRuleFactory {
*/
private List<PhysicalHashAggregate<Plan>> twoPhaseAggregateWithCountDistinctMulti(
LogicalAggregate<? extends Plan> logicalAgg, CascadesContext cascadesContext) {
AggregateParam inputToBufferParam = new AggregateParam(AggPhase.LOCAL, AggMode.INPUT_TO_BUFFER);
AggregateParam inputToBufferParam = AggregateParam.LOCAL_BUFFER;
Collection<Expression> countDistinctArguments = logicalAgg.getDistinctArguments();
List<Expression> localAggGroupBy = ImmutableList.copyOf(ImmutableSet.<Expression>builder()
@ -487,7 +493,8 @@ public class AggregateStrategies implements ImplementationRuleFactory {
.withRequireTree(requireHash.withChildren(requireHash))
.withPartitionExpressions(logicalAgg.getGroupByExpressions());
return ImmutableList.<PhysicalHashAggregate<Plan>>builder()
.add(gatherLocalGatherDistinctAgg)
// TODO: usually bad, disable it until we could do better cost computation.
//.add(gatherLocalGatherDistinctAgg)
.add(hashLocalHashGlobalAgg)
.build();
}
@ -495,17 +502,17 @@ public class AggregateStrategies implements ImplementationRuleFactory {
/**
* sql: select count(distinct id, name) from tbl group by name
*
* <p>
* before:
*
* <p>
* LogicalAggregate(groupBy=[name], output=[count(distinct id, name)])
* |
* LogicalOlapScan(table=tbl)
*
* <p>
* after:
*
* <p>
* single node aggregate:
*
* <p>
* PhysicalHashAggregate(groupBy=[name], output=[count(if(id is null, null, name))])
* |
* PhysicalHashAggregate(groupBy=[name, id], output=[name, id], mode=BUFFER_TO_BUFFER)
@ -515,9 +522,9 @@ public class AggregateStrategies implements ImplementationRuleFactory {
* PhysicalHashAggregate(groupBy=[name, id], output=[name, id], mode=INPUT_TO_BUFFER)
* |
* LogicalOlapScan(table=tbl)
*
* <p>
* distribute node aggregate:
*
* <p>
* PhysicalHashAggregate(groupBy=[name], output=[count(if(id is null, null, name))])
* |
* PhysicalHashAggregate(groupBy=[name, id], output=[name, id], mode=BUFFER_TO_BUFFER)
@ -566,11 +573,17 @@ public class AggregateStrategies implements ImplementationRuleFactory {
List<Expression> globalAggGroupBy = localAggGroupBy;
AggregateParam bufferToBufferParam = new AggregateParam(AggPhase.GLOBAL, AggMode.BUFFER_TO_BUFFER);
boolean hasCountDistinctMulti = logicalAgg.getAggregateFunctions().stream()
.filter(AggregateFunction::isDistinct)
.filter(Count.class::isInstance)
.anyMatch(c -> c.arity() > 1);
AggregateParam bufferToBufferParam = new AggregateParam(
AggPhase.GLOBAL, AggMode.BUFFER_TO_BUFFER, !hasCountDistinctMulti);
Map<AggregateFunction, Alias> nonDistinctAggFunctionToAliasPhase2 =
nonDistinctAggFunctionToAliasPhase1.entrySet()
.stream()
.collect(ImmutableMap.toImmutableMap(kv -> kv.getKey(), kv -> {
.collect(ImmutableMap.toImmutableMap(Entry::getKey, kv -> {
AggregateFunction originFunction = kv.getKey();
Alias localOutputAlias = kv.getValue();
AggregateExpression globalAggExpr = new AggregateExpression(
@ -596,7 +609,7 @@ public class AggregateStrategies implements ImplementationRuleFactory {
logicalAgg, cascadesContext).first;
AggregateParam distinctInputToResultParam
= new AggregateParam(AggPhase.DISTINCT_LOCAL, AggMode.INPUT_TO_RESULT);
= new AggregateParam(AggPhase.DISTINCT_LOCAL, AggMode.INPUT_TO_RESULT, !hasCountDistinctMulti);
AggregateParam globalBufferToResultParam
= new AggregateParam(AggPhase.GLOBAL, AggMode.BUFFER_TO_RESULT);
List<NamedExpression> distinctOutput = ExpressionUtils.rewriteDownShortCircuit(
@ -621,19 +634,19 @@ public class AggregateStrategies implements ImplementationRuleFactory {
logicalAgg.getLogicalProperties(), requireGather, anyLocalGatherGlobalAgg
);
RequireProperties requireDistinctHash = RequireProperties.of(
PhysicalProperties.createHash(logicalAgg.getGroupByExpressions(), ShuffleType.REQUIRE));
PhysicalHashAggregate<? extends Plan> anyLocalHashGlobalGatherDistinctAgg
= anyLocalGatherGlobalGatherAgg.withChildren(ImmutableList.of(
anyLocalGatherGlobalAgg
.withRequire(requireDistinctHash)
.withPartitionExpressions(ImmutableList.copyOf(logicalAgg.getDistinctArguments()))
));
// RequireProperties requireDistinctHash = RequireProperties.of(
// PhysicalProperties.createHash(logicalAgg.getGroupByExpressions(), ShuffleType.REQUIRE));
// PhysicalHashAggregate<? extends Plan> anyLocalHashGlobalGatherDistinctAgg
// = anyLocalGatherGlobalGatherAgg.withChildren(ImmutableList.of(
// anyLocalGatherGlobalAgg
// .withRequire(requireDistinctHash)
// .withPartitionExpressions(ImmutableList.copyOf(logicalAgg.getDistinctArguments()))
// ));
if (logicalAgg.getGroupByExpressions().isEmpty()) {
return ImmutableList.<PhysicalHashAggregate<? extends Plan>>builder()
.add(anyLocalGatherGlobalGatherAgg)
.add(anyLocalHashGlobalGatherDistinctAgg)
//.add(anyLocalHashGlobalGatherDistinctAgg)
.build();
} else {
RequireProperties requireGroupByHash = RequireProperties.of(
@ -646,8 +659,8 @@ public class AggregateStrategies implements ImplementationRuleFactory {
)
.withPartitionExpressions(logicalAgg.getGroupByExpressions());
return ImmutableList.<PhysicalHashAggregate<? extends Plan>>builder()
.add(anyLocalGatherGlobalGatherAgg)
.add(anyLocalHashGlobalGatherDistinctAgg)
// .add(anyLocalGatherGlobalGatherAgg)
// .add(anyLocalHashGlobalGatherDistinctAgg)
.add(anyLocalHashGlobalHashDistinctAgg)
.build();
}
@ -655,17 +668,17 @@ public class AggregateStrategies implements ImplementationRuleFactory {
/**
* sql: select name, count(value) from tbl group by name
*
* <p>
* before:
*
* <p>
* LogicalAggregate(groupBy=[name], output=[name, count(value)])
* |
* LogicalOlapScan(table=tbl)
*
* <p>
* after:
*
* <p>
* single node aggregate:
*
* <p>
* PhysicalHashAggregate(groupBy=[name], output=[name, count(value)], mode=BUFFER_TO_RESULT)
* |
* PhysicalDistribute(distributionSpec=GATHER)
@ -673,9 +686,9 @@ public class AggregateStrategies implements ImplementationRuleFactory {
* PhysicalHashAggregate(groupBy=[name], output=[name, count(value)], mode=INPUT_TO_BUFFER)
* |
* LogicalOlapScan(table=tbl)
*
* <p>
* distribute node aggregate:
*
* <p>
* PhysicalHashAggregate(groupBy=[name], output=[name, count(value)], mode=BUFFER_TO_RESULT)
* |
* PhysicalDistribute(distributionSpec=HASH(name))
@ -713,6 +726,9 @@ public class AggregateStrategies implements ImplementationRuleFactory {
AggregateParam bufferToResultParam = new AggregateParam(AggPhase.GLOBAL, AggMode.BUFFER_TO_RESULT);
List<NamedExpression> globalAggOutput = ExpressionUtils.rewriteDownShortCircuit(
logicalAgg.getOutputExpressions(), outputChild -> {
if (!(outputChild instanceof AggregateFunction)) {
return outputChild;
}
Alias inputToBufferAlias = inputToBufferAliases.get(outputChild);
if (inputToBufferAlias == null) {
return outputChild;
@ -722,7 +738,7 @@ public class AggregateStrategies implements ImplementationRuleFactory {
});
RequireProperties requireGather = RequireProperties.of(PhysicalProperties.GATHER);
PhysicalHashAggregate<Plan> anyLocalGatherGlobalAgg = new PhysicalHashAggregate(
PhysicalHashAggregate<Plan> anyLocalGatherGlobalAgg = new PhysicalHashAggregate<>(
localAggGroupBy, globalAggOutput, Optional.of(partitionExpressions),
bufferToResultParam, false, anyLocalAgg.getLogicalProperties(),
requireGather, anyLocalAgg);
@ -746,17 +762,17 @@ public class AggregateStrategies implements ImplementationRuleFactory {
/**
* sql: select count(distinct id) from tbl group by name
*
* <p>
* before:
*
* <p>
* LogicalAggregate(groupBy=[name], output=[name, count(distinct id)])
* |
* LogicalOlapScan(table=tbl)
*
* <p>
* after:
*
* <p>
* single node aggregate:
*
* <p>
* PhysicalHashAggregate(groupBy=[name], output=[name, count(distinct(id)], mode=BUFFER_TO_RESULT)
* |
* PhysicalHashAggregate(groupBy=[name, id], output=[name, id], mode=INPUT_TO_BUFFER)
@ -764,9 +780,9 @@ public class AggregateStrategies implements ImplementationRuleFactory {
* PhysicalDistribute(distributionSpec=GATHER)
* |
* LogicalOlapScan(table=tbl)
*
* <p>
* distribute node aggregate:
*
* <p>
* PhysicalHashAggregate(groupBy=[name], output=[name, count(distinct(id)], mode=BUFFER_TO_RESULT)
* |
* PhysicalHashAggregate(groupBy=[name, id], output=[name, id], mode=INPUT_TO_BUFFER)
@ -781,7 +797,7 @@ public class AggregateStrategies implements ImplementationRuleFactory {
Set<AggregateFunction> aggregateFunctions = logicalAgg.getAggregateFunctions();
Set<Expression> distinctArguments = aggregateFunctions.stream()
.filter(aggregateExpression -> aggregateExpression.isDistinct())
.filter(AggregateFunction::isDistinct)
.flatMap(aggregateExpression -> aggregateExpression.getArguments().stream())
.collect(ImmutableSet.toImmutableSet());
@ -790,7 +806,7 @@ public class AggregateStrategies implements ImplementationRuleFactory {
.addAll(distinctArguments)
.build();
AggregateParam inputToBufferParam = new AggregateParam(AggPhase.LOCAL, AggMode.INPUT_TO_BUFFER, true);
AggregateParam inputToBufferParam = AggregateParam.LOCAL_BUFFER;
Map<AggregateFunction, Alias> nonDistinctAggFunctionToAliasPhase1 = aggregateFunctions.stream()
.filter(aggregateFunction -> !aggregateFunction.isDistinct())
@ -822,10 +838,13 @@ public class AggregateStrategies implements ImplementationRuleFactory {
if (outputChild instanceof AggregateFunction) {
AggregateFunction aggregateFunction = (AggregateFunction) outputChild;
if (aggregateFunction.isDistinct()) {
Preconditions.checkArgument(aggregateFunction.arity() == 1);
Set<Expression> aggChild = Sets.newHashSet(aggregateFunction.children());
Preconditions.checkArgument(aggChild.size() == 1,
"cannot process more than one child in aggregate distinct function: "
+ aggregateFunction);
AggregateFunction nonDistinct = aggregateFunction
.withDistinctAndChildren(false, aggregateFunction.getArguments());
return new AggregateExpression(nonDistinct, AggregateParam.localResult());
.withDistinctAndChildren(false, ImmutableList.copyOf(aggChild));
return new AggregateExpression(nonDistinct, AggregateParam.LOCAL_RESULT);
} else {
Alias alias = nonDistinctAggFunctionToAliasPhase1.get(outputChild);
return new AggregateExpression(
@ -850,7 +869,7 @@ public class AggregateStrategies implements ImplementationRuleFactory {
.withPartitionExpressions(ImmutableList.copyOf(logicalAgg.getDistinctArguments()))
));
return ImmutableList.<PhysicalHashAggregate<? extends Plan>>builder()
.add(gatherLocalGatherGlobalAgg)
//.add(gatherLocalGatherGlobalAgg)
.add(hashLocalGatherGlobalAgg)
.build();
} else {
@ -863,7 +882,7 @@ public class AggregateStrategies implements ImplementationRuleFactory {
)
.withPartitionExpressions(logicalAgg.getGroupByExpressions());
return ImmutableList.<PhysicalHashAggregate<? extends Plan>>builder()
.add(gatherLocalGatherGlobalAgg)
// .add(gatherLocalGatherGlobalAgg)
.add(hashLocalHashGlobalAgg)
.build();
}
@ -871,16 +890,16 @@ public class AggregateStrategies implements ImplementationRuleFactory {
/**
* sql: select count(distinct id) from tbl group by name
*
* <p>
* before:
*
* <p>
* LogicalAggregate(groupBy=[name], output=[name, count(distinct id)])
* |
* LogicalOlapScan(table=tbl)
*
* <p>
* after:
* single node aggregate:
*
* <p>
* PhysicalHashAggregate(groupBy=[name], output=[name, count(distinct(id)], mode=BUFFER_TO_RESULT)
* |
* PhysicalHashAggregate(groupBy=[name, id], output=[name, id], mode=BUFFER_TO_BUFFER)
@ -890,9 +909,9 @@ public class AggregateStrategies implements ImplementationRuleFactory {
* PhysicalHashAggregate(groupBy=[name, id], output=[name, id], mode=INPUT_TO_BUFFER)
* |
* LogicalOlapScan(table=tbl)
*
* <p>
* distribute node aggregate:
*
* <p>
* PhysicalHashAggregate(groupBy=[name], output=[name, count(distinct(id)], mode=BUFFER_TO_RESULT)
* |
* PhysicalHashAggregate(groupBy=[name, id], output=[name, id], mode=BUFFER_TO_BUFFER)
@ -907,10 +926,12 @@ public class AggregateStrategies implements ImplementationRuleFactory {
// TODO: support one phase aggregate(group by columns + distinct columns) + two phase distinct aggregate
private List<PhysicalHashAggregate<? extends Plan>> threePhaseAggregateWithDistinct(
LogicalAggregate<? extends Plan> logicalAgg, ConnectContext connectContext) {
boolean couldBanned = couldConvertToMulti(logicalAgg);
Set<AggregateFunction> aggregateFunctions = logicalAgg.getAggregateFunctions();
Set<Expression> distinctArguments = aggregateFunctions.stream()
.filter(aggregateExpression -> aggregateExpression.isDistinct())
.filter(AggregateFunction::isDistinct)
.flatMap(aggregateExpression -> aggregateExpression.getArguments().stream())
.collect(ImmutableSet.toImmutableSet());
@ -919,7 +940,7 @@ public class AggregateStrategies implements ImplementationRuleFactory {
.addAll(distinctArguments)
.build();
AggregateParam inputToBufferParam = new AggregateParam(AggPhase.LOCAL, AggMode.INPUT_TO_BUFFER);
AggregateParam inputToBufferParam = new AggregateParam(AggPhase.LOCAL, AggMode.INPUT_TO_BUFFER, couldBanned);
Map<AggregateFunction, Alias> nonDistinctAggFunctionToAliasPhase1 = aggregateFunctions.stream()
.filter(aggregateFunction -> !aggregateFunction.isDistinct())
@ -942,11 +963,11 @@ public class AggregateStrategies implements ImplementationRuleFactory {
maybeUsingStreamAgg, Optional.empty(), logicalAgg.getLogicalProperties(),
requireAny, logicalAgg.child());
AggregateParam bufferToBufferParam = new AggregateParam(AggPhase.GLOBAL, AggMode.BUFFER_TO_BUFFER);
AggregateParam bufferToBufferParam = new AggregateParam(AggPhase.GLOBAL, AggMode.BUFFER_TO_BUFFER, couldBanned);
Map<AggregateFunction, Alias> nonDistinctAggFunctionToAliasPhase2 =
nonDistinctAggFunctionToAliasPhase1.entrySet()
.stream()
.collect(ImmutableMap.toImmutableMap(kv -> kv.getKey(), kv -> {
.collect(ImmutableMap.toImmutableMap(Entry::getKey, kv -> {
AggregateFunction originFunction = kv.getKey();
Alias localOutput = kv.getValue();
AggregateExpression globalAggExpr = new AggregateExpression(
@ -965,15 +986,19 @@ public class AggregateStrategies implements ImplementationRuleFactory {
bufferToBufferParam, false, logicalAgg.getLogicalProperties(),
requireGather, anyLocalAgg);
AggregateParam bufferToResultParam = new AggregateParam(AggPhase.DISTINCT_LOCAL, AggMode.INPUT_TO_RESULT);
AggregateParam bufferToResultParam = new AggregateParam(
AggPhase.DISTINCT_LOCAL, AggMode.INPUT_TO_RESULT, couldBanned);
List<NamedExpression> distinctOutput = ExpressionUtils.rewriteDownShortCircuit(
logicalAgg.getOutputExpressions(), expr -> {
if (expr instanceof AggregateFunction) {
AggregateFunction aggregateFunction = (AggregateFunction) expr;
if (aggregateFunction.isDistinct()) {
Preconditions.checkArgument(aggregateFunction.arity() == 1);
Set<Expression> aggChild = Sets.newHashSet(aggregateFunction.children());
Preconditions.checkArgument(aggChild.size() == 1,
"cannot process more than one child in aggregate distinct function: "
+ aggregateFunction);
AggregateFunction nonDistinct = aggregateFunction
.withDistinctAndChildren(false, aggregateFunction.getArguments());
.withDistinctAndChildren(false, ImmutableList.copyOf(aggChild));
return new AggregateExpression(nonDistinct,
bufferToResultParam, aggregateFunction.child(0));
} else {
@ -1017,8 +1042,9 @@ public class AggregateStrategies implements ImplementationRuleFactory {
)
.withPartitionExpressions(logicalAgg.getGroupByExpressions());
return ImmutableList.<PhysicalHashAggregate<? extends Plan>>builder()
.add(anyLocalGatherGlobalGatherDistinctAgg)
.add(anyLocalHashGlobalGatherDistinctAgg)
// TODO: this plan pattern is not good usually, we remove it temporary.
//.add(anyLocalGatherGlobalGatherDistinctAgg)
//.add(anyLocalHashGlobalGatherDistinctAgg)
.add(anyLocalHashGlobalHashDistinctAgg)
.build();
}
@ -1026,25 +1052,25 @@ public class AggregateStrategies implements ImplementationRuleFactory {
/**
* sql: select count(distinct id) from (...) group by name
*
* <p>
* before:
*
* <p>
* LogicalAggregate(groupBy=[name], output=[count(distinct id)])
* |
* any plan
*
* <p>
* after:
*
* <p>
* single node aggregate:
*
* <p>
* PhysicalHashAggregate(groupBy=[name], output=[multi_distinct_count(id)])
* |
* PhysicalDistribute(distributionSpec=GATHER)
* |
* any plan
*
* <p>
* distribute node aggregate:
*
* <p>
* PhysicalHashAggregate(groupBy=[name], output=[multi_distinct_count(id)])
* |
* any plan(**already distribute by name**)
@ -1052,7 +1078,7 @@ public class AggregateStrategies implements ImplementationRuleFactory {
*/
private List<PhysicalHashAggregate<? extends Plan>> onePhaseAggregateWithMultiDistinct(
LogicalAggregate<? extends Plan> logicalAgg, ConnectContext connectContext) {
AggregateParam inputToResultParam = AggregateParam.localResult();
AggregateParam inputToResultParam = AggregateParam.LOCAL_RESULT;
List<NamedExpression> newOutput = ExpressionUtils.rewriteDownShortCircuit(
logicalAgg.getOutputExpressions(), outputChild -> {
if (outputChild instanceof AggregateFunction) {
@ -1068,7 +1094,9 @@ public class AggregateStrategies implements ImplementationRuleFactory {
maybeUsingStreamAgg(connectContext, logicalAgg),
logicalAgg.getLogicalProperties(), requireGather, logicalAgg.child());
if (logicalAgg.getGroupByExpressions().isEmpty()) {
return ImmutableList.of(gatherLocalAgg);
// TODO: usually bad, disable it until we could do better cost computation.
// return ImmutableList.of(gatherLocalAgg);
return ImmutableList.of();
} else {
RequireProperties requireHash = RequireProperties.of(
PhysicalProperties.createHash(logicalAgg.getGroupByExpressions(), ShuffleType.REQUIRE));
@ -1085,17 +1113,17 @@ public class AggregateStrategies implements ImplementationRuleFactory {
/**
* sql: select count(distinct id) from tbl group by name
*
* <p>
* before:
*
* <p>
* LogicalAggregate(groupBy=[name], output=[name, count(distinct id)])
* |
* LogicalOlapScan(table=tbl)
*
* <p>
* after:
*
* <p>
* single node aggregate:
*
* <p>
* PhysicalHashAggregate(groupBy=[name], output=[name, multi_count_distinct(value)], mode=BUFFER_TO_RESULT)
* |
* PhysicalDistribute(distributionSpec=GATHER)
@ -1103,9 +1131,9 @@ public class AggregateStrategies implements ImplementationRuleFactory {
* PhysicalHashAggregate(groupBy=[name], output=[name, multi_count_distinct(value)], mode=INPUT_TO_BUFFER)
* |
* LogicalOlapScan(table=tbl)
*
* <p>
* distribute node aggregate:
*
* <p>
* PhysicalHashAggregate(groupBy=[name], output=[name, multi_count_distinct(value)], mode=BUFFER_TO_RESULT)
* |
* PhysicalDistribute(distributionSpec=HASH(name))
@ -1157,17 +1185,16 @@ public class AggregateStrategies implements ImplementationRuleFactory {
RequireProperties.of(PhysicalProperties.GATHER), anyLocalAgg);
if (logicalAgg.getGroupByExpressions().isEmpty()) {
Collection<Expression> distinctArguments = logicalAgg.getDistinctArguments();
RequireProperties requireDistinctHash = RequireProperties.of(PhysicalProperties.createHash(
distinctArguments, ShuffleType.REQUIRE));
PhysicalHashAggregate<? extends Plan> hashLocalGatherGlobalAgg = anyLocalGatherGlobalAgg
.withChildren(ImmutableList.of(anyLocalAgg
.withRequire(requireDistinctHash)
.withPartitionExpressions(ImmutableList.copyOf(logicalAgg.getDistinctArguments()))
));
// Collection<Expression> distinctArguments = logicalAgg.getDistinctArguments();
// RequireProperties requireDistinctHash = RequireProperties.of(PhysicalProperties.createHash(
// distinctArguments, ShuffleType.REQUIRE));
// PhysicalHashAggregate<? extends Plan> hashLocalGatherGlobalAgg = anyLocalGatherGlobalAgg
// .withChildren(ImmutableList.of(anyLocalAgg
// .withRequire(requireDistinctHash)
// .withPartitionExpressions(ImmutableList.copyOf(logicalAgg.getDistinctArguments()))
// ));
return ImmutableList.<PhysicalHashAggregate<? extends Plan>>builder()
.add(anyLocalGatherGlobalAgg)
.add(hashLocalGatherGlobalAgg)
.build();
} else {
RequireProperties requireHash = RequireProperties.of(
@ -1176,7 +1203,8 @@ public class AggregateStrategies implements ImplementationRuleFactory {
.withRequire(requireHash)
.withPartitionExpressions(logicalAgg.getGroupByExpressions());
return ImmutableList.<PhysicalHashAggregate<? extends Plan>>builder()
.add(anyLocalGatherGlobalAgg)
// TODO: usually bad, disable it until we could do better cost computation.
// .add(anyLocalGatherGlobalAgg)
.add(anyLocalHashGlobalAgg)
.build();
}
@ -1215,7 +1243,7 @@ public class AggregateStrategies implements ImplementationRuleFactory {
/**
* countDistinctMultiExprToCountIf.
*
* <p>
* NOTE: this function will break the normalized output, e.g. from `count(distinct slot1, slot2)` to
* `count(if(slot1 is null, null, slot2))`. So if you invoke this method, and separate the
* phase of aggregate, please normalize to slot and create a bottom project like NormalizeAggregate.
@ -1268,15 +1296,10 @@ public class AggregateStrategies implements ImplementationRuleFactory {
return connectContext == null || connectContext.getSessionVariable().enablePushDownNoGroupAgg();
}
private boolean enableSingleDistinctColumnOpt() {
ConnectContext connectContext = ConnectContext.get();
return connectContext == null || connectContext.getSessionVariable().enableSingleDistinctColumnOpt();
}
/**
* sql:
* select count(distinct name), sum(age) from student;
*
* <p>
* 4 phase plan
* DISTINCT_GLOBAL, BUFFER_TO_RESULT groupBy(), output[count(name), sum(age#5)], [GATHER]
* +--DISTINCT_LOCAL, INPUT_TO_BUFFER, groupBy()), output(count(name), partial_sum(age)), hash distribute by name
@ -1286,26 +1309,29 @@ public class AggregateStrategies implements ImplementationRuleFactory {
*/
private List<PhysicalHashAggregate<? extends Plan>> fourPhaseAggregateWithDistinct(
LogicalAggregate<? extends Plan> logicalAgg, ConnectContext connectContext) {
boolean couldBanned = couldConvertToMulti(logicalAgg);
Set<AggregateFunction> aggregateFunctions = logicalAgg.getAggregateFunctions();
Set<Expression> distinctArguments = aggregateFunctions.stream()
.filter(aggregateExpression -> aggregateExpression.isDistinct())
Set<NamedExpression> distinctArguments = aggregateFunctions.stream()
.filter(AggregateFunction::isDistinct)
.flatMap(aggregateExpression -> aggregateExpression.getArguments().stream())
.map(NamedExpression.class::cast)
.collect(ImmutableSet.toImmutableSet());
Set<NamedExpression> localAggGroupBySet = ImmutableSet.<NamedExpression>builder()
.addAll((List) logicalAgg.getGroupByExpressions())
.addAll((List<NamedExpression>) (List) logicalAgg.getGroupByExpressions())
.addAll(distinctArguments)
.build();
AggregateParam inputToBufferParam = new AggregateParam(AggPhase.LOCAL, AggMode.INPUT_TO_BUFFER, true);
AggregateParam inputToBufferParam = new AggregateParam(AggPhase.LOCAL, AggMode.INPUT_TO_BUFFER, couldBanned);
Map<AggregateFunction, Alias> nonDistinctAggFunctionToAliasPhase1 = aggregateFunctions.stream()
.filter(aggregateFunction -> !aggregateFunction.isDistinct())
.collect(ImmutableMap.toImmutableMap(expr -> expr, expr -> {
AggregateExpression localAggExpr = new AggregateExpression(expr, inputToBufferParam);
return new Alias(localAggExpr, localAggExpr.toSql());
}));
}, (oldValue, newValue) -> newValue));
List<NamedExpression> localAggOutput = ImmutableList.<NamedExpression>builder()
.addAll(localAggGroupBySet)
@ -1321,11 +1347,11 @@ public class AggregateStrategies implements ImplementationRuleFactory {
maybeUsingStreamAgg, Optional.empty(), logicalAgg.getLogicalProperties(),
requireAny, logicalAgg.child());
AggregateParam bufferToBufferParam = new AggregateParam(AggPhase.GLOBAL, AggMode.BUFFER_TO_BUFFER);
AggregateParam bufferToBufferParam = new AggregateParam(AggPhase.GLOBAL, AggMode.BUFFER_TO_BUFFER, couldBanned);
Map<AggregateFunction, Alias> nonDistinctAggFunctionToAliasPhase2 =
nonDistinctAggFunctionToAliasPhase1.entrySet()
.stream()
.collect(ImmutableMap.toImmutableMap(kv -> kv.getKey(), kv -> {
.collect(ImmutableMap.toImmutableMap(Entry::getKey, kv -> {
AggregateFunction originFunction = kv.getKey();
Alias localOutput = kv.getValue();
AggregateExpression globalAggExpr = new AggregateExpression(
@ -1350,7 +1376,8 @@ public class AggregateStrategies implements ImplementationRuleFactory {
requireDistinctHash, anyLocalAgg);
// phase 3
AggregateParam distinctLocalParam = new AggregateParam(AggPhase.DISTINCT_LOCAL, AggMode.INPUT_TO_BUFFER);
AggregateParam distinctLocalParam = new AggregateParam(
AggPhase.DISTINCT_LOCAL, AggMode.INPUT_TO_BUFFER, couldBanned);
Map<AggregateFunction, Alias> nonDistinctAggFunctionToAliasPhase3 = new HashMap<>();
List<NamedExpression> localDistinctOutput = Lists.newArrayList();
for (int i = 0; i < logicalAgg.getOutputExpressions().size(); i++) {
@ -1361,9 +1388,12 @@ public class AggregateStrategies implements ImplementationRuleFactory {
if (expr instanceof AggregateFunction) {
AggregateFunction aggregateFunction = (AggregateFunction) expr;
if (aggregateFunction.isDistinct()) {
Preconditions.checkArgument(aggregateFunction.arity() == 1);
Set<Expression> aggChild = Sets.newHashSet(aggregateFunction.children());
Preconditions.checkArgument(aggChild.size() == 1,
"cannot process more than one child in aggregate distinct function: "
+ aggregateFunction);
AggregateFunction nonDistinct = aggregateFunction
.withDistinctAndChildren(false, aggregateFunction.getArguments());
.withDistinctAndChildren(false, ImmutableList.copyOf(aggChild));
AggregateExpression nonDistinctAggExpr = new AggregateExpression(nonDistinct,
distinctLocalParam, aggregateFunction.child(0));
return nonDistinctAggExpr;
@ -1389,7 +1419,8 @@ public class AggregateStrategies implements ImplementationRuleFactory {
requireDistinctHash, anyLocalHashGlobalAgg);
//phase 4
AggregateParam distinctGlobalParam = new AggregateParam(AggPhase.DISTINCT_GLOBAL, AggMode.BUFFER_TO_RESULT);
AggregateParam distinctGlobalParam = new AggregateParam(
AggPhase.DISTINCT_GLOBAL, AggMode.BUFFER_TO_RESULT, couldBanned);
List<NamedExpression> globalDistinctOutput = Lists.newArrayList();
for (int i = 0; i < logicalAgg.getOutputExpressions().size(); i++) {
NamedExpression outputExpr = logicalAgg.getOutputExpressions().get(i);
@ -1397,9 +1428,12 @@ public class AggregateStrategies implements ImplementationRuleFactory {
if (expr instanceof AggregateFunction) {
AggregateFunction aggregateFunction = (AggregateFunction) expr;
if (aggregateFunction.isDistinct()) {
Preconditions.checkArgument(aggregateFunction.arity() == 1);
Set<Expression> aggChild = Sets.newHashSet(aggregateFunction.children());
Preconditions.checkArgument(aggChild.size() == 1,
"cannot process more than one child in aggregate distinct function: "
+ aggregateFunction);
AggregateFunction nonDistinct = aggregateFunction
.withDistinctAndChildren(false, aggregateFunction.getArguments());
.withDistinctAndChildren(false, ImmutableList.copyOf(aggChild));
int idx = logicalAgg.getOutputExpressions().indexOf(outputExpr);
Alias localDistinctAlias = (Alias) (localDistinctOutput.get(idx));
return new AggregateExpression(nonDistinct,
@ -1424,4 +1458,11 @@ public class AggregateStrategies implements ImplementationRuleFactory {
.add(distinctGlobal)
.build();
}
private boolean couldConvertToMulti(LogicalAggregate<? extends Plan> aggregate) {
return ExpressionUtils.noneMatch(aggregate.getOutputExpressions(), expr ->
expr instanceof AggregateFunction && ((AggregateFunction) expr).isDistinct()
&& (expr.arity() > 1
|| !(expr instanceof Count || expr instanceof Sum || expr instanceof GroupConcat)));
}
}

View File

@ -25,43 +25,35 @@ import java.util.Objects;
/** AggregateParam. */
public class AggregateParam {
public static AggregateParam LOCAL_RESULT = new AggregateParam(AggPhase.LOCAL, AggMode.INPUT_TO_RESULT);
public static AggregateParam LOCAL_BUFFER = new AggregateParam(AggPhase.LOCAL, AggMode.INPUT_TO_BUFFER);
public final AggPhase aggPhase;
public final AggMode aggMode;
// TODO remove this flag, and generate it in enforce and cost job
public boolean needColocateScan;
// TODO: this is a short-term plan to process count(distinct a, b) correctly
public final boolean canBeBanned;
/** AggregateParam */
public AggregateParam(AggPhase aggPhase, AggMode aggMode) {
this(aggPhase, aggMode, false);
this(aggPhase, aggMode, true);
}
/** AggregateParam */
public AggregateParam(AggPhase aggPhase, AggMode aggMode, boolean needColocateScan) {
public AggregateParam(AggPhase aggPhase, AggMode aggMode, boolean canBeBanned) {
this.aggMode = Objects.requireNonNull(aggMode, "aggMode cannot be null");
this.aggPhase = Objects.requireNonNull(aggPhase, "aggPhase cannot be null");
this.needColocateScan = needColocateScan;
}
public static AggregateParam localResult() {
return new AggregateParam(AggPhase.LOCAL, AggMode.INPUT_TO_RESULT, true);
this.canBeBanned = canBeBanned;
}
public AggregateParam withAggPhase(AggPhase aggPhase) {
return new AggregateParam(aggPhase, aggMode, needColocateScan);
return new AggregateParam(aggPhase, aggMode, canBeBanned);
}
public AggregateParam withAggPhase(AggMode aggMode) {
return new AggregateParam(aggPhase, aggMode, needColocateScan);
return new AggregateParam(aggPhase, aggMode, canBeBanned);
}
public AggregateParam withAppPhaseAndAppMode(AggPhase aggPhase, AggMode aggMode) {
return new AggregateParam(aggPhase, aggMode, needColocateScan);
}
public AggregateParam withNeedColocateScan(boolean needColocateScan) {
return new AggregateParam(aggPhase, aggMode, needColocateScan);
return new AggregateParam(aggPhase, aggMode, canBeBanned);
}
@Override
@ -87,7 +79,6 @@ public class AggregateParam {
return "AggregateParam{"
+ "aggPhase=" + aggPhase
+ ", aggMode=" + aggMode
+ ", needColocateScan=" + needColocateScan
+ '}';
}
}

View File

@ -35,7 +35,7 @@ import java.util.List;
/** MultiDistinctCount */
public class MultiDistinctCount extends AggregateFunction
implements AlwaysNotNullable, ExplicitlyCastableSignature {
implements AlwaysNotNullable, ExplicitlyCastableSignature, MultiDistinction {
public static final List<FunctionSignature> SIGNATURES = ImmutableList.of(
FunctionSignature.ret(BigIntType.INSTANCE).varArgs(AnyDataType.INSTANCE)

View File

@ -36,7 +36,7 @@ import java.util.List;
/** MultiDistinctGroupConcat */
public class MultiDistinctGroupConcat extends NullableAggregateFunction
implements ExplicitlyCastableSignature {
implements ExplicitlyCastableSignature, MultiDistinction {
public static final List<FunctionSignature> SIGNATURES = ImmutableList.of(
FunctionSignature.ret(VarcharType.SYSTEM_DEFAULT).args(VarcharType.SYSTEM_DEFAULT),

View File

@ -35,8 +35,8 @@ import com.google.common.collect.ImmutableList;
import java.util.List;
/** MultiDistinctSum */
public class MultiDistinctSum extends AggregateFunction
implements UnaryExpression, AlwaysNotNullable, ExplicitlyCastableSignature, ComputePrecisionForSum {
public class MultiDistinctSum extends AggregateFunction implements UnaryExpression, AlwaysNotNullable,
ExplicitlyCastableSignature, ComputePrecisionForSum, MultiDistinction {
public static final List<FunctionSignature> SIGNATURES = ImmutableList.of(
FunctionSignature.ret(BigIntType.INSTANCE).varArgs(BigIntType.INSTANCE),

View File

@ -0,0 +1,27 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.
package org.apache.doris.nereids.trees.expressions.functions.agg;
import org.apache.doris.nereids.trees.TreeNode;
import org.apache.doris.nereids.trees.expressions.Expression;
/**
* base class of multi-distinct agg function
*/
public interface MultiDistinction extends TreeNode<Expression> {
}

View File

@ -25,7 +25,7 @@ import org.apache.doris.nereids.trees.plans.UnaryPlan;
import org.apache.doris.nereids.trees.plans.logical.OutputPrunable;
import org.apache.doris.nereids.util.ExpressionUtils;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableSet;
import java.util.List;
import java.util.Set;
@ -53,10 +53,10 @@ public interface Aggregate<CHILD_TYPE extends Plan> extends UnaryPlan<CHILD_TYPE
return ExpressionUtils.collect(getOutputExpressions(), AggregateFunction.class::isInstance);
}
default List<Expression> getDistinctArguments() {
default Set<Expression> getDistinctArguments() {
return getAggregateFunctions().stream()
.filter(AggregateFunction::isDistinct)
.flatMap(aggregateExpression -> aggregateExpression.getArguments().stream())
.collect(ImmutableList.toImmutableList());
.collect(ImmutableSet.toImmutableSet());
}
}

View File

@ -21,7 +21,7 @@ import org.apache.doris.catalog.Table;
import org.apache.doris.nereids.memo.GroupExpression;
import org.apache.doris.nereids.properties.LogicalProperties;
import org.apache.doris.nereids.properties.PhysicalProperties;
import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.expressions.functions.agg.AggregateFunction;
import org.apache.doris.nereids.trees.expressions.functions.agg.Count;
import org.apache.doris.nereids.trees.expressions.functions.agg.Max;
import org.apache.doris.nereids.trees.expressions.functions.agg.Min;
@ -30,7 +30,6 @@ import org.apache.doris.nereids.trees.plans.visitor.PlanVisitor;
import org.apache.doris.nereids.util.Utils;
import org.apache.doris.statistics.Statistics;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import java.util.List;
@ -78,11 +77,6 @@ public class PhysicalStorageLayerAggregate extends PhysicalRelation {
return visitor.visitPhysicalStorageLayerAggregate(this, context);
}
@Override
public List<? extends Expression> getExpressions() {
return ImmutableList.of();
}
@Override
public boolean equals(Object o) {
if (this == o) {
@ -113,7 +107,7 @@ public class PhysicalStorageLayerAggregate extends PhysicalRelation {
}
public PhysicalStorageLayerAggregate withPhysicalOlapScan(PhysicalOlapScan physicalOlapScan) {
return new PhysicalStorageLayerAggregate(relation, aggOp);
return new PhysicalStorageLayerAggregate(physicalOlapScan, aggOp);
}
@Override
@ -142,8 +136,8 @@ public class PhysicalStorageLayerAggregate extends PhysicalRelation {
public enum PushDownAggOp {
COUNT, MIN_MAX, MIX;
public static Map<Class, PushDownAggOp> supportedFunctions() {
return ImmutableMap.<Class, PushDownAggOp>builder()
public static Map<Class<? extends AggregateFunction>, PushDownAggOp> supportedFunctions() {
return ImmutableMap.<Class<? extends AggregateFunction>, PushDownAggOp>builder()
.put(Count.class, PushDownAggOp.COUNT)
.put(Min.class, PushDownAggOp.MIN_MAX)
.put(Max.class, PushDownAggOp.MIN_MAX)

View File

@ -403,6 +403,11 @@ public class ExpressionUtils {
.anyMatch(expr -> expr.anyMatch(predicate));
}
public static boolean noneMatch(List<? extends Expression> expressions, Predicate<TreeNode<Expression>> predicate) {
return expressions.stream()
.noneMatch(expr -> expr.anyMatch(predicate));
}
public static boolean containsType(List<? extends Expression> expressions, Class type) {
return anyMatch(expressions, type::isInstance);
}

View File

@ -223,6 +223,8 @@ public class AggregateStrategiesTest implements MemoPatternMatchSupported {
* </pre>
*/
@Test
@Disabled
@Developing("reopen it after we could choose agg phase by CBO")
public void distinctAggregateWithoutGroupByApply2PhaseRule() {
List<Expression> groupExpressionList = new ArrayList<>();
List<NamedExpression> outputExpressionList = Lists.newArrayList(new Alias(

View File

@ -4,8 +4,8 @@ PhysicalTopN
--PhysicalTopN
----PhysicalProject
------hashAgg[GLOBAL]
--------hashAgg[LOCAL]
----------PhysicalDistribute
--------PhysicalDistribute
----------hashAgg[LOCAL]
------------PhysicalProject
--------------hashJoin[INNER_JOIN](ws1.ws_ship_date_sk = date_dim.d_date_sk)
----------------PhysicalProject

View File

@ -14,8 +14,8 @@ CteAnchor[cteId= ( CTEId#3=] )
----PhysicalTopN
------PhysicalProject
--------hashAgg[GLOBAL]
----------hashAgg[LOCAL]
------------PhysicalDistribute
----------PhysicalDistribute
------------hashAgg[LOCAL]
--------------PhysicalProject
----------------hashJoin[INNER_JOIN](ws1.ws_ship_date_sk = date_dim.d_date_sk)
------------------PhysicalProject

View File

@ -4,8 +4,8 @@ PhysicalQuickSort
--PhysicalDistribute
----PhysicalQuickSort
------hashAgg[GLOBAL]
--------hashAgg[LOCAL]
----------PhysicalDistribute
--------PhysicalDistribute
----------hashAgg[LOCAL]
------------PhysicalProject
--------------hashJoin[LEFT_ANTI_JOIN](partsupp.ps_suppkey = supplier.s_suppkey)
----------------hashJoin[INNER_JOIN](part.p_partkey = partsupp.ps_partkey)

View File

@ -4,8 +4,8 @@ PhysicalQuickSort
--PhysicalDistribute
----PhysicalQuickSort
------hashAgg[GLOBAL]
--------hashAgg[LOCAL]
----------PhysicalDistribute
--------PhysicalDistribute
----------hashAgg[LOCAL]
------------PhysicalProject
--------------hashJoin[LEFT_ANTI_JOIN](partsupp.ps_suppkey = supplier.s_suppkey)
----------------hashJoin[INNER_JOIN](part.p_partkey = partsupp.ps_partkey)

View File

@ -24,15 +24,6 @@ suite("test_join", "nereids_p0") {
def tbName1 = "test"
def tbName2 = "baseall"
def tbName3 = "bigtable"
def empty_name = "empty"
qt_agg_sql1 """select /*+SET_VAR(disable_nereids_rules='TWO_PHASE_AGGREGATE_WITH_COUNT_DISTINCT_MULTI')*/ count(distinct k1, NULL) from test;"""
qt_agg_sql2 """select /*+SET_VAR(disable_nereids_rules='TWO_PHASE_AGGREGATE_WITH_COUNT_DISTINCT_MULTI')*/ count(distinct k1, NULL), avg(k2) from baseall;"""
qt_agg_sql3 """select /*+SET_VAR(disable_nereids_rules='TWO_PHASE_AGGREGATE_WITH_COUNT_DISTINCT_MULTI')*/ k1,count(distinct k2,k3),min(k4),count(*) from baseall group by k1 order by k1;"""
qt_agg_sql4 """select /*+SET_VAR(disable_nereids_rules='THREE_PHASE_AGGREGATE_WITH_COUNT_DISTINCT_MULTI')*/ count(distinct k1, NULL) from test;"""
qt_agg_sql5 """select /*+SET_VAR(disable_nereids_rules='THREE_PHASE_AGGREGATE_WITH_COUNT_DISTINCT_MULTI')*/ count(distinct k1, NULL), avg(k2) from baseall;"""
qt_agg_sql6 """select /*+SET_VAR(disable_nereids_rules='THREE_PHASE_AGGREGATE_WITH_COUNT_DISTINCT_MULTI')*/ k1,count(distinct k2,k3),min(k4),count(*) from baseall group by k1 order by k1;"""
order_sql """select j.*, d.* from ${tbName2} j full outer join ${tbName1} d on (j.k1=d.k1) order by j.k1, j.k2, j.k3, j.k4, d.k1, d.k2
limit 100"""

View File

@ -43,16 +43,16 @@ suite("agg_4_phase") {
(0, 0, "aa", 10), (1, 1, "bb",20), (2, 2, "cc", 30), (1, 1, "bb",20);
"""
def test_sql = """
select /*+SET_VAR(disable_nereids_rules='THREE_PHASE_AGGREGATE_WITH_DISTINCT,TWO_PHASE_AGGREGATE_WITH_DISTINCT')*/
count(distinct name), sum(age)
select
count(distinct id)
from agg_4_phase_tbl;
"""
explain{
sql(test_sql)
contains "6:VAGGREGATE (merge finalize)"
contains "5:VEXCHANGE"
contains "4:VAGGREGATE (update serialize)"
contains "3:VAGGREGATE (merge serialize)"
contains "5:VAGGREGATE (merge finalize)"
contains "4:VEXCHANGE"
contains "3:VAGGREGATE (update serialize)"
contains "2:VAGGREGATE (merge serialize)"
contains "1:VAGGREGATE (update serialize)"
}
qt_4phase (test_sql)

View File

@ -82,7 +82,6 @@ suite("aggregate_strategies") {
explain {
sql """
select
/*+SET_VAR(disable_nereids_rules='ONE_PHASE_AGGREGATE_SINGLE_DISTINCT_TO_MULTI,TWO_PHASE_AGGREGATE_SINGLE_DISTINCT_TO_MULTI,THREE_PHASE_AGGREGATE_WITH_DISTINCT, FOUR_PHASE_AGGREGATE_WITH_DISTINCT')*/
count(distinct id)
from $tableName
"""
@ -90,17 +89,17 @@ suite("aggregate_strategies") {
notContains "STREAMING"
}
// test multi_distinct
test {
sql """select
/*+SET_VAR(disable_nereids_rules='TWO_PHASE_AGGREGATE_WITH_DISTINCT')*/
count(distinct id)
count(distinct name)
from $tableName"""
result([[5L]])
}
// test four phase distinct
test {
sql """select
/*+SET_VAR(disable_nereids_rules='THREE_PHASE_AGGREGATE_WITH_DISTINCT')*/
count(distinct id)
from $tableName"""
result([[5L]])

View File

@ -21,14 +21,14 @@ suite("group_concat") {
test {
sql """select /*+SET_VAR(disable_nereids_rules='TWO_PHASE_AGGREGATE_WITHOUT_DISTINCT')*/
sql """select
group_concat(cast(number as string), ',' order by number)
from numbers('number'='10')"""
result([["0,1,2,3,4,5,6,7,8,9"]])
}
test {
sql """select /*+SET_VAR(disable_nereids_rules='ONE_PHASE_AGGREGATE_WITHOUT_DISTINCT')*/
sql """select
group_concat(cast(number as string), ',' order by number)
from numbers('number'='10')"""
result([["0,1,2,3,4,5,6,7,8,9"]])