diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/properties/PhysicalProperties.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/properties/PhysicalProperties.java index 3c89f87340..beed2617a3 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/properties/PhysicalProperties.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/properties/PhysicalProperties.java @@ -72,9 +72,13 @@ public class PhysicalProperties { this.orderSpec = orderSpec; } + /** + * create hash info from orderedShuffledColumns, ignore non slot reference expression. + */ public static PhysicalProperties createHash( Collection orderedShuffledColumns, ShuffleType shuffleType) { List partitionedSlots = orderedShuffledColumns.stream() + .filter(SlotReference.class::isInstance) .map(SlotReference.class::cast) .map(SlotReference::getExprId) .collect(Collectors.toList()); diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/CheckAnalysis.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/CheckAnalysis.java index fc6ddfcf79..49de801b78 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/CheckAnalysis.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/CheckAnalysis.java @@ -140,8 +140,25 @@ public class CheckAnalysis implements AnalysisRuleFactory { private void checkAggregate(LogicalAggregate aggregate) { Set aggregateFunctions = aggregate.getAggregateFunctions(); - boolean distinctMultiColumns = aggregateFunctions.stream() - .anyMatch(fun -> fun.isDistinct() && fun.arity() > 1); + boolean distinctMultiColumns = false; + for (AggregateFunction func : aggregateFunctions) { + if (!func.isDistinct()) { + continue; + } + if (func.arity() <= 1) { + continue; + } + for (int i = 1; i < func.arity(); i++) { + if (!func.child(i).getInputSlots().isEmpty()) { + // think about group_concat(distinct col_1, ',') + distinctMultiColumns = true; + break; + } + } + if (distinctMultiColumns) { + break; + } + } long distinctFunctionNum = aggregateFunctions.stream() .filter(AggregateFunction::isDistinct) .count(); diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/NormalizeAggregate.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/NormalizeAggregate.java index c287f2dffe..0acd366f1c 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/NormalizeAggregate.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/NormalizeAggregate.java @@ -30,6 +30,7 @@ import org.apache.doris.nereids.trees.expressions.SlotReference; import org.apache.doris.nereids.trees.expressions.SubqueryExpr; import org.apache.doris.nereids.trees.expressions.WindowExpression; import org.apache.doris.nereids.trees.expressions.functions.agg.AggregateFunction; +import org.apache.doris.nereids.trees.expressions.literal.Literal; import org.apache.doris.nereids.trees.expressions.visitor.DefaultExpressionVisitor; import org.apache.doris.nereids.trees.plans.Plan; import org.apache.doris.nereids.trees.plans.logical.LogicalAggregate; @@ -150,7 +151,7 @@ public class NormalizeAggregate extends OneRewriteRuleFactory implements Normali for (AggregateFunction distinctAggFunc : distinctAggFuncs) { List newChildren = Lists.newArrayList(); for (Expression child : distinctAggFunc.children()) { - if (child instanceof SlotReference) { + if (child instanceof SlotReference || child instanceof Literal) { newChildren.add(child); } else { NamedExpression alias; diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/implementation/AggregateStrategies.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/implementation/AggregateStrategies.java index 511312ddd8..a020656c68 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/implementation/AggregateStrategies.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/implementation/AggregateStrategies.java @@ -972,13 +972,15 @@ public class AggregateStrategies implements ImplementationRuleFactory { LogicalAggregate logicalAgg, ConnectContext connectContext) { Set aggregateFunctions = logicalAgg.getAggregateFunctions(); - Set distinctArguments = aggregateFunctions.stream() + Set distinctArguments = aggregateFunctions.stream() .filter(AggregateFunction::isDistinct) .flatMap(aggregateExpression -> aggregateExpression.getArguments().stream()) + .filter(NamedExpression.class::isInstance) + .map(NamedExpression.class::cast) .collect(ImmutableSet.toImmutableSet()); Set localAggGroupBy = ImmutableSet.builder() - .addAll((List) logicalAgg.getGroupByExpressions()) + .addAll((List) (List) logicalAgg.getGroupByExpressions()) .addAll(distinctArguments) .build(); @@ -1106,13 +1108,15 @@ public class AggregateStrategies implements ImplementationRuleFactory { Set aggregateFunctions = logicalAgg.getAggregateFunctions(); - Set distinctArguments = aggregateFunctions.stream() + Set distinctArguments = aggregateFunctions.stream() .filter(AggregateFunction::isDistinct) .flatMap(aggregateExpression -> aggregateExpression.getArguments().stream()) + .filter(NamedExpression.class::isInstance) + .map(NamedExpression.class::cast) .collect(ImmutableSet.toImmutableSet()); Set localAggGroupBySet = ImmutableSet.builder() - .addAll((List) logicalAgg.getGroupByExpressions()) + .addAll((List) (List) logicalAgg.getGroupByExpressions()) .addAll(distinctArguments) .build(); @@ -1492,6 +1496,7 @@ public class AggregateStrategies implements ImplementationRuleFactory { Set distinctArguments = aggregateFunctions.stream() .filter(AggregateFunction::isDistinct) .flatMap(aggregateExpression -> aggregateExpression.getArguments().stream()) + .filter(NamedExpression.class::isInstance) .map(NamedExpression.class::cast) .collect(ImmutableSet.toImmutableSet()); @@ -1636,9 +1641,24 @@ public class AggregateStrategies implements ImplementationRuleFactory { } private boolean couldConvertToMulti(LogicalAggregate aggregate) { - return ExpressionUtils.noneMatch(aggregate.getOutputExpressions(), expr -> - expr instanceof AggregateFunction && ((AggregateFunction) expr).isDistinct() - && (expr.arity() > 1 - || !(expr instanceof Count || expr instanceof Sum || expr instanceof GroupConcat))); + Set aggregateFunctions = aggregate.getAggregateFunctions(); + for (AggregateFunction func : aggregateFunctions) { + if (!func.isDistinct()) { + continue; + } + if (!(func instanceof Count || func instanceof Sum || func instanceof GroupConcat)) { + return false; + } + if (func.arity() <= 1) { + continue; + } + for (int i = 1; i < func.arity(); i++) { + // think about group_concat(distinct col_1, ',') + if (!func.child(i).getInputSlots().isEmpty()) { + return false; + } + } + } + return true; } }