[fix](Nereids) fix group_concat(distinct) failed (#31873)

This commit is contained in:
924060929
2024-03-06 21:54:35 +08:00
committed by yiguolei
parent ad2f7fc316
commit 561709451c
6 changed files with 81 additions and 19 deletions

View File

@ -1108,7 +1108,7 @@ public class AggregateStrategies implements ImplementationRuleFactory {
* <p>
* single node aggregate:
* <p>
* PhysicalHashAggregate(groupBy=[name], output=[name, count(distinct(id)], mode=BUFFER_TO_RESULT)
* PhysicalHashAggregate(groupBy=[name], output=[name, count(distinct(id))], mode=BUFFER_TO_RESULT)
* |
* PhysicalHashAggregate(groupBy=[name, id], output=[name, id], mode=INPUT_TO_BUFFER)
* |
@ -1118,12 +1118,10 @@ public class AggregateStrategies implements ImplementationRuleFactory {
* <p>
* distribute node aggregate:
* <p>
* PhysicalHashAggregate(groupBy=[name], output=[name, count(distinct(id)], mode=BUFFER_TO_RESULT)
* PhysicalHashAggregate(groupBy=[name], output=[name, count(distinct(id))], mode=BUFFER_TO_RESULT)
* |
* PhysicalHashAggregate(groupBy=[name, id], output=[name, id], mode=INPUT_TO_BUFFER)
* |
* PhysicalDistribute(distributionSpec=HASH(name))
* |
* LogicalOlapScan(table=tbl, **if distribute by name**)
*
*/
@ -1175,8 +1173,9 @@ public class AggregateStrategies implements ImplementationRuleFactory {
if (outputChild instanceof AggregateFunction) {
AggregateFunction aggregateFunction = (AggregateFunction) outputChild;
if (aggregateFunction.isDistinct()) {
Set<Expression> aggChild = Sets.newHashSet(aggregateFunction.children());
Preconditions.checkArgument(aggChild.size() == 1,
Set<Expression> aggChild = Sets.newLinkedHashSet(aggregateFunction.children());
Preconditions.checkArgument(aggChild.size() == 1
|| aggregateFunction.getDistinctArguments().size() == 1,
"cannot process more than one child in aggregate distinct function: "
+ aggregateFunction);
AggregateFunction nonDistinct = aggregateFunction
@ -1236,7 +1235,7 @@ public class AggregateStrategies implements ImplementationRuleFactory {
* after:
* single node aggregate:
* <p>
* PhysicalHashAggregate(groupBy=[name], output=[name, count(distinct(id)], mode=BUFFER_TO_RESULT)
* PhysicalHashAggregate(groupBy=[name], output=[name, count(distinct(id))], mode=BUFFER_TO_RESULT)
* |
* PhysicalHashAggregate(groupBy=[name, id], output=[name, id], mode=BUFFER_TO_BUFFER)
* |
@ -1248,7 +1247,7 @@ public class AggregateStrategies implements ImplementationRuleFactory {
* <p>
* distribute node aggregate:
* <p>
* PhysicalHashAggregate(groupBy=[name], output=[name, count(distinct(id)], mode=BUFFER_TO_RESULT)
* PhysicalHashAggregate(groupBy=[name], output=[name, count(distinct(id))], mode=BUFFER_TO_RESULT)
* |
* PhysicalHashAggregate(groupBy=[name, id], output=[name, id], mode=BUFFER_TO_BUFFER)
* |
@ -1331,14 +1330,14 @@ public class AggregateStrategies implements ImplementationRuleFactory {
if (expr instanceof AggregateFunction) {
AggregateFunction aggregateFunction = (AggregateFunction) expr;
if (aggregateFunction.isDistinct()) {
Set<Expression> aggChild = Sets.newHashSet(aggregateFunction.children());
Preconditions.checkArgument(aggChild.size() == 1,
Set<Expression> aggChild = Sets.newLinkedHashSet(aggregateFunction.children());
Preconditions.checkArgument(aggChild.size() == 1
|| aggregateFunction.getDistinctArguments().size() == 1,
"cannot process more than one child in aggregate distinct function: "
+ aggregateFunction);
AggregateFunction nonDistinct = aggregateFunction
.withDistinctAndChildren(false, ImmutableList.copyOf(aggChild));
return new AggregateExpression(nonDistinct,
bufferToResultParam, aggregateFunction.child(0));
return new AggregateExpression(nonDistinct, bufferToResultParam, aggregateFunction);
} else {
Alias alias = nonDistinctAggFunctionToAliasPhase2.get(expr);
return new AggregateExpression(aggregateFunction,
@ -1727,8 +1726,9 @@ public class AggregateStrategies implements ImplementationRuleFactory {
if (expr instanceof AggregateFunction) {
AggregateFunction aggregateFunction = (AggregateFunction) expr;
if (aggregateFunction.isDistinct()) {
Set<Expression> aggChild = Sets.newHashSet(aggregateFunction.children());
Preconditions.checkArgument(aggChild.size() == 1,
Set<Expression> aggChild = Sets.newLinkedHashSet(aggregateFunction.children());
Preconditions.checkArgument(aggChild.size() == 1
|| aggregateFunction.getDistinctArguments().size() == 1,
"cannot process more than one child in aggregate distinct function: "
+ aggregateFunction);
AggregateFunction nonDistinct = aggregateFunction
@ -1767,8 +1767,9 @@ public class AggregateStrategies implements ImplementationRuleFactory {
if (expr instanceof AggregateFunction) {
AggregateFunction aggregateFunction = (AggregateFunction) expr;
if (aggregateFunction.isDistinct()) {
Set<Expression> aggChild = Sets.newHashSet(aggregateFunction.children());
Preconditions.checkArgument(aggChild.size() == 1,
Set<Expression> aggChild = Sets.newLinkedHashSet(aggregateFunction.children());
Preconditions.checkArgument(aggChild.size() == 1
|| aggregateFunction.getDistinctArguments().size() == 1,
"cannot process more than one child in aggregate distinct function: "
+ aggregateFunction);
AggregateFunction nonDistinct = aggregateFunction

View File

@ -124,4 +124,7 @@ public abstract class AggregateFunction extends BoundFunction implements Expects
return getName() + "(" + (distinct ? "DISTINCT " : "") + args + ")";
}
public List<Expression> getDistinctArguments() {
return distinct ? getArguments() : ImmutableList.of();
}
}

View File

@ -111,6 +111,15 @@ public class GroupConcat extends NullableAggregateFunction
.anyMatch(expression -> !(expression instanceof OrderExpression) && expression.nullable());
}
@Override
public List<Expression> getDistinctArguments() {
if (distinct) {
return ImmutableList.of(getArgument(0));
} else {
return ImmutableList.of();
}
}
@Override
public void checkLegalityBeforeTypeCoercion() {
DataType typeOrArg0 = getArgumentType(0);

View File

@ -56,7 +56,7 @@ public interface Aggregate<CHILD_TYPE extends Plan> extends UnaryPlan<CHILD_TYPE
default Set<Expression> getDistinctArguments() {
return getAggregateFunctions().stream()
.filter(AggregateFunction::isDistinct)
.flatMap(aggregateExpression -> aggregateExpression.getArguments().stream())
.flatMap(aggregateFunction -> aggregateFunction.getDistinctArguments().stream())
.collect(ImmutableSet.toImmutableSet());
}
}