From 165957658a8f6908944d1e431048ca0ca018e233 Mon Sep 17 00:00:00 2001 From: morrySnow <101034200+morrySnow@users.noreply.github.com> Date: Fri, 27 Oct 2023 16:42:00 +0800 Subject: [PATCH] [fix](Nereids) could not run multi group_concat distinct (#25851) could not run multi group_concat distinct with more than one parameters. This bug is not just for group_concat, but we usually use literal as parameters in group_concat. So group_concat brought the problem to light. In the original logic, we think only distinct aggregate function with zero or one parameter could run in multi distinct mode. But it is wrong. We could process all distinct aggregate function with not more than one input slots. Think about sql: ```sql SELECT group_concat(distinct c1, ','), group_concat(distinct c2, ',') FROM t GROUP BY c3 ``` --- .../properties/PhysicalProperties.java | 4 +++ .../nereids/rules/analysis/CheckAnalysis.java | 21 +++++++++-- .../rules/analysis/NormalizeAggregate.java | 3 +- .../implementation/AggregateStrategies.java | 36 ++++++++++++++----- 4 files changed, 53 insertions(+), 11 deletions(-) diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/properties/PhysicalProperties.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/properties/PhysicalProperties.java index 3c89f87340..beed2617a3 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/properties/PhysicalProperties.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/properties/PhysicalProperties.java @@ -72,9 +72,13 @@ public class PhysicalProperties { this.orderSpec = orderSpec; } + /** + * create hash info from orderedShuffledColumns, ignore non slot reference expression. + */ public static PhysicalProperties createHash( Collection orderedShuffledColumns, ShuffleType shuffleType) { List partitionedSlots = orderedShuffledColumns.stream() + .filter(SlotReference.class::isInstance) .map(SlotReference.class::cast) .map(SlotReference::getExprId) .collect(Collectors.toList()); diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/CheckAnalysis.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/CheckAnalysis.java index fc6ddfcf79..49de801b78 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/CheckAnalysis.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/CheckAnalysis.java @@ -140,8 +140,25 @@ public class CheckAnalysis implements AnalysisRuleFactory { private void checkAggregate(LogicalAggregate aggregate) { Set aggregateFunctions = aggregate.getAggregateFunctions(); - boolean distinctMultiColumns = aggregateFunctions.stream() - .anyMatch(fun -> fun.isDistinct() && fun.arity() > 1); + boolean distinctMultiColumns = false; + for (AggregateFunction func : aggregateFunctions) { + if (!func.isDistinct()) { + continue; + } + if (func.arity() <= 1) { + continue; + } + for (int i = 1; i < func.arity(); i++) { + if (!func.child(i).getInputSlots().isEmpty()) { + // think about group_concat(distinct col_1, ',') + distinctMultiColumns = true; + break; + } + } + if (distinctMultiColumns) { + break; + } + } long distinctFunctionNum = aggregateFunctions.stream() .filter(AggregateFunction::isDistinct) .count(); diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/NormalizeAggregate.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/NormalizeAggregate.java index c287f2dffe..0acd366f1c 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/NormalizeAggregate.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/NormalizeAggregate.java @@ -30,6 +30,7 @@ import org.apache.doris.nereids.trees.expressions.SlotReference; import org.apache.doris.nereids.trees.expressions.SubqueryExpr; import org.apache.doris.nereids.trees.expressions.WindowExpression; import org.apache.doris.nereids.trees.expressions.functions.agg.AggregateFunction; +import org.apache.doris.nereids.trees.expressions.literal.Literal; import org.apache.doris.nereids.trees.expressions.visitor.DefaultExpressionVisitor; import org.apache.doris.nereids.trees.plans.Plan; import org.apache.doris.nereids.trees.plans.logical.LogicalAggregate; @@ -150,7 +151,7 @@ public class NormalizeAggregate extends OneRewriteRuleFactory implements Normali for (AggregateFunction distinctAggFunc : distinctAggFuncs) { List newChildren = Lists.newArrayList(); for (Expression child : distinctAggFunc.children()) { - if (child instanceof SlotReference) { + if (child instanceof SlotReference || child instanceof Literal) { newChildren.add(child); } else { NamedExpression alias; diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/implementation/AggregateStrategies.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/implementation/AggregateStrategies.java index 511312ddd8..a020656c68 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/implementation/AggregateStrategies.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/implementation/AggregateStrategies.java @@ -972,13 +972,15 @@ public class AggregateStrategies implements ImplementationRuleFactory { LogicalAggregate logicalAgg, ConnectContext connectContext) { Set aggregateFunctions = logicalAgg.getAggregateFunctions(); - Set distinctArguments = aggregateFunctions.stream() + Set distinctArguments = aggregateFunctions.stream() .filter(AggregateFunction::isDistinct) .flatMap(aggregateExpression -> aggregateExpression.getArguments().stream()) + .filter(NamedExpression.class::isInstance) + .map(NamedExpression.class::cast) .collect(ImmutableSet.toImmutableSet()); Set localAggGroupBy = ImmutableSet.builder() - .addAll((List) logicalAgg.getGroupByExpressions()) + .addAll((List) (List) logicalAgg.getGroupByExpressions()) .addAll(distinctArguments) .build(); @@ -1106,13 +1108,15 @@ public class AggregateStrategies implements ImplementationRuleFactory { Set aggregateFunctions = logicalAgg.getAggregateFunctions(); - Set distinctArguments = aggregateFunctions.stream() + Set distinctArguments = aggregateFunctions.stream() .filter(AggregateFunction::isDistinct) .flatMap(aggregateExpression -> aggregateExpression.getArguments().stream()) + .filter(NamedExpression.class::isInstance) + .map(NamedExpression.class::cast) .collect(ImmutableSet.toImmutableSet()); Set localAggGroupBySet = ImmutableSet.builder() - .addAll((List) logicalAgg.getGroupByExpressions()) + .addAll((List) (List) logicalAgg.getGroupByExpressions()) .addAll(distinctArguments) .build(); @@ -1492,6 +1496,7 @@ public class AggregateStrategies implements ImplementationRuleFactory { Set distinctArguments = aggregateFunctions.stream() .filter(AggregateFunction::isDistinct) .flatMap(aggregateExpression -> aggregateExpression.getArguments().stream()) + .filter(NamedExpression.class::isInstance) .map(NamedExpression.class::cast) .collect(ImmutableSet.toImmutableSet()); @@ -1636,9 +1641,24 @@ public class AggregateStrategies implements ImplementationRuleFactory { } private boolean couldConvertToMulti(LogicalAggregate aggregate) { - return ExpressionUtils.noneMatch(aggregate.getOutputExpressions(), expr -> - expr instanceof AggregateFunction && ((AggregateFunction) expr).isDistinct() - && (expr.arity() > 1 - || !(expr instanceof Count || expr instanceof Sum || expr instanceof GroupConcat))); + Set aggregateFunctions = aggregate.getAggregateFunctions(); + for (AggregateFunction func : aggregateFunctions) { + if (!func.isDistinct()) { + continue; + } + if (!(func instanceof Count || func instanceof Sum || func instanceof GroupConcat)) { + return false; + } + if (func.arity() <= 1) { + continue; + } + for (int i = 1; i < func.arity(); i++) { + // think about group_concat(distinct col_1, ',') + if (!func.child(i).getInputSlots().isEmpty()) { + return false; + } + } + } + return true; } }