[fix](Nereids) could not run multi group_concat distinct (#25851)
could not run multi group_concat distinct with more than one parameters. This bug is not just for group_concat, but we usually use literal as parameters in group_concat. So group_concat brought the problem to light. In the original logic, we think only distinct aggregate function with zero or one parameter could run in multi distinct mode. But it is wrong. We could process all distinct aggregate function with not more than one input slots. Think about sql: ```sql SELECT group_concat(distinct c1, ','), group_concat(distinct c2, ',') FROM t GROUP BY c3 ```
This commit is contained in:
@ -72,9 +72,13 @@ public class PhysicalProperties {
|
||||
this.orderSpec = orderSpec;
|
||||
}
|
||||
|
||||
/**
|
||||
* create hash info from orderedShuffledColumns, ignore non slot reference expression.
|
||||
*/
|
||||
public static PhysicalProperties createHash(
|
||||
Collection<? extends Expression> orderedShuffledColumns, ShuffleType shuffleType) {
|
||||
List<ExprId> partitionedSlots = orderedShuffledColumns.stream()
|
||||
.filter(SlotReference.class::isInstance)
|
||||
.map(SlotReference.class::cast)
|
||||
.map(SlotReference::getExprId)
|
||||
.collect(Collectors.toList());
|
||||
|
||||
@ -140,8 +140,25 @@ public class CheckAnalysis implements AnalysisRuleFactory {
|
||||
|
||||
private void checkAggregate(LogicalAggregate<? extends Plan> aggregate) {
|
||||
Set<AggregateFunction> aggregateFunctions = aggregate.getAggregateFunctions();
|
||||
boolean distinctMultiColumns = aggregateFunctions.stream()
|
||||
.anyMatch(fun -> fun.isDistinct() && fun.arity() > 1);
|
||||
boolean distinctMultiColumns = false;
|
||||
for (AggregateFunction func : aggregateFunctions) {
|
||||
if (!func.isDistinct()) {
|
||||
continue;
|
||||
}
|
||||
if (func.arity() <= 1) {
|
||||
continue;
|
||||
}
|
||||
for (int i = 1; i < func.arity(); i++) {
|
||||
if (!func.child(i).getInputSlots().isEmpty()) {
|
||||
// think about group_concat(distinct col_1, ',')
|
||||
distinctMultiColumns = true;
|
||||
break;
|
||||
}
|
||||
}
|
||||
if (distinctMultiColumns) {
|
||||
break;
|
||||
}
|
||||
}
|
||||
long distinctFunctionNum = aggregateFunctions.stream()
|
||||
.filter(AggregateFunction::isDistinct)
|
||||
.count();
|
||||
|
||||
@ -30,6 +30,7 @@ import org.apache.doris.nereids.trees.expressions.SlotReference;
|
||||
import org.apache.doris.nereids.trees.expressions.SubqueryExpr;
|
||||
import org.apache.doris.nereids.trees.expressions.WindowExpression;
|
||||
import org.apache.doris.nereids.trees.expressions.functions.agg.AggregateFunction;
|
||||
import org.apache.doris.nereids.trees.expressions.literal.Literal;
|
||||
import org.apache.doris.nereids.trees.expressions.visitor.DefaultExpressionVisitor;
|
||||
import org.apache.doris.nereids.trees.plans.Plan;
|
||||
import org.apache.doris.nereids.trees.plans.logical.LogicalAggregate;
|
||||
@ -150,7 +151,7 @@ public class NormalizeAggregate extends OneRewriteRuleFactory implements Normali
|
||||
for (AggregateFunction distinctAggFunc : distinctAggFuncs) {
|
||||
List<Expression> newChildren = Lists.newArrayList();
|
||||
for (Expression child : distinctAggFunc.children()) {
|
||||
if (child instanceof SlotReference) {
|
||||
if (child instanceof SlotReference || child instanceof Literal) {
|
||||
newChildren.add(child);
|
||||
} else {
|
||||
NamedExpression alias;
|
||||
|
||||
@ -972,13 +972,15 @@ public class AggregateStrategies implements ImplementationRuleFactory {
|
||||
LogicalAggregate<? extends Plan> logicalAgg, ConnectContext connectContext) {
|
||||
Set<AggregateFunction> aggregateFunctions = logicalAgg.getAggregateFunctions();
|
||||
|
||||
Set<Expression> distinctArguments = aggregateFunctions.stream()
|
||||
Set<NamedExpression> distinctArguments = aggregateFunctions.stream()
|
||||
.filter(AggregateFunction::isDistinct)
|
||||
.flatMap(aggregateExpression -> aggregateExpression.getArguments().stream())
|
||||
.filter(NamedExpression.class::isInstance)
|
||||
.map(NamedExpression.class::cast)
|
||||
.collect(ImmutableSet.toImmutableSet());
|
||||
|
||||
Set<NamedExpression> localAggGroupBy = ImmutableSet.<NamedExpression>builder()
|
||||
.addAll((List) logicalAgg.getGroupByExpressions())
|
||||
.addAll((List<NamedExpression>) (List) logicalAgg.getGroupByExpressions())
|
||||
.addAll(distinctArguments)
|
||||
.build();
|
||||
|
||||
@ -1106,13 +1108,15 @@ public class AggregateStrategies implements ImplementationRuleFactory {
|
||||
|
||||
Set<AggregateFunction> aggregateFunctions = logicalAgg.getAggregateFunctions();
|
||||
|
||||
Set<Expression> distinctArguments = aggregateFunctions.stream()
|
||||
Set<NamedExpression> distinctArguments = aggregateFunctions.stream()
|
||||
.filter(AggregateFunction::isDistinct)
|
||||
.flatMap(aggregateExpression -> aggregateExpression.getArguments().stream())
|
||||
.filter(NamedExpression.class::isInstance)
|
||||
.map(NamedExpression.class::cast)
|
||||
.collect(ImmutableSet.toImmutableSet());
|
||||
|
||||
Set<NamedExpression> localAggGroupBySet = ImmutableSet.<NamedExpression>builder()
|
||||
.addAll((List) logicalAgg.getGroupByExpressions())
|
||||
.addAll((List<NamedExpression>) (List) logicalAgg.getGroupByExpressions())
|
||||
.addAll(distinctArguments)
|
||||
.build();
|
||||
|
||||
@ -1492,6 +1496,7 @@ public class AggregateStrategies implements ImplementationRuleFactory {
|
||||
Set<NamedExpression> distinctArguments = aggregateFunctions.stream()
|
||||
.filter(AggregateFunction::isDistinct)
|
||||
.flatMap(aggregateExpression -> aggregateExpression.getArguments().stream())
|
||||
.filter(NamedExpression.class::isInstance)
|
||||
.map(NamedExpression.class::cast)
|
||||
.collect(ImmutableSet.toImmutableSet());
|
||||
|
||||
@ -1636,9 +1641,24 @@ public class AggregateStrategies implements ImplementationRuleFactory {
|
||||
}
|
||||
|
||||
private boolean couldConvertToMulti(LogicalAggregate<? extends Plan> aggregate) {
|
||||
return ExpressionUtils.noneMatch(aggregate.getOutputExpressions(), expr ->
|
||||
expr instanceof AggregateFunction && ((AggregateFunction) expr).isDistinct()
|
||||
&& (expr.arity() > 1
|
||||
|| !(expr instanceof Count || expr instanceof Sum || expr instanceof GroupConcat)));
|
||||
Set<AggregateFunction> aggregateFunctions = aggregate.getAggregateFunctions();
|
||||
for (AggregateFunction func : aggregateFunctions) {
|
||||
if (!func.isDistinct()) {
|
||||
continue;
|
||||
}
|
||||
if (!(func instanceof Count || func instanceof Sum || func instanceof GroupConcat)) {
|
||||
return false;
|
||||
}
|
||||
if (func.arity() <= 1) {
|
||||
continue;
|
||||
}
|
||||
for (int i = 1; i < func.arity(); i++) {
|
||||
// think about group_concat(distinct col_1, ',')
|
||||
if (!func.child(i).getInputSlots().isEmpty()) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
}
|
||||
return true;
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user