diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/glue/translator/PhysicalPlanTranslator.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/glue/translator/PhysicalPlanTranslator.java index 650bcc43a9..9844179633 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/glue/translator/PhysicalPlanTranslator.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/glue/translator/PhysicalPlanTranslator.java @@ -1742,7 +1742,7 @@ public class PhysicalPlanTranslator extends DefaultPlanVisitor { Expression function = e.child(0).child(0); if (function instanceof AggregateFunction) { - AggregateParam param = AggregateParam.localResult(); + AggregateParam param = AggregateParam.LOCAL_RESULT; function = new AggregateExpression((AggregateFunction) function, param); } return ExpressionTranslator.translate(function, context); diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/properties/ChildrenPropertiesRegulator.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/properties/ChildrenPropertiesRegulator.java index 811c26569a..54e4c4780a 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/properties/ChildrenPropertiesRegulator.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/properties/ChildrenPropertiesRegulator.java @@ -23,7 +23,12 @@ import org.apache.doris.nereids.cost.CostCalculator; import org.apache.doris.nereids.jobs.JobContext; import org.apache.doris.nereids.memo.GroupExpression; import org.apache.doris.nereids.properties.DistributionSpecHash.ShuffleType; +import org.apache.doris.nereids.trees.expressions.AggregateExpression; +import org.apache.doris.nereids.trees.expressions.Alias; import org.apache.doris.nereids.trees.expressions.ExprId; +import org.apache.doris.nereids.trees.expressions.Expression; +import org.apache.doris.nereids.trees.expressions.SlotReference; +import org.apache.doris.nereids.trees.expressions.functions.agg.MultiDistinction; import org.apache.doris.nereids.trees.plans.AggMode; import org.apache.doris.nereids.trees.plans.Plan; import org.apache.doris.nereids.trees.plans.physical.PhysicalDistribute; @@ -43,6 +48,7 @@ import java.util.ArrayList; import java.util.List; import java.util.Optional; import java.util.Set; +import java.util.stream.Collectors; /** * ensure child add enough distribute. update children properties if we do regular @@ -88,12 +94,55 @@ public class ChildrenPropertiesRegulator extends PlanVisitor { @Override public Boolean visitPhysicalHashAggregate(PhysicalHashAggregate agg, Void context) { - // forbid one phase agg on distribute - if (agg.getAggMode() == AggMode.INPUT_TO_RESULT + if (!agg.getAggregateParam().canBeBanned) { + return true; + } + // forbid one phase agg on distribute and three or four stage distinct agg inter by distribute + if ((agg.getAggMode() == AggMode.INPUT_TO_RESULT || agg.getAggMode() == AggMode.BUFFER_TO_BUFFER) && children.get(0).getPlan() instanceof PhysicalDistribute) { // this means one stage gather agg, usually bad pattern return false; } + // forbid TWO_PHASE_AGGREGATE_WITH_DISTINCT after shuffle + // TODO: this is forbid good plan after cte reuse by mistake + if (agg.getAggMode() == AggMode.INPUT_TO_BUFFER + && requiredProperties.get(0).getDistributionSpec() instanceof DistributionSpecHash + && children.get(0).getPlan() instanceof PhysicalDistribute) { + return false; + } + // forbid multi distinct opt that bad than multi-stage version when multi-stage can be executed in one fragment + if (agg.getAggMode() == AggMode.INPUT_TO_BUFFER || agg.getAggMode() == AggMode.INPUT_TO_RESULT) { + List multiDistinctions = agg.getOutputExpressions().stream() + .filter(Alias.class::isInstance) + .map(a -> ((Alias) a).child()) + .filter(AggregateExpression.class::isInstance) + .map(a -> ((AggregateExpression) a).getFunction()) + .filter(MultiDistinction.class::isInstance) + .map(MultiDistinction.class::cast) + .collect(Collectors.toList()); + if (multiDistinctions.size() == 1) { + Expression distinctChild = multiDistinctions.get(0).child(0); + DistributionSpec childDistribution = childrenProperties.get(0).getDistributionSpec(); + if (distinctChild instanceof SlotReference && childDistribution instanceof DistributionSpecHash) { + SlotReference slotReference = (SlotReference) distinctChild; + DistributionSpecHash distributionSpecHash = (DistributionSpecHash) childDistribution; + List groupByColumns = agg.getGroupByExpressions().stream() + .map(SlotReference.class::cast) + .map(SlotReference::getExprId) + .collect(Collectors.toList()); + DistributionSpecHash groupByRequire = new DistributionSpecHash( + groupByColumns, ShuffleType.REQUIRE); + List distinctChildColumns = Lists.newArrayList(slotReference.getExprId()); + distinctChildColumns.add(slotReference.getExprId()); + DistributionSpecHash distinctChildRequire = new DistributionSpecHash( + distinctChildColumns, ShuffleType.REQUIRE); + if ((!groupByColumns.isEmpty() && distributionSpecHash.satisfy(groupByRequire)) + || (groupByColumns.isEmpty() && distributionSpecHash.satisfy(distinctChildRequire))) { + return false; + } + } + } + } // process must shuffle visit(agg, context); // process agg diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/implementation/AggregateStrategies.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/implementation/AggregateStrategies.java index 551e73532c..6df6b2f817 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/implementation/AggregateStrategies.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/implementation/AggregateStrategies.java @@ -70,11 +70,13 @@ import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; import com.google.common.collect.ImmutableSet; import com.google.common.collect.Lists; +import com.google.common.collect.Sets; import java.util.Collection; import java.util.HashMap; import java.util.List; import java.util.Map; +import java.util.Map.Entry; import java.util.Optional; import java.util.Set; import java.util.stream.Collectors; @@ -123,41 +125,43 @@ public class AggregateStrategies implements ImplementationRuleFactory { .when(agg -> agg.getDistinctArguments().size() == 0) .thenApplyMulti(ctx -> twoPhaseAggregateWithoutDistinct(ctx.root, ctx.connectContext)) ), - RuleType.TWO_PHASE_AGGREGATE_WITH_COUNT_DISTINCT_MULTI.build( - basePattern - .when(this::containsCountDistinctMultiExpr) - .thenApplyMulti(ctx -> twoPhaseAggregateWithCountDistinctMulti(ctx.root, ctx.cascadesContext)) - ), + // RuleType.TWO_PHASE_AGGREGATE_WITH_COUNT_DISTINCT_MULTI.build( + // basePattern + // .when(this::containsCountDistinctMultiExpr) + // .thenApplyMulti(ctx -> twoPhaseAggregateWithCountDistinctMulti(ctx.root, ctx.cascadesContext)) + // ), RuleType.THREE_PHASE_AGGREGATE_WITH_COUNT_DISTINCT_MULTI.build( basePattern .when(this::containsCountDistinctMultiExpr) .thenApplyMulti(ctx -> threePhaseAggregateWithCountDistinctMulti(ctx.root, ctx.cascadesContext)) ), - RuleType.TWO_PHASE_AGGREGATE_WITH_DISTINCT.build( - basePattern - .when(agg -> agg.getDistinctArguments().size() == 1) - .thenApplyMulti(ctx -> twoPhaseAggregateWithDistinct(ctx.root, ctx.connectContext)) - ), RuleType.ONE_PHASE_AGGREGATE_SINGLE_DISTINCT_TO_MULTI.build( basePattern - .when(agg -> agg.getDistinctArguments().size() == 1 && enableSingleDistinctColumnOpt()) + .when(agg -> agg.getDistinctArguments().size() == 1 && couldConvertToMulti(agg)) .thenApplyMulti(ctx -> onePhaseAggregateWithMultiDistinct(ctx.root, ctx.connectContext)) ), RuleType.TWO_PHASE_AGGREGATE_SINGLE_DISTINCT_TO_MULTI.build( basePattern - .when(agg -> agg.getDistinctArguments().size() == 1 && enableSingleDistinctColumnOpt()) + .when(agg -> agg.getDistinctArguments().size() == 1 && couldConvertToMulti(agg)) .thenApplyMulti(ctx -> twoPhaseAggregateWithMultiDistinct(ctx.root, ctx.connectContext)) ), - RuleType.THREE_PHASE_AGGREGATE_WITH_DISTINCT.build( - basePattern - .when(agg -> agg.getDistinctArguments().size() == 1) - .thenApplyMulti(ctx -> threePhaseAggregateWithDistinct(ctx.root, ctx.connectContext)) - ), RuleType.TWO_PHASE_AGGREGATE_WITH_MULTI_DISTINCT.build( basePattern - .when(agg -> agg.getDistinctArguments().size() > 1 && !containsCountDistinctMultiExpr(agg)) + .when(agg -> agg.getDistinctArguments().size() > 1 + && !containsCountDistinctMultiExpr(agg) + && couldConvertToMulti(agg)) .thenApplyMulti(ctx -> twoPhaseAggregateWithMultiDistinct(ctx.root, ctx.connectContext)) ), + // RuleType.TWO_PHASE_AGGREGATE_WITH_DISTINCT.build( + // basePattern + // .when(agg -> agg.getDistinctArguments().size() == 1) + // .thenApplyMulti(ctx -> twoPhaseAggregateWithDistinct(ctx.root, ctx.connectContext)) + // ), + RuleType.THREE_PHASE_AGGREGATE_WITH_DISTINCT.build( + basePattern + .when(agg -> agg.getDistinctArguments().size() == 1) + .thenApplyMulti(ctx -> threePhaseAggregateWithDistinct(ctx.root, ctx.connectContext)) + ), RuleType.FOUR_PHASE_AGGREGATE_WITH_DISTINCT.build( basePattern .when(agg -> agg.getDistinctArguments().size() == 1) @@ -169,15 +173,15 @@ public class AggregateStrategies implements ImplementationRuleFactory { /** * sql: select count(*) from tbl - * + *

* before: - * + *

* LogicalAggregate(groupBy=[], output=[count(*)]) * | * LogicalOlapScan(table=tbl) - * + *

* after: - * + *

* LogicalAggregate(groupBy=[], output=[count(*)]) * | * PhysicalStorageLayerAggregate(pushAggOp=COUNT, table=PhysicalOlapScan(table=tbl)) @@ -205,7 +209,7 @@ public class AggregateStrategies implements ImplementationRuleFactory { .map(AggregateFunction::getClass) .collect(Collectors.toSet()); - Map supportedAgg = PushDownAggOp.supportedFunctions(); + Map, PushDownAggOp> supportedAgg = PushDownAggOp.supportedFunctions(); if (!supportedAgg.keySet().containsAll(functionClasses)) { return canNotPush; } @@ -292,7 +296,7 @@ public class AggregateStrategies implements ImplementationRuleFactory { || colType == PrimitiveType.STRING) { return canNotPush; } - if (colType.isCharFamily() && mergeOp != PushDownAggOp.COUNT && column.getType().getLength() > 512) { + if (colType.isCharFamily() && column.getType().getLength() > 512) { return canNotPush; } } @@ -324,25 +328,25 @@ public class AggregateStrategies implements ImplementationRuleFactory { /** * sql: select count(*) from tbl group by id - * + *

* before: - * + *

* LogicalAggregate(groupBy=[id], output=[count(*)]) * | * LogicalOlapScan(table=tbl) - * + *

* after: - * + *

* single node aggregate: - * + *

* PhysicalHashAggregate(groupBy=[id], output=[count(*)]) * | * PhysicalDistribute(distributionSpec=GATHER) * | * LogicalOlapScan(table=tbl) - * + *

* distribute node aggregate: - * + *

* PhysicalHashAggregate(groupBy=[id], output=[count(*)]) * | * LogicalOlapScan(table=tbl, **already distribute by id**) @@ -351,7 +355,7 @@ public class AggregateStrategies implements ImplementationRuleFactory { private List> onePhaseAggregateWithoutDistinct( LogicalAggregate logicalAgg, ConnectContext connectContext) { RequireProperties requireGather = RequireProperties.of(PhysicalProperties.GATHER); - AggregateParam inputToResultParam = AggregateParam.localResult(); + AggregateParam inputToResultParam = AggregateParam.LOCAL_RESULT; List newOutput = ExpressionUtils.rewriteDownShortCircuit( logicalAgg.getOutputExpressions(), outputChild -> { if (outputChild instanceof AggregateFunction) { @@ -366,7 +370,9 @@ public class AggregateStrategies implements ImplementationRuleFactory { requireGather, logicalAgg.child()); if (logicalAgg.getGroupByExpressions().isEmpty()) { - return ImmutableList.of(gatherLocalAgg); + // TODO: usually bad, disable it until we could do better cost computation. + // return ImmutableList.of(gatherLocalAgg); + return ImmutableList.of(); } else { RequireProperties requireHash = RequireProperties.of( PhysicalProperties.createHash(logicalAgg.getGroupByExpressions(), ShuffleType.REQUIRE)); @@ -383,17 +389,17 @@ public class AggregateStrategies implements ImplementationRuleFactory { /** * sql: select count(distinct id, name) from tbl group by name - * + *

* before: - * + *

* LogicalAggregate(groupBy=[name], output=[count(distinct id, name)]) * | * LogicalOlapScan(table=tbl) - * + *

* after: - * + *

* single node aggregate: - * + *

* PhysicalHashAggregate(groupBy=[name], output=[count(if(id is null, null, name))]) * | * PhysicalHashAggregate(groupBy=[name, id], output=[name, id]) @@ -401,9 +407,9 @@ public class AggregateStrategies implements ImplementationRuleFactory { * PhysicalDistribute(distributionSpec=GATHER) * | * LogicalOlapScan(table=tbl) - * + *

* distribute node aggregate: - * + *

* PhysicalHashAggregate(groupBy=[name], output=[count(if(id is null, null, name))]) * | * PhysicalHashAggregate(groupBy=[name, id], output=[name, id]) @@ -415,7 +421,7 @@ public class AggregateStrategies implements ImplementationRuleFactory { */ private List> twoPhaseAggregateWithCountDistinctMulti( LogicalAggregate logicalAgg, CascadesContext cascadesContext) { - AggregateParam inputToBufferParam = new AggregateParam(AggPhase.LOCAL, AggMode.INPUT_TO_BUFFER); + AggregateParam inputToBufferParam = AggregateParam.LOCAL_BUFFER; Collection countDistinctArguments = logicalAgg.getDistinctArguments(); List localAggGroupBy = ImmutableList.copyOf(ImmutableSet.builder() @@ -487,7 +493,8 @@ public class AggregateStrategies implements ImplementationRuleFactory { .withRequireTree(requireHash.withChildren(requireHash)) .withPartitionExpressions(logicalAgg.getGroupByExpressions()); return ImmutableList.>builder() - .add(gatherLocalGatherDistinctAgg) + // TODO: usually bad, disable it until we could do better cost computation. + //.add(gatherLocalGatherDistinctAgg) .add(hashLocalHashGlobalAgg) .build(); } @@ -495,17 +502,17 @@ public class AggregateStrategies implements ImplementationRuleFactory { /** * sql: select count(distinct id, name) from tbl group by name - * + *

* before: - * + *

* LogicalAggregate(groupBy=[name], output=[count(distinct id, name)]) * | * LogicalOlapScan(table=tbl) - * + *

* after: - * + *

* single node aggregate: - * + *

* PhysicalHashAggregate(groupBy=[name], output=[count(if(id is null, null, name))]) * | * PhysicalHashAggregate(groupBy=[name, id], output=[name, id], mode=BUFFER_TO_BUFFER) @@ -515,9 +522,9 @@ public class AggregateStrategies implements ImplementationRuleFactory { * PhysicalHashAggregate(groupBy=[name, id], output=[name, id], mode=INPUT_TO_BUFFER) * | * LogicalOlapScan(table=tbl) - * + *

* distribute node aggregate: - * + *

* PhysicalHashAggregate(groupBy=[name], output=[count(if(id is null, null, name))]) * | * PhysicalHashAggregate(groupBy=[name, id], output=[name, id], mode=BUFFER_TO_BUFFER) @@ -566,11 +573,17 @@ public class AggregateStrategies implements ImplementationRuleFactory { List globalAggGroupBy = localAggGroupBy; - AggregateParam bufferToBufferParam = new AggregateParam(AggPhase.GLOBAL, AggMode.BUFFER_TO_BUFFER); + boolean hasCountDistinctMulti = logicalAgg.getAggregateFunctions().stream() + .filter(AggregateFunction::isDistinct) + .filter(Count.class::isInstance) + .anyMatch(c -> c.arity() > 1); + AggregateParam bufferToBufferParam = new AggregateParam( + AggPhase.GLOBAL, AggMode.BUFFER_TO_BUFFER, !hasCountDistinctMulti); + Map nonDistinctAggFunctionToAliasPhase2 = nonDistinctAggFunctionToAliasPhase1.entrySet() .stream() - .collect(ImmutableMap.toImmutableMap(kv -> kv.getKey(), kv -> { + .collect(ImmutableMap.toImmutableMap(Entry::getKey, kv -> { AggregateFunction originFunction = kv.getKey(); Alias localOutputAlias = kv.getValue(); AggregateExpression globalAggExpr = new AggregateExpression( @@ -596,7 +609,7 @@ public class AggregateStrategies implements ImplementationRuleFactory { logicalAgg, cascadesContext).first; AggregateParam distinctInputToResultParam - = new AggregateParam(AggPhase.DISTINCT_LOCAL, AggMode.INPUT_TO_RESULT); + = new AggregateParam(AggPhase.DISTINCT_LOCAL, AggMode.INPUT_TO_RESULT, !hasCountDistinctMulti); AggregateParam globalBufferToResultParam = new AggregateParam(AggPhase.GLOBAL, AggMode.BUFFER_TO_RESULT); List distinctOutput = ExpressionUtils.rewriteDownShortCircuit( @@ -621,19 +634,19 @@ public class AggregateStrategies implements ImplementationRuleFactory { logicalAgg.getLogicalProperties(), requireGather, anyLocalGatherGlobalAgg ); - RequireProperties requireDistinctHash = RequireProperties.of( - PhysicalProperties.createHash(logicalAgg.getGroupByExpressions(), ShuffleType.REQUIRE)); - PhysicalHashAggregate anyLocalHashGlobalGatherDistinctAgg - = anyLocalGatherGlobalGatherAgg.withChildren(ImmutableList.of( - anyLocalGatherGlobalAgg - .withRequire(requireDistinctHash) - .withPartitionExpressions(ImmutableList.copyOf(logicalAgg.getDistinctArguments())) - )); + // RequireProperties requireDistinctHash = RequireProperties.of( + // PhysicalProperties.createHash(logicalAgg.getGroupByExpressions(), ShuffleType.REQUIRE)); + // PhysicalHashAggregate anyLocalHashGlobalGatherDistinctAgg + // = anyLocalGatherGlobalGatherAgg.withChildren(ImmutableList.of( + // anyLocalGatherGlobalAgg + // .withRequire(requireDistinctHash) + // .withPartitionExpressions(ImmutableList.copyOf(logicalAgg.getDistinctArguments())) + // )); if (logicalAgg.getGroupByExpressions().isEmpty()) { return ImmutableList.>builder() .add(anyLocalGatherGlobalGatherAgg) - .add(anyLocalHashGlobalGatherDistinctAgg) + //.add(anyLocalHashGlobalGatherDistinctAgg) .build(); } else { RequireProperties requireGroupByHash = RequireProperties.of( @@ -646,8 +659,8 @@ public class AggregateStrategies implements ImplementationRuleFactory { ) .withPartitionExpressions(logicalAgg.getGroupByExpressions()); return ImmutableList.>builder() - .add(anyLocalGatherGlobalGatherAgg) - .add(anyLocalHashGlobalGatherDistinctAgg) + // .add(anyLocalGatherGlobalGatherAgg) + // .add(anyLocalHashGlobalGatherDistinctAgg) .add(anyLocalHashGlobalHashDistinctAgg) .build(); } @@ -655,17 +668,17 @@ public class AggregateStrategies implements ImplementationRuleFactory { /** * sql: select name, count(value) from tbl group by name - * + *

* before: - * + *

* LogicalAggregate(groupBy=[name], output=[name, count(value)]) * | * LogicalOlapScan(table=tbl) - * + *

* after: - * + *

* single node aggregate: - * + *

* PhysicalHashAggregate(groupBy=[name], output=[name, count(value)], mode=BUFFER_TO_RESULT) * | * PhysicalDistribute(distributionSpec=GATHER) @@ -673,9 +686,9 @@ public class AggregateStrategies implements ImplementationRuleFactory { * PhysicalHashAggregate(groupBy=[name], output=[name, count(value)], mode=INPUT_TO_BUFFER) * | * LogicalOlapScan(table=tbl) - * + *

* distribute node aggregate: - * + *

* PhysicalHashAggregate(groupBy=[name], output=[name, count(value)], mode=BUFFER_TO_RESULT) * | * PhysicalDistribute(distributionSpec=HASH(name)) @@ -713,6 +726,9 @@ public class AggregateStrategies implements ImplementationRuleFactory { AggregateParam bufferToResultParam = new AggregateParam(AggPhase.GLOBAL, AggMode.BUFFER_TO_RESULT); List globalAggOutput = ExpressionUtils.rewriteDownShortCircuit( logicalAgg.getOutputExpressions(), outputChild -> { + if (!(outputChild instanceof AggregateFunction)) { + return outputChild; + } Alias inputToBufferAlias = inputToBufferAliases.get(outputChild); if (inputToBufferAlias == null) { return outputChild; @@ -722,7 +738,7 @@ public class AggregateStrategies implements ImplementationRuleFactory { }); RequireProperties requireGather = RequireProperties.of(PhysicalProperties.GATHER); - PhysicalHashAggregate anyLocalGatherGlobalAgg = new PhysicalHashAggregate( + PhysicalHashAggregate anyLocalGatherGlobalAgg = new PhysicalHashAggregate<>( localAggGroupBy, globalAggOutput, Optional.of(partitionExpressions), bufferToResultParam, false, anyLocalAgg.getLogicalProperties(), requireGather, anyLocalAgg); @@ -746,17 +762,17 @@ public class AggregateStrategies implements ImplementationRuleFactory { /** * sql: select count(distinct id) from tbl group by name - * + *

* before: - * + *

* LogicalAggregate(groupBy=[name], output=[name, count(distinct id)]) * | * LogicalOlapScan(table=tbl) - * + *

* after: - * + *

* single node aggregate: - * + *

* PhysicalHashAggregate(groupBy=[name], output=[name, count(distinct(id)], mode=BUFFER_TO_RESULT) * | * PhysicalHashAggregate(groupBy=[name, id], output=[name, id], mode=INPUT_TO_BUFFER) @@ -764,9 +780,9 @@ public class AggregateStrategies implements ImplementationRuleFactory { * PhysicalDistribute(distributionSpec=GATHER) * | * LogicalOlapScan(table=tbl) - * + *

* distribute node aggregate: - * + *

* PhysicalHashAggregate(groupBy=[name], output=[name, count(distinct(id)], mode=BUFFER_TO_RESULT) * | * PhysicalHashAggregate(groupBy=[name, id], output=[name, id], mode=INPUT_TO_BUFFER) @@ -781,7 +797,7 @@ public class AggregateStrategies implements ImplementationRuleFactory { Set aggregateFunctions = logicalAgg.getAggregateFunctions(); Set distinctArguments = aggregateFunctions.stream() - .filter(aggregateExpression -> aggregateExpression.isDistinct()) + .filter(AggregateFunction::isDistinct) .flatMap(aggregateExpression -> aggregateExpression.getArguments().stream()) .collect(ImmutableSet.toImmutableSet()); @@ -790,7 +806,7 @@ public class AggregateStrategies implements ImplementationRuleFactory { .addAll(distinctArguments) .build(); - AggregateParam inputToBufferParam = new AggregateParam(AggPhase.LOCAL, AggMode.INPUT_TO_BUFFER, true); + AggregateParam inputToBufferParam = AggregateParam.LOCAL_BUFFER; Map nonDistinctAggFunctionToAliasPhase1 = aggregateFunctions.stream() .filter(aggregateFunction -> !aggregateFunction.isDistinct()) @@ -822,10 +838,13 @@ public class AggregateStrategies implements ImplementationRuleFactory { if (outputChild instanceof AggregateFunction) { AggregateFunction aggregateFunction = (AggregateFunction) outputChild; if (aggregateFunction.isDistinct()) { - Preconditions.checkArgument(aggregateFunction.arity() == 1); + Set aggChild = Sets.newHashSet(aggregateFunction.children()); + Preconditions.checkArgument(aggChild.size() == 1, + "cannot process more than one child in aggregate distinct function: " + + aggregateFunction); AggregateFunction nonDistinct = aggregateFunction - .withDistinctAndChildren(false, aggregateFunction.getArguments()); - return new AggregateExpression(nonDistinct, AggregateParam.localResult()); + .withDistinctAndChildren(false, ImmutableList.copyOf(aggChild)); + return new AggregateExpression(nonDistinct, AggregateParam.LOCAL_RESULT); } else { Alias alias = nonDistinctAggFunctionToAliasPhase1.get(outputChild); return new AggregateExpression( @@ -850,7 +869,7 @@ public class AggregateStrategies implements ImplementationRuleFactory { .withPartitionExpressions(ImmutableList.copyOf(logicalAgg.getDistinctArguments())) )); return ImmutableList.>builder() - .add(gatherLocalGatherGlobalAgg) + //.add(gatherLocalGatherGlobalAgg) .add(hashLocalGatherGlobalAgg) .build(); } else { @@ -863,7 +882,7 @@ public class AggregateStrategies implements ImplementationRuleFactory { ) .withPartitionExpressions(logicalAgg.getGroupByExpressions()); return ImmutableList.>builder() - .add(gatherLocalGatherGlobalAgg) + // .add(gatherLocalGatherGlobalAgg) .add(hashLocalHashGlobalAgg) .build(); } @@ -871,16 +890,16 @@ public class AggregateStrategies implements ImplementationRuleFactory { /** * sql: select count(distinct id) from tbl group by name - * + *

* before: - * + *

* LogicalAggregate(groupBy=[name], output=[name, count(distinct id)]) * | * LogicalOlapScan(table=tbl) - * + *

* after: * single node aggregate: - * + *

* PhysicalHashAggregate(groupBy=[name], output=[name, count(distinct(id)], mode=BUFFER_TO_RESULT) * | * PhysicalHashAggregate(groupBy=[name, id], output=[name, id], mode=BUFFER_TO_BUFFER) @@ -890,9 +909,9 @@ public class AggregateStrategies implements ImplementationRuleFactory { * PhysicalHashAggregate(groupBy=[name, id], output=[name, id], mode=INPUT_TO_BUFFER) * | * LogicalOlapScan(table=tbl) - * + *

* distribute node aggregate: - * + *

* PhysicalHashAggregate(groupBy=[name], output=[name, count(distinct(id)], mode=BUFFER_TO_RESULT) * | * PhysicalHashAggregate(groupBy=[name, id], output=[name, id], mode=BUFFER_TO_BUFFER) @@ -907,10 +926,12 @@ public class AggregateStrategies implements ImplementationRuleFactory { // TODO: support one phase aggregate(group by columns + distinct columns) + two phase distinct aggregate private List> threePhaseAggregateWithDistinct( LogicalAggregate logicalAgg, ConnectContext connectContext) { + boolean couldBanned = couldConvertToMulti(logicalAgg); + Set aggregateFunctions = logicalAgg.getAggregateFunctions(); Set distinctArguments = aggregateFunctions.stream() - .filter(aggregateExpression -> aggregateExpression.isDistinct()) + .filter(AggregateFunction::isDistinct) .flatMap(aggregateExpression -> aggregateExpression.getArguments().stream()) .collect(ImmutableSet.toImmutableSet()); @@ -919,7 +940,7 @@ public class AggregateStrategies implements ImplementationRuleFactory { .addAll(distinctArguments) .build(); - AggregateParam inputToBufferParam = new AggregateParam(AggPhase.LOCAL, AggMode.INPUT_TO_BUFFER); + AggregateParam inputToBufferParam = new AggregateParam(AggPhase.LOCAL, AggMode.INPUT_TO_BUFFER, couldBanned); Map nonDistinctAggFunctionToAliasPhase1 = aggregateFunctions.stream() .filter(aggregateFunction -> !aggregateFunction.isDistinct()) @@ -942,11 +963,11 @@ public class AggregateStrategies implements ImplementationRuleFactory { maybeUsingStreamAgg, Optional.empty(), logicalAgg.getLogicalProperties(), requireAny, logicalAgg.child()); - AggregateParam bufferToBufferParam = new AggregateParam(AggPhase.GLOBAL, AggMode.BUFFER_TO_BUFFER); + AggregateParam bufferToBufferParam = new AggregateParam(AggPhase.GLOBAL, AggMode.BUFFER_TO_BUFFER, couldBanned); Map nonDistinctAggFunctionToAliasPhase2 = nonDistinctAggFunctionToAliasPhase1.entrySet() .stream() - .collect(ImmutableMap.toImmutableMap(kv -> kv.getKey(), kv -> { + .collect(ImmutableMap.toImmutableMap(Entry::getKey, kv -> { AggregateFunction originFunction = kv.getKey(); Alias localOutput = kv.getValue(); AggregateExpression globalAggExpr = new AggregateExpression( @@ -965,15 +986,19 @@ public class AggregateStrategies implements ImplementationRuleFactory { bufferToBufferParam, false, logicalAgg.getLogicalProperties(), requireGather, anyLocalAgg); - AggregateParam bufferToResultParam = new AggregateParam(AggPhase.DISTINCT_LOCAL, AggMode.INPUT_TO_RESULT); + AggregateParam bufferToResultParam = new AggregateParam( + AggPhase.DISTINCT_LOCAL, AggMode.INPUT_TO_RESULT, couldBanned); List distinctOutput = ExpressionUtils.rewriteDownShortCircuit( logicalAgg.getOutputExpressions(), expr -> { if (expr instanceof AggregateFunction) { AggregateFunction aggregateFunction = (AggregateFunction) expr; if (aggregateFunction.isDistinct()) { - Preconditions.checkArgument(aggregateFunction.arity() == 1); + Set aggChild = Sets.newHashSet(aggregateFunction.children()); + Preconditions.checkArgument(aggChild.size() == 1, + "cannot process more than one child in aggregate distinct function: " + + aggregateFunction); AggregateFunction nonDistinct = aggregateFunction - .withDistinctAndChildren(false, aggregateFunction.getArguments()); + .withDistinctAndChildren(false, ImmutableList.copyOf(aggChild)); return new AggregateExpression(nonDistinct, bufferToResultParam, aggregateFunction.child(0)); } else { @@ -1017,8 +1042,9 @@ public class AggregateStrategies implements ImplementationRuleFactory { ) .withPartitionExpressions(logicalAgg.getGroupByExpressions()); return ImmutableList.>builder() - .add(anyLocalGatherGlobalGatherDistinctAgg) - .add(anyLocalHashGlobalGatherDistinctAgg) + // TODO: this plan pattern is not good usually, we remove it temporary. + //.add(anyLocalGatherGlobalGatherDistinctAgg) + //.add(anyLocalHashGlobalGatherDistinctAgg) .add(anyLocalHashGlobalHashDistinctAgg) .build(); } @@ -1026,25 +1052,25 @@ public class AggregateStrategies implements ImplementationRuleFactory { /** * sql: select count(distinct id) from (...) group by name - * + *

* before: - * + *

* LogicalAggregate(groupBy=[name], output=[count(distinct id)]) * | * any plan - * + *

* after: - * + *

* single node aggregate: - * + *

* PhysicalHashAggregate(groupBy=[name], output=[multi_distinct_count(id)]) * | * PhysicalDistribute(distributionSpec=GATHER) * | * any plan - * + *

* distribute node aggregate: - * + *

* PhysicalHashAggregate(groupBy=[name], output=[multi_distinct_count(id)]) * | * any plan(**already distribute by name**) @@ -1052,7 +1078,7 @@ public class AggregateStrategies implements ImplementationRuleFactory { */ private List> onePhaseAggregateWithMultiDistinct( LogicalAggregate logicalAgg, ConnectContext connectContext) { - AggregateParam inputToResultParam = AggregateParam.localResult(); + AggregateParam inputToResultParam = AggregateParam.LOCAL_RESULT; List newOutput = ExpressionUtils.rewriteDownShortCircuit( logicalAgg.getOutputExpressions(), outputChild -> { if (outputChild instanceof AggregateFunction) { @@ -1068,7 +1094,9 @@ public class AggregateStrategies implements ImplementationRuleFactory { maybeUsingStreamAgg(connectContext, logicalAgg), logicalAgg.getLogicalProperties(), requireGather, logicalAgg.child()); if (logicalAgg.getGroupByExpressions().isEmpty()) { - return ImmutableList.of(gatherLocalAgg); + // TODO: usually bad, disable it until we could do better cost computation. + // return ImmutableList.of(gatherLocalAgg); + return ImmutableList.of(); } else { RequireProperties requireHash = RequireProperties.of( PhysicalProperties.createHash(logicalAgg.getGroupByExpressions(), ShuffleType.REQUIRE)); @@ -1085,17 +1113,17 @@ public class AggregateStrategies implements ImplementationRuleFactory { /** * sql: select count(distinct id) from tbl group by name - * + *

* before: - * + *

* LogicalAggregate(groupBy=[name], output=[name, count(distinct id)]) * | * LogicalOlapScan(table=tbl) - * + *

* after: - * + *

* single node aggregate: - * + *

* PhysicalHashAggregate(groupBy=[name], output=[name, multi_count_distinct(value)], mode=BUFFER_TO_RESULT) * | * PhysicalDistribute(distributionSpec=GATHER) @@ -1103,9 +1131,9 @@ public class AggregateStrategies implements ImplementationRuleFactory { * PhysicalHashAggregate(groupBy=[name], output=[name, multi_count_distinct(value)], mode=INPUT_TO_BUFFER) * | * LogicalOlapScan(table=tbl) - * + *

* distribute node aggregate: - * + *

* PhysicalHashAggregate(groupBy=[name], output=[name, multi_count_distinct(value)], mode=BUFFER_TO_RESULT) * | * PhysicalDistribute(distributionSpec=HASH(name)) @@ -1157,17 +1185,16 @@ public class AggregateStrategies implements ImplementationRuleFactory { RequireProperties.of(PhysicalProperties.GATHER), anyLocalAgg); if (logicalAgg.getGroupByExpressions().isEmpty()) { - Collection distinctArguments = logicalAgg.getDistinctArguments(); - RequireProperties requireDistinctHash = RequireProperties.of(PhysicalProperties.createHash( - distinctArguments, ShuffleType.REQUIRE)); - PhysicalHashAggregate hashLocalGatherGlobalAgg = anyLocalGatherGlobalAgg - .withChildren(ImmutableList.of(anyLocalAgg - .withRequire(requireDistinctHash) - .withPartitionExpressions(ImmutableList.copyOf(logicalAgg.getDistinctArguments())) - )); + // Collection distinctArguments = logicalAgg.getDistinctArguments(); + // RequireProperties requireDistinctHash = RequireProperties.of(PhysicalProperties.createHash( + // distinctArguments, ShuffleType.REQUIRE)); + // PhysicalHashAggregate hashLocalGatherGlobalAgg = anyLocalGatherGlobalAgg + // .withChildren(ImmutableList.of(anyLocalAgg + // .withRequire(requireDistinctHash) + // .withPartitionExpressions(ImmutableList.copyOf(logicalAgg.getDistinctArguments())) + // )); return ImmutableList.>builder() .add(anyLocalGatherGlobalAgg) - .add(hashLocalGatherGlobalAgg) .build(); } else { RequireProperties requireHash = RequireProperties.of( @@ -1176,7 +1203,8 @@ public class AggregateStrategies implements ImplementationRuleFactory { .withRequire(requireHash) .withPartitionExpressions(logicalAgg.getGroupByExpressions()); return ImmutableList.>builder() - .add(anyLocalGatherGlobalAgg) + // TODO: usually bad, disable it until we could do better cost computation. + // .add(anyLocalGatherGlobalAgg) .add(anyLocalHashGlobalAgg) .build(); } @@ -1215,7 +1243,7 @@ public class AggregateStrategies implements ImplementationRuleFactory { /** * countDistinctMultiExprToCountIf. - * + *

* NOTE: this function will break the normalized output, e.g. from `count(distinct slot1, slot2)` to * `count(if(slot1 is null, null, slot2))`. So if you invoke this method, and separate the * phase of aggregate, please normalize to slot and create a bottom project like NormalizeAggregate. @@ -1268,15 +1296,10 @@ public class AggregateStrategies implements ImplementationRuleFactory { return connectContext == null || connectContext.getSessionVariable().enablePushDownNoGroupAgg(); } - private boolean enableSingleDistinctColumnOpt() { - ConnectContext connectContext = ConnectContext.get(); - return connectContext == null || connectContext.getSessionVariable().enableSingleDistinctColumnOpt(); - } - /** * sql: * select count(distinct name), sum(age) from student; - * + *

* 4 phase plan * DISTINCT_GLOBAL, BUFFER_TO_RESULT groupBy(), output[count(name), sum(age#5)], [GATHER] * +--DISTINCT_LOCAL, INPUT_TO_BUFFER, groupBy()), output(count(name), partial_sum(age)), hash distribute by name @@ -1286,26 +1309,29 @@ public class AggregateStrategies implements ImplementationRuleFactory { */ private List> fourPhaseAggregateWithDistinct( LogicalAggregate logicalAgg, ConnectContext connectContext) { + boolean couldBanned = couldConvertToMulti(logicalAgg); + Set aggregateFunctions = logicalAgg.getAggregateFunctions(); - Set distinctArguments = aggregateFunctions.stream() - .filter(aggregateExpression -> aggregateExpression.isDistinct()) + Set distinctArguments = aggregateFunctions.stream() + .filter(AggregateFunction::isDistinct) .flatMap(aggregateExpression -> aggregateExpression.getArguments().stream()) + .map(NamedExpression.class::cast) .collect(ImmutableSet.toImmutableSet()); Set localAggGroupBySet = ImmutableSet.builder() - .addAll((List) logicalAgg.getGroupByExpressions()) + .addAll((List) (List) logicalAgg.getGroupByExpressions()) .addAll(distinctArguments) .build(); - AggregateParam inputToBufferParam = new AggregateParam(AggPhase.LOCAL, AggMode.INPUT_TO_BUFFER, true); + AggregateParam inputToBufferParam = new AggregateParam(AggPhase.LOCAL, AggMode.INPUT_TO_BUFFER, couldBanned); Map nonDistinctAggFunctionToAliasPhase1 = aggregateFunctions.stream() .filter(aggregateFunction -> !aggregateFunction.isDistinct()) .collect(ImmutableMap.toImmutableMap(expr -> expr, expr -> { AggregateExpression localAggExpr = new AggregateExpression(expr, inputToBufferParam); return new Alias(localAggExpr, localAggExpr.toSql()); - })); + }, (oldValue, newValue) -> newValue)); List localAggOutput = ImmutableList.builder() .addAll(localAggGroupBySet) @@ -1321,11 +1347,11 @@ public class AggregateStrategies implements ImplementationRuleFactory { maybeUsingStreamAgg, Optional.empty(), logicalAgg.getLogicalProperties(), requireAny, logicalAgg.child()); - AggregateParam bufferToBufferParam = new AggregateParam(AggPhase.GLOBAL, AggMode.BUFFER_TO_BUFFER); + AggregateParam bufferToBufferParam = new AggregateParam(AggPhase.GLOBAL, AggMode.BUFFER_TO_BUFFER, couldBanned); Map nonDistinctAggFunctionToAliasPhase2 = nonDistinctAggFunctionToAliasPhase1.entrySet() .stream() - .collect(ImmutableMap.toImmutableMap(kv -> kv.getKey(), kv -> { + .collect(ImmutableMap.toImmutableMap(Entry::getKey, kv -> { AggregateFunction originFunction = kv.getKey(); Alias localOutput = kv.getValue(); AggregateExpression globalAggExpr = new AggregateExpression( @@ -1350,7 +1376,8 @@ public class AggregateStrategies implements ImplementationRuleFactory { requireDistinctHash, anyLocalAgg); // phase 3 - AggregateParam distinctLocalParam = new AggregateParam(AggPhase.DISTINCT_LOCAL, AggMode.INPUT_TO_BUFFER); + AggregateParam distinctLocalParam = new AggregateParam( + AggPhase.DISTINCT_LOCAL, AggMode.INPUT_TO_BUFFER, couldBanned); Map nonDistinctAggFunctionToAliasPhase3 = new HashMap<>(); List localDistinctOutput = Lists.newArrayList(); for (int i = 0; i < logicalAgg.getOutputExpressions().size(); i++) { @@ -1361,9 +1388,12 @@ public class AggregateStrategies implements ImplementationRuleFactory { if (expr instanceof AggregateFunction) { AggregateFunction aggregateFunction = (AggregateFunction) expr; if (aggregateFunction.isDistinct()) { - Preconditions.checkArgument(aggregateFunction.arity() == 1); + Set aggChild = Sets.newHashSet(aggregateFunction.children()); + Preconditions.checkArgument(aggChild.size() == 1, + "cannot process more than one child in aggregate distinct function: " + + aggregateFunction); AggregateFunction nonDistinct = aggregateFunction - .withDistinctAndChildren(false, aggregateFunction.getArguments()); + .withDistinctAndChildren(false, ImmutableList.copyOf(aggChild)); AggregateExpression nonDistinctAggExpr = new AggregateExpression(nonDistinct, distinctLocalParam, aggregateFunction.child(0)); return nonDistinctAggExpr; @@ -1389,7 +1419,8 @@ public class AggregateStrategies implements ImplementationRuleFactory { requireDistinctHash, anyLocalHashGlobalAgg); //phase 4 - AggregateParam distinctGlobalParam = new AggregateParam(AggPhase.DISTINCT_GLOBAL, AggMode.BUFFER_TO_RESULT); + AggregateParam distinctGlobalParam = new AggregateParam( + AggPhase.DISTINCT_GLOBAL, AggMode.BUFFER_TO_RESULT, couldBanned); List globalDistinctOutput = Lists.newArrayList(); for (int i = 0; i < logicalAgg.getOutputExpressions().size(); i++) { NamedExpression outputExpr = logicalAgg.getOutputExpressions().get(i); @@ -1397,9 +1428,12 @@ public class AggregateStrategies implements ImplementationRuleFactory { if (expr instanceof AggregateFunction) { AggregateFunction aggregateFunction = (AggregateFunction) expr; if (aggregateFunction.isDistinct()) { - Preconditions.checkArgument(aggregateFunction.arity() == 1); + Set aggChild = Sets.newHashSet(aggregateFunction.children()); + Preconditions.checkArgument(aggChild.size() == 1, + "cannot process more than one child in aggregate distinct function: " + + aggregateFunction); AggregateFunction nonDistinct = aggregateFunction - .withDistinctAndChildren(false, aggregateFunction.getArguments()); + .withDistinctAndChildren(false, ImmutableList.copyOf(aggChild)); int idx = logicalAgg.getOutputExpressions().indexOf(outputExpr); Alias localDistinctAlias = (Alias) (localDistinctOutput.get(idx)); return new AggregateExpression(nonDistinct, @@ -1424,4 +1458,11 @@ public class AggregateStrategies implements ImplementationRuleFactory { .add(distinctGlobal) .build(); } + + private boolean couldConvertToMulti(LogicalAggregate aggregate) { + return ExpressionUtils.noneMatch(aggregate.getOutputExpressions(), expr -> + expr instanceof AggregateFunction && ((AggregateFunction) expr).isDistinct() + && (expr.arity() > 1 + || !(expr instanceof Count || expr instanceof Sum || expr instanceof GroupConcat))); + } } diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/agg/AggregateParam.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/agg/AggregateParam.java index 2ff8eb262f..89e9c1ea7d 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/agg/AggregateParam.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/agg/AggregateParam.java @@ -25,43 +25,35 @@ import java.util.Objects; /** AggregateParam. */ public class AggregateParam { + public static AggregateParam LOCAL_RESULT = new AggregateParam(AggPhase.LOCAL, AggMode.INPUT_TO_RESULT); + public static AggregateParam LOCAL_BUFFER = new AggregateParam(AggPhase.LOCAL, AggMode.INPUT_TO_BUFFER); + public final AggPhase aggPhase; - public final AggMode aggMode; - - // TODO remove this flag, and generate it in enforce and cost job - public boolean needColocateScan; + // TODO: this is a short-term plan to process count(distinct a, b) correctly + public final boolean canBeBanned; /** AggregateParam */ public AggregateParam(AggPhase aggPhase, AggMode aggMode) { - this(aggPhase, aggMode, false); + this(aggPhase, aggMode, true); } - /** AggregateParam */ - public AggregateParam(AggPhase aggPhase, AggMode aggMode, boolean needColocateScan) { + public AggregateParam(AggPhase aggPhase, AggMode aggMode, boolean canBeBanned) { this.aggMode = Objects.requireNonNull(aggMode, "aggMode cannot be null"); this.aggPhase = Objects.requireNonNull(aggPhase, "aggPhase cannot be null"); - this.needColocateScan = needColocateScan; - } - - public static AggregateParam localResult() { - return new AggregateParam(AggPhase.LOCAL, AggMode.INPUT_TO_RESULT, true); + this.canBeBanned = canBeBanned; } public AggregateParam withAggPhase(AggPhase aggPhase) { - return new AggregateParam(aggPhase, aggMode, needColocateScan); + return new AggregateParam(aggPhase, aggMode, canBeBanned); } public AggregateParam withAggPhase(AggMode aggMode) { - return new AggregateParam(aggPhase, aggMode, needColocateScan); + return new AggregateParam(aggPhase, aggMode, canBeBanned); } public AggregateParam withAppPhaseAndAppMode(AggPhase aggPhase, AggMode aggMode) { - return new AggregateParam(aggPhase, aggMode, needColocateScan); - } - - public AggregateParam withNeedColocateScan(boolean needColocateScan) { - return new AggregateParam(aggPhase, aggMode, needColocateScan); + return new AggregateParam(aggPhase, aggMode, canBeBanned); } @Override @@ -87,7 +79,6 @@ public class AggregateParam { return "AggregateParam{" + "aggPhase=" + aggPhase + ", aggMode=" + aggMode - + ", needColocateScan=" + needColocateScan + '}'; } } diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/agg/MultiDistinctCount.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/agg/MultiDistinctCount.java index 72a26d288e..b9e7c1fdb1 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/agg/MultiDistinctCount.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/agg/MultiDistinctCount.java @@ -35,7 +35,7 @@ import java.util.List; /** MultiDistinctCount */ public class MultiDistinctCount extends AggregateFunction - implements AlwaysNotNullable, ExplicitlyCastableSignature { + implements AlwaysNotNullable, ExplicitlyCastableSignature, MultiDistinction { public static final List SIGNATURES = ImmutableList.of( FunctionSignature.ret(BigIntType.INSTANCE).varArgs(AnyDataType.INSTANCE) diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/agg/MultiDistinctGroupConcat.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/agg/MultiDistinctGroupConcat.java index 737e895906..5f5bb6815a 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/agg/MultiDistinctGroupConcat.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/agg/MultiDistinctGroupConcat.java @@ -36,7 +36,7 @@ import java.util.List; /** MultiDistinctGroupConcat */ public class MultiDistinctGroupConcat extends NullableAggregateFunction - implements ExplicitlyCastableSignature { + implements ExplicitlyCastableSignature, MultiDistinction { public static final List SIGNATURES = ImmutableList.of( FunctionSignature.ret(VarcharType.SYSTEM_DEFAULT).args(VarcharType.SYSTEM_DEFAULT), diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/agg/MultiDistinctSum.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/agg/MultiDistinctSum.java index a378dc0960..8441e02828 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/agg/MultiDistinctSum.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/agg/MultiDistinctSum.java @@ -35,8 +35,8 @@ import com.google.common.collect.ImmutableList; import java.util.List; /** MultiDistinctSum */ -public class MultiDistinctSum extends AggregateFunction - implements UnaryExpression, AlwaysNotNullable, ExplicitlyCastableSignature, ComputePrecisionForSum { +public class MultiDistinctSum extends AggregateFunction implements UnaryExpression, AlwaysNotNullable, + ExplicitlyCastableSignature, ComputePrecisionForSum, MultiDistinction { public static final List SIGNATURES = ImmutableList.of( FunctionSignature.ret(BigIntType.INSTANCE).varArgs(BigIntType.INSTANCE), diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/agg/MultiDistinction.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/agg/MultiDistinction.java new file mode 100644 index 0000000000..ab8842f730 --- /dev/null +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/agg/MultiDistinction.java @@ -0,0 +1,27 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +package org.apache.doris.nereids.trees.expressions.functions.agg; + +import org.apache.doris.nereids.trees.TreeNode; +import org.apache.doris.nereids.trees.expressions.Expression; + +/** + * base class of multi-distinct agg function + */ +public interface MultiDistinction extends TreeNode { +} diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/algebra/Aggregate.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/algebra/Aggregate.java index 6731bde58a..8361e230be 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/algebra/Aggregate.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/algebra/Aggregate.java @@ -25,7 +25,7 @@ import org.apache.doris.nereids.trees.plans.UnaryPlan; import org.apache.doris.nereids.trees.plans.logical.OutputPrunable; import org.apache.doris.nereids.util.ExpressionUtils; -import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableSet; import java.util.List; import java.util.Set; @@ -53,10 +53,10 @@ public interface Aggregate extends UnaryPlan getDistinctArguments() { + default Set getDistinctArguments() { return getAggregateFunctions().stream() .filter(AggregateFunction::isDistinct) .flatMap(aggregateExpression -> aggregateExpression.getArguments().stream()) - .collect(ImmutableList.toImmutableList()); + .collect(ImmutableSet.toImmutableSet()); } } diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/physical/PhysicalStorageLayerAggregate.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/physical/PhysicalStorageLayerAggregate.java index ed57eebcb0..094f5d75cd 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/physical/PhysicalStorageLayerAggregate.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/physical/PhysicalStorageLayerAggregate.java @@ -21,7 +21,7 @@ import org.apache.doris.catalog.Table; import org.apache.doris.nereids.memo.GroupExpression; import org.apache.doris.nereids.properties.LogicalProperties; import org.apache.doris.nereids.properties.PhysicalProperties; -import org.apache.doris.nereids.trees.expressions.Expression; +import org.apache.doris.nereids.trees.expressions.functions.agg.AggregateFunction; import org.apache.doris.nereids.trees.expressions.functions.agg.Count; import org.apache.doris.nereids.trees.expressions.functions.agg.Max; import org.apache.doris.nereids.trees.expressions.functions.agg.Min; @@ -30,7 +30,6 @@ import org.apache.doris.nereids.trees.plans.visitor.PlanVisitor; import org.apache.doris.nereids.util.Utils; import org.apache.doris.statistics.Statistics; -import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; import java.util.List; @@ -78,11 +77,6 @@ public class PhysicalStorageLayerAggregate extends PhysicalRelation { return visitor.visitPhysicalStorageLayerAggregate(this, context); } - @Override - public List getExpressions() { - return ImmutableList.of(); - } - @Override public boolean equals(Object o) { if (this == o) { @@ -113,7 +107,7 @@ public class PhysicalStorageLayerAggregate extends PhysicalRelation { } public PhysicalStorageLayerAggregate withPhysicalOlapScan(PhysicalOlapScan physicalOlapScan) { - return new PhysicalStorageLayerAggregate(relation, aggOp); + return new PhysicalStorageLayerAggregate(physicalOlapScan, aggOp); } @Override @@ -142,8 +136,8 @@ public class PhysicalStorageLayerAggregate extends PhysicalRelation { public enum PushDownAggOp { COUNT, MIN_MAX, MIX; - public static Map supportedFunctions() { - return ImmutableMap.builder() + public static Map, PushDownAggOp> supportedFunctions() { + return ImmutableMap., PushDownAggOp>builder() .put(Count.class, PushDownAggOp.COUNT) .put(Min.class, PushDownAggOp.MIN_MAX) .put(Max.class, PushDownAggOp.MIN_MAX) diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/util/ExpressionUtils.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/util/ExpressionUtils.java index e1fbefad61..a3a3ca1b80 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/util/ExpressionUtils.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/util/ExpressionUtils.java @@ -403,6 +403,11 @@ public class ExpressionUtils { .anyMatch(expr -> expr.anyMatch(predicate)); } + public static boolean noneMatch(List expressions, Predicate> predicate) { + return expressions.stream() + .noneMatch(expr -> expr.anyMatch(predicate)); + } + public static boolean containsType(List expressions, Class type) { return anyMatch(expressions, type::isInstance); } diff --git a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/AggregateStrategiesTest.java b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/AggregateStrategiesTest.java index 417605bfa5..e8419b2458 100644 --- a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/AggregateStrategiesTest.java +++ b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/AggregateStrategiesTest.java @@ -223,6 +223,8 @@ public class AggregateStrategiesTest implements MemoPatternMatchSupported { * */ @Test + @Disabled + @Developing("reopen it after we could choose agg phase by CBO") public void distinctAggregateWithoutGroupByApply2PhaseRule() { List groupExpressionList = new ArrayList<>(); List outputExpressionList = Lists.newArrayList(new Alias( diff --git a/regression-test/data/nereids_tpcds_shape_sf100_p0/shape/query94.out b/regression-test/data/nereids_tpcds_shape_sf100_p0/shape/query94.out index e12f8482f5..c26035693e 100644 --- a/regression-test/data/nereids_tpcds_shape_sf100_p0/shape/query94.out +++ b/regression-test/data/nereids_tpcds_shape_sf100_p0/shape/query94.out @@ -4,8 +4,8 @@ PhysicalTopN --PhysicalTopN ----PhysicalProject ------hashAgg[GLOBAL] ---------hashAgg[LOCAL] -----------PhysicalDistribute +--------PhysicalDistribute +----------hashAgg[LOCAL] ------------PhysicalProject --------------hashJoin[INNER_JOIN](ws1.ws_ship_date_sk = date_dim.d_date_sk) ----------------PhysicalProject diff --git a/regression-test/data/nereids_tpcds_shape_sf100_p0/shape/query95.out b/regression-test/data/nereids_tpcds_shape_sf100_p0/shape/query95.out index 6ddca5c0c2..014535c50d 100644 --- a/regression-test/data/nereids_tpcds_shape_sf100_p0/shape/query95.out +++ b/regression-test/data/nereids_tpcds_shape_sf100_p0/shape/query95.out @@ -14,8 +14,8 @@ CteAnchor[cteId= ( CTEId#3=] ) ----PhysicalTopN ------PhysicalProject --------hashAgg[GLOBAL] -----------hashAgg[LOCAL] -------------PhysicalDistribute +----------PhysicalDistribute +------------hashAgg[LOCAL] --------------PhysicalProject ----------------hashJoin[INNER_JOIN](ws1.ws_ship_date_sk = date_dim.d_date_sk) ------------------PhysicalProject diff --git a/regression-test/data/nereids_tpch_shape_sf1000_p0/shape/q16.out b/regression-test/data/nereids_tpch_shape_sf1000_p0/shape/q16.out index 515a72a29d..d72afcf57d 100644 --- a/regression-test/data/nereids_tpch_shape_sf1000_p0/shape/q16.out +++ b/regression-test/data/nereids_tpch_shape_sf1000_p0/shape/q16.out @@ -4,8 +4,8 @@ PhysicalQuickSort --PhysicalDistribute ----PhysicalQuickSort ------hashAgg[GLOBAL] ---------hashAgg[LOCAL] -----------PhysicalDistribute +--------PhysicalDistribute +----------hashAgg[LOCAL] ------------PhysicalProject --------------hashJoin[LEFT_ANTI_JOIN](partsupp.ps_suppkey = supplier.s_suppkey) ----------------hashJoin[INNER_JOIN](part.p_partkey = partsupp.ps_partkey) diff --git a/regression-test/data/nereids_tpch_shape_sf500_p0/shape/q16.out b/regression-test/data/nereids_tpch_shape_sf500_p0/shape/q16.out index 515a72a29d..d72afcf57d 100644 --- a/regression-test/data/nereids_tpch_shape_sf500_p0/shape/q16.out +++ b/regression-test/data/nereids_tpch_shape_sf500_p0/shape/q16.out @@ -4,8 +4,8 @@ PhysicalQuickSort --PhysicalDistribute ----PhysicalQuickSort ------hashAgg[GLOBAL] ---------hashAgg[LOCAL] -----------PhysicalDistribute +--------PhysicalDistribute +----------hashAgg[LOCAL] ------------PhysicalProject --------------hashJoin[LEFT_ANTI_JOIN](partsupp.ps_suppkey = supplier.s_suppkey) ----------------hashJoin[INNER_JOIN](part.p_partkey = partsupp.ps_partkey) diff --git a/regression-test/suites/nereids_p0/join/test_join.groovy b/regression-test/suites/nereids_p0/join/test_join.groovy index ce643f081a..584b80384a 100644 --- a/regression-test/suites/nereids_p0/join/test_join.groovy +++ b/regression-test/suites/nereids_p0/join/test_join.groovy @@ -24,15 +24,6 @@ suite("test_join", "nereids_p0") { def tbName1 = "test" def tbName2 = "baseall" def tbName3 = "bigtable" - def empty_name = "empty" - - qt_agg_sql1 """select /*+SET_VAR(disable_nereids_rules='TWO_PHASE_AGGREGATE_WITH_COUNT_DISTINCT_MULTI')*/ count(distinct k1, NULL) from test;""" - qt_agg_sql2 """select /*+SET_VAR(disable_nereids_rules='TWO_PHASE_AGGREGATE_WITH_COUNT_DISTINCT_MULTI')*/ count(distinct k1, NULL), avg(k2) from baseall;""" - qt_agg_sql3 """select /*+SET_VAR(disable_nereids_rules='TWO_PHASE_AGGREGATE_WITH_COUNT_DISTINCT_MULTI')*/ k1,count(distinct k2,k3),min(k4),count(*) from baseall group by k1 order by k1;""" - - qt_agg_sql4 """select /*+SET_VAR(disable_nereids_rules='THREE_PHASE_AGGREGATE_WITH_COUNT_DISTINCT_MULTI')*/ count(distinct k1, NULL) from test;""" - qt_agg_sql5 """select /*+SET_VAR(disable_nereids_rules='THREE_PHASE_AGGREGATE_WITH_COUNT_DISTINCT_MULTI')*/ count(distinct k1, NULL), avg(k2) from baseall;""" - qt_agg_sql6 """select /*+SET_VAR(disable_nereids_rules='THREE_PHASE_AGGREGATE_WITH_COUNT_DISTINCT_MULTI')*/ k1,count(distinct k2,k3),min(k4),count(*) from baseall group by k1 order by k1;""" order_sql """select j.*, d.* from ${tbName2} j full outer join ${tbName1} d on (j.k1=d.k1) order by j.k1, j.k2, j.k3, j.k4, d.k1, d.k2 limit 100""" diff --git a/regression-test/suites/nereids_syntax_p0/agg_4_phase.groovy b/regression-test/suites/nereids_syntax_p0/agg_4_phase.groovy index d2d48e3e08..a672f8dee3 100644 --- a/regression-test/suites/nereids_syntax_p0/agg_4_phase.groovy +++ b/regression-test/suites/nereids_syntax_p0/agg_4_phase.groovy @@ -43,16 +43,16 @@ suite("agg_4_phase") { (0, 0, "aa", 10), (1, 1, "bb",20), (2, 2, "cc", 30), (1, 1, "bb",20); """ def test_sql = """ - select /*+SET_VAR(disable_nereids_rules='THREE_PHASE_AGGREGATE_WITH_DISTINCT,TWO_PHASE_AGGREGATE_WITH_DISTINCT')*/ - count(distinct name), sum(age) + select + count(distinct id) from agg_4_phase_tbl; """ explain{ sql(test_sql) - contains "6:VAGGREGATE (merge finalize)" - contains "5:VEXCHANGE" - contains "4:VAGGREGATE (update serialize)" - contains "3:VAGGREGATE (merge serialize)" + contains "5:VAGGREGATE (merge finalize)" + contains "4:VEXCHANGE" + contains "3:VAGGREGATE (update serialize)" + contains "2:VAGGREGATE (merge serialize)" contains "1:VAGGREGATE (update serialize)" } qt_4phase (test_sql) diff --git a/regression-test/suites/nereids_syntax_p0/aggregate_strategies.groovy b/regression-test/suites/nereids_syntax_p0/aggregate_strategies.groovy index 39742e8d21..ea63b5b789 100644 --- a/regression-test/suites/nereids_syntax_p0/aggregate_strategies.groovy +++ b/regression-test/suites/nereids_syntax_p0/aggregate_strategies.groovy @@ -82,7 +82,6 @@ suite("aggregate_strategies") { explain { sql """ select - /*+SET_VAR(disable_nereids_rules='ONE_PHASE_AGGREGATE_SINGLE_DISTINCT_TO_MULTI,TWO_PHASE_AGGREGATE_SINGLE_DISTINCT_TO_MULTI,THREE_PHASE_AGGREGATE_WITH_DISTINCT, FOUR_PHASE_AGGREGATE_WITH_DISTINCT')*/ count(distinct id) from $tableName """ @@ -90,17 +89,17 @@ suite("aggregate_strategies") { notContains "STREAMING" } + // test multi_distinct test { sql """select - /*+SET_VAR(disable_nereids_rules='TWO_PHASE_AGGREGATE_WITH_DISTINCT')*/ - count(distinct id) + count(distinct name) from $tableName""" result([[5L]]) } + // test four phase distinct test { sql """select - /*+SET_VAR(disable_nereids_rules='THREE_PHASE_AGGREGATE_WITH_DISTINCT')*/ count(distinct id) from $tableName""" result([[5L]]) diff --git a/regression-test/suites/nereids_syntax_p0/group_concat.groovy b/regression-test/suites/nereids_syntax_p0/group_concat.groovy index fe2062d66d..60f52c2ba0 100644 --- a/regression-test/suites/nereids_syntax_p0/group_concat.groovy +++ b/regression-test/suites/nereids_syntax_p0/group_concat.groovy @@ -21,14 +21,14 @@ suite("group_concat") { test { - sql """select /*+SET_VAR(disable_nereids_rules='TWO_PHASE_AGGREGATE_WITHOUT_DISTINCT')*/ + sql """select group_concat(cast(number as string), ',' order by number) from numbers('number'='10')""" result([["0,1,2,3,4,5,6,7,8,9"]]) } test { - sql """select /*+SET_VAR(disable_nereids_rules='ONE_PHASE_AGGREGATE_WITHOUT_DISTINCT')*/ + sql """select group_concat(cast(number as string), ',' order by number) from numbers('number'='10')""" result([["0,1,2,3,4,5,6,7,8,9"]])