[fix](nereids)support group_concat with distinct and order by (#38871)

## Proposed changes

pick from master https://github.com/apache/doris/pull/38080

<!--Describe your changes.-->
This commit is contained in:
starocean999
2024-08-05 18:23:55 +08:00
committed by GitHub
parent bf1c7a1c15
commit 40567b5d69
22 changed files with 326 additions and 190 deletions

View File

@ -624,7 +624,9 @@ public class ExpressionTranslator extends DefaultExpressionVisitor<Expr, PlanTra
@Override
public Expr visitStateCombinator(StateCombinator combinator, PlanTranslatorContext context) {
List<Expr> arguments = combinator.getArguments().stream().map(arg -> arg.accept(this, context))
List<Expr> arguments = combinator.getArguments().stream().map(arg -> arg instanceof OrderExpression
? translateOrderExpression((OrderExpression) arg, context).getExpr()
: arg.accept(this, context))
.collect(Collectors.toList());
return Function.convertToStateCombinator(
new FunctionCallExpr(visitAggregateFunction(combinator.getNestedFunction(), context).getFn(),

View File

@ -274,7 +274,6 @@ import org.apache.doris.nereids.trees.expressions.WindowExpression;
import org.apache.doris.nereids.trees.expressions.WindowFrame;
import org.apache.doris.nereids.trees.expressions.functions.Function;
import org.apache.doris.nereids.trees.expressions.functions.agg.Count;
import org.apache.doris.nereids.trees.expressions.functions.agg.GroupConcat;
import org.apache.doris.nereids.trees.expressions.functions.scalar.Array;
import org.apache.doris.nereids.trees.expressions.functions.scalar.ArrayRange;
import org.apache.doris.nereids.trees.expressions.functions.scalar.ArrayRangeDayUnit;
@ -2096,11 +2095,10 @@ public class LogicalPlanBuilder extends DorisParserBaseVisitor<Object> {
return ParserUtils.withOrigin(ctx, () -> {
String functionName = ctx.functionIdentifier().functionNameIdentifier().getText();
boolean isDistinct = ctx.DISTINCT() != null;
List<Expression> params = visit(ctx.expression(), Expression.class);
List<Expression> params = Lists.newArrayList();
params.addAll(visit(ctx.expression(), Expression.class));
List<OrderKey> orderKeys = visit(ctx.sortItem(), OrderKey.class);
if (!orderKeys.isEmpty()) {
return parseFunctionWithOrderKeys(functionName, isDistinct, params, orderKeys, ctx);
}
params.addAll(orderKeys.stream().map(OrderExpression::new).collect(Collectors.toList()));
List<UnboundStar> unboundStars = ExpressionUtils.collectAll(params, UnboundStar.class::isInstance);
if (!unboundStars.isEmpty()) {
@ -3471,23 +3469,6 @@ public class LogicalPlanBuilder extends DorisParserBaseVisitor<Object> {
return new StructField(ctx.identifier().getText(), typedVisit(ctx.dataType()), true, comment);
}
private Expression parseFunctionWithOrderKeys(String functionName, boolean isDistinct,
List<Expression> params, List<OrderKey> orderKeys, ParserRuleContext ctx) {
if (functionName.equalsIgnoreCase("group_concat")) {
OrderExpression[] orderExpressions = orderKeys.stream()
.map(OrderExpression::new)
.toArray(OrderExpression[]::new);
if (params.size() == 1) {
return new GroupConcat(isDistinct, params.get(0), orderExpressions);
} else if (params.size() == 2) {
return new GroupConcat(isDistinct, params.get(0), params.get(1), orderExpressions);
} else {
throw new ParseException("group_concat requires one or two parameters: " + params, ctx);
}
}
throw new ParseException("Unsupported function with order expressions" + ctx.getText(), ctx);
}
private String parseConstant(ConstantContext context) {
Object constant = visit(context);
if (constant instanceof Literal && ((Literal) constant).isStringLikeLiteral()) {

View File

@ -159,12 +159,15 @@ public class ChildrenPropertiesRegulator extends PlanVisitor<Boolean, Void> {
distinctChildColumns, ShuffleType.REQUIRE);
if ((!groupByColumns.isEmpty() && distributionSpecHash.satisfy(groupByRequire))
|| (groupByColumns.isEmpty() && distributionSpecHash.satisfy(distinctChildRequire))) {
return false;
if (!agg.mustUseMultiDistinctAgg()) {
return false;
}
}
}
// if distinct without group by key, we prefer three or four stage distinct agg
// because the second phase of multi-distinct only have one instance, and it is slow generally.
if (agg.getOutputExpressions().size() == 1 && agg.getGroupByExpressions().isEmpty()) {
if (agg.getOutputExpressions().size() == 1 && agg.getGroupByExpressions().isEmpty()
&& !agg.mustUseMultiDistinctAgg()) {
return false;
}
}

View File

@ -21,6 +21,7 @@ import org.apache.doris.nereids.exceptions.AnalysisException;
import org.apache.doris.nereids.rules.Rule;
import org.apache.doris.nereids.rules.RuleType;
import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.expressions.OrderExpression;
import org.apache.doris.nereids.trees.expressions.WindowExpression;
import org.apache.doris.nereids.trees.expressions.functions.agg.AggregateFunction;
import org.apache.doris.nereids.trees.expressions.functions.generator.TableGeneratingFunction;
@ -148,7 +149,7 @@ public class CheckAnalysis implements AnalysisRuleFactory {
continue;
}
for (int i = 1; i < func.arity(); i++) {
if (!func.child(i).getInputSlots().isEmpty()) {
if (!func.child(i).getInputSlots().isEmpty() && !(func.child(i) instanceof OrderExpression)) {
// think about group_concat(distinct col_1, ',')
distinctMultiColumns = true;
break;

View File

@ -403,6 +403,7 @@ public class ExpressionAnalyzer extends SubExprAnalyzer<ExpressionRewriteContext
}
Pair<? extends Expression, ? extends BoundFunction> buildResult = builder.build(functionName, arguments);
buildResult.second.checkOrderExprIsValid();
Optional<SqlCacheContext> sqlCacheContext = Optional.empty();
if (wantToParseSqlFromSqlCache) {
StatementContext statementContext = context.cascadesContext.getStatementContext();

View File

@ -26,12 +26,14 @@ import org.apache.doris.nereids.trees.expressions.Alias;
import org.apache.doris.nereids.trees.expressions.ExprId;
import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.expressions.NamedExpression;
import org.apache.doris.nereids.trees.expressions.OrderExpression;
import org.apache.doris.nereids.trees.expressions.Slot;
import org.apache.doris.nereids.trees.expressions.SlotNotFromChildren;
import org.apache.doris.nereids.trees.expressions.SlotReference;
import org.apache.doris.nereids.trees.expressions.SubqueryExpr;
import org.apache.doris.nereids.trees.expressions.WindowExpression;
import org.apache.doris.nereids.trees.expressions.functions.agg.AggregateFunction;
import org.apache.doris.nereids.trees.expressions.functions.agg.MultiDistinction;
import org.apache.doris.nereids.trees.expressions.literal.Literal;
import org.apache.doris.nereids.trees.plans.Plan;
import org.apache.doris.nereids.trees.plans.logical.LogicalAggregate;
@ -54,6 +56,7 @@ import java.util.Map;
import java.util.Optional;
import java.util.Set;
import java.util.stream.Collectors;
import java.util.stream.Stream;
/**
* normalize aggregate's group keys and AggregateFunction's child to SlotReference
@ -170,6 +173,7 @@ public class NormalizeAggregate implements RewriteRuleFactory, NormalizeToSlot {
// should not push down literal under aggregate
// e.g. group_concat(distinct xxx, ','), the ',' literal show stay in aggregate
.filter(arg -> !(arg instanceof Literal))
.flatMap(arg -> arg instanceof OrderExpression ? arg.getInputSlots().stream() : Stream.of(arg))
.collect(
Collectors.groupingBy(
child -> !(child instanceof SlotReference),
@ -255,7 +259,15 @@ public class NormalizeAggregate implements RewriteRuleFactory, NormalizeToSlot {
normalizedAggFuncsToSlotContext.pushDownToNamedExpression(normalizedAggFuncs)
);
// create new agg node
ImmutableList<NamedExpression> normalizedAggOutput = normalizedAggOutputBuilder.build();
ImmutableList<NamedExpression> aggOutput = normalizedAggOutputBuilder.build();
ImmutableList.Builder<NamedExpression> newAggOutputBuilder
= ImmutableList.builderWithExpectedSize(aggOutput.size());
for (NamedExpression output : aggOutput) {
Expression rewrittenExpr = output.rewriteDownShortCircuit(
e -> e instanceof MultiDistinction ? ((MultiDistinction) e).withMustUseMultiDistinctAgg(true) : e);
newAggOutputBuilder.add((NamedExpression) rewrittenExpr);
}
ImmutableList<NamedExpression> normalizedAggOutput = newAggOutputBuilder.build();
LogicalAggregate<?> newAggregate =
aggregate.withNormalized(normalizedGroupExprs, normalizedAggOutput, bottomPlan);

View File

@ -40,6 +40,7 @@ import org.apache.doris.nereids.trees.expressions.Cast;
import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.expressions.IsNull;
import org.apache.doris.nereids.trees.expressions.NamedExpression;
import org.apache.doris.nereids.trees.expressions.OrderExpression;
import org.apache.doris.nereids.trees.expressions.Slot;
import org.apache.doris.nereids.trees.expressions.SlotReference;
import org.apache.doris.nereids.trees.expressions.functions.ExpressionTrait;
@ -296,6 +297,7 @@ public class AggregateStrategies implements ImplementationRuleFactory {
RuleType.THREE_PHASE_AGGREGATE_WITH_DISTINCT.build(
basePattern
.when(agg -> agg.getDistinctArguments().size() == 1)
.whenNot(agg -> agg.mustUseMultiDistinctAgg())
.thenApplyMulti(ctx -> threePhaseAggregateWithDistinct(ctx.root, ctx.connectContext))
),
/*
@ -319,6 +321,7 @@ public class AggregateStrategies implements ImplementationRuleFactory {
basePattern
.when(agg -> agg.getDistinctArguments().size() == 1)
.when(agg -> agg.getGroupByExpressions().isEmpty())
.whenNot(agg -> agg.mustUseMultiDistinctAgg())
.thenApplyMulti(ctx -> {
Function<List<Expression>, RequireProperties> secondPhaseRequireDistinctHash =
groupByAndDistinct -> RequireProperties.of(
@ -1940,7 +1943,7 @@ public class AggregateStrategies implements ImplementationRuleFactory {
}
for (int i = 1; i < func.arity(); i++) {
// think about group_concat(distinct col_1, ',')
if (!func.child(i).getInputSlots().isEmpty()) {
if (!(func.child(i) instanceof OrderExpression) && !func.child(i).getInputSlots().isEmpty()) {
return false;
}
}

View File

@ -124,6 +124,12 @@ public class WindowExpression extends Expression {
.orElseGet(() -> new WindowExpression(function, partitionKeys, orderKeys));
}
public WindowExpression withFunctionPartitionKeysOrderKeys(Expression function,
List<Expression> partitionKeys, List<OrderExpression> orderKeys) {
return windowFrame.map(frame -> new WindowExpression(function, partitionKeys, orderKeys, frame))
.orElseGet(() -> new WindowExpression(function, partitionKeys, orderKeys));
}
@Override
public boolean nullable() {
return function.nullable();

View File

@ -127,6 +127,10 @@ public class AggCombinerFunctionBuilder extends FunctionBuilder {
String nestedName = getNestedName(name);
if (combinatorSuffix.equalsIgnoreCase(STATE)) {
AggregateFunction nestedFunction = buildState(nestedName, arguments);
// distinct will be passed as 1st boolean true arg. remove it
if (!arguments.isEmpty() && arguments.get(0) instanceof Boolean && (Boolean) arguments.get(0)) {
arguments = arguments.subList(1, arguments.size());
}
return Pair.of(new StateCombinator((List<Expression>) arguments, nestedFunction), nestedFunction);
} else if (combinatorSuffix.equalsIgnoreCase(MERGE)) {
AggregateFunction nestedFunction = buildMergeOrUnion(nestedName, arguments);

View File

@ -18,8 +18,12 @@
package org.apache.doris.nereids.trees.expressions.functions;
import org.apache.doris.catalog.FunctionSignature;
import org.apache.doris.nereids.exceptions.AnalysisException;
import org.apache.doris.nereids.exceptions.UnboundException;
import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.expressions.OrderExpression;
import org.apache.doris.nereids.trees.expressions.functions.agg.GroupConcat;
import org.apache.doris.nereids.trees.expressions.functions.agg.MultiDistinctGroupConcat;
import org.apache.doris.nereids.trees.expressions.visitor.ExpressionVisitor;
import org.apache.doris.nereids.util.Utils;
@ -98,4 +102,17 @@ public abstract class BoundFunction extends Function implements ComputeSignature
.collect(Collectors.joining(", "));
return getName() + "(" + args + ")";
}
/**
* checkOrderExprIsValid.
*/
public void checkOrderExprIsValid() {
for (Expression child : children) {
if (child instanceof OrderExpression
&& !(this instanceof GroupConcat || this instanceof MultiDistinctGroupConcat)) {
throw new AnalysisException(
String.format("%s doesn't support order by expression", getName()));
}
}
}
}

View File

@ -134,4 +134,8 @@ public abstract class AggregateFunction extends BoundFunction implements Expects
public List<Expression> getDistinctArguments() {
return distinct ? getArguments() : ImmutableList.of();
}
public boolean mustUseMultiDistinctAgg() {
return false;
}
}

View File

@ -52,57 +52,30 @@ public class GroupConcat extends NullableAggregateFunction
/**
* constructor with 1 argument.
*/
public GroupConcat(boolean distinct, boolean alwaysNullable, Expression arg, OrderExpression... orders) {
super("group_concat", distinct, alwaysNullable, ExpressionUtils.mergeArguments(arg, orders));
this.nonOrderArguments = 1;
checkArguments();
public GroupConcat(boolean distinct, boolean alwaysNullable, Expression arg, Expression... others) {
this(distinct, alwaysNullable, ExpressionUtils.mergeArguments(arg, others));
}
/**
* constructor with 1 argument.
*/
public GroupConcat(boolean distinct, Expression arg, OrderExpression... orders) {
this(distinct, false, arg, orders);
public GroupConcat(boolean distinct, Expression arg, Expression... others) {
this(distinct, false, arg, others);
}
/**
* constructor with 1 argument, use for function search.
*/
public GroupConcat(Expression arg, OrderExpression... orders) {
this(false, arg, orders);
}
/**
* constructor with 2 arguments.
*/
public GroupConcat(boolean distinct, boolean alwaysNullable,
Expression arg0, Expression arg1, OrderExpression... orders) {
super("group_concat", distinct, alwaysNullable, ExpressionUtils.mergeArguments(arg0, arg1, orders));
this.nonOrderArguments = 2;
checkArguments();
}
/**
* constructor with 2 arguments.
*/
public GroupConcat(boolean distinct, Expression arg0, Expression arg1, OrderExpression... orders) {
this(distinct, false, arg0, arg1, orders);
}
/**
* constructor with 2 arguments, use for function search.
*/
public GroupConcat(Expression arg0, Expression arg1, OrderExpression... orders) {
this(false, arg0, arg1, orders);
public GroupConcat(Expression arg, Expression... others) {
this(false, arg, others);
}
/**
* constructor for always nullable.
*/
public GroupConcat(boolean distinct, boolean alwaysNullable, int nonOrderArguments, List<Expression> args) {
public GroupConcat(boolean distinct, boolean alwaysNullable, List<Expression> args) {
super("group_concat", distinct, alwaysNullable, args);
this.nonOrderArguments = nonOrderArguments;
checkArguments();
this.nonOrderArguments = findOrderExprIndex(children);
}
@Override
@ -139,7 +112,7 @@ public class GroupConcat extends NullableAggregateFunction
@Override
public GroupConcat withAlwaysNullable(boolean alwaysNullable) {
return new GroupConcat(distinct, alwaysNullable, nonOrderArguments, children);
return new GroupConcat(distinct, alwaysNullable, children);
}
/**
@ -147,30 +120,7 @@ public class GroupConcat extends NullableAggregateFunction
*/
@Override
public GroupConcat withDistinctAndChildren(boolean distinct, List<Expression> children) {
Preconditions.checkArgument(children().size() >= 1);
boolean foundOrderExpr = false;
int firstOrderExrIndex = 0;
for (int i = 0; i < children.size(); i++) {
Expression child = children.get(i);
if (child instanceof OrderExpression) {
foundOrderExpr = true;
} else if (!foundOrderExpr) {
firstOrderExrIndex++;
} else {
throw new AnalysisException("invalid group_concat parameters: " + children);
}
}
List<OrderExpression> orders = (List) children.subList(firstOrderExrIndex, children.size());
if (firstOrderExrIndex == 1) {
return new GroupConcat(distinct, alwaysNullable,
children.get(0), orders.toArray(new OrderExpression[0]));
} else if (firstOrderExrIndex == 2) {
return new GroupConcat(distinct, alwaysNullable,
children.get(0), children.get(1), orders.toArray(new OrderExpression[0]));
} else {
throw new AnalysisException("group_concat requires one or two parameters: " + children);
}
return new GroupConcat(distinct, alwaysNullable, children);
}
@Override
@ -186,15 +136,34 @@ public class GroupConcat extends NullableAggregateFunction
public MultiDistinctGroupConcat convertToMultiDistinct() {
Preconditions.checkArgument(distinct,
"can't convert to multi_distinct_group_concat because there is no distinct args");
return new MultiDistinctGroupConcat(alwaysNullable, nonOrderArguments, children);
return new MultiDistinctGroupConcat(alwaysNullable, children);
}
// TODO: because of current be's limitation, we have to thow exception for now
// remove this after be support new method of multi distinct functions
private void checkArguments() {
if (isDistinct() && children().stream().anyMatch(expression -> expression instanceof OrderExpression)) {
throw new AnalysisException(
"group_concat don't support using distinct with order by together");
@Override
public boolean mustUseMultiDistinctAgg() {
return distinct && children.stream().anyMatch(OrderExpression.class::isInstance);
}
private int findOrderExprIndex(List<Expression> children) {
Preconditions.checkArgument(children().size() >= 1, "children's size should >= 1");
boolean foundOrderExpr = false;
int firstOrderExrIndex = 0;
for (int i = 0; i < children.size(); i++) {
Expression child = children.get(i);
if (child instanceof OrderExpression) {
foundOrderExpr = true;
} else if (!foundOrderExpr) {
firstOrderExrIndex++;
} else {
throw new AnalysisException(
"invalid multi_distinct_group_concat parameters: " + children);
}
}
if (firstOrderExrIndex > 2) {
throw new AnalysisException(
"multi_distinct_group_concat requires one or two parameters: " + children);
}
return firstOrderExrIndex;
}
}

View File

@ -18,6 +18,7 @@
package org.apache.doris.nereids.trees.expressions.functions.agg;
import org.apache.doris.catalog.FunctionSignature;
import org.apache.doris.nereids.analyzer.Unbound;
import org.apache.doris.nereids.trees.expressions.Cast;
import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.expressions.functions.AlwaysNotNullable;
@ -36,35 +37,35 @@ import java.util.List;
/** MultiDistinctCount */
public class MultiDistinctCount extends AggregateFunction
implements AlwaysNotNullable, ExplicitlyCastableSignature, MultiDistinction {
public static final List<FunctionSignature> SIGNATURES = ImmutableList.of(
FunctionSignature.ret(BigIntType.INSTANCE).varArgs(AnyDataType.INSTANCE_WITHOUT_INDEX)
);
private final boolean mustUseMultiDistinctAgg;
// MultiDistinctCount is created in AggregateStrategies phase
// can't change getSignatures to use type coercion rule to add a cast expr
// because AggregateStrategies phase is after type coercion
public MultiDistinctCount(Expression arg0, Expression... varArgs) {
super("multi_distinct_count", true, ExpressionUtils.mergeArguments(arg0, varArgs).stream()
.map(arg -> arg.getDataType() instanceof DateLikeType ? new Cast(arg, BigIntType.INSTANCE) : arg)
.collect(ImmutableList.toImmutableList()));
this(false, arg0, varArgs);
}
public MultiDistinctCount(boolean distinct, Expression arg0, Expression... varArgs) {
super("multi_distinct_count", distinct, ExpressionUtils.mergeArguments(arg0, varArgs).stream()
.map(arg -> arg.getDataType() instanceof DateLikeType ? new Cast(arg, BigIntType.INSTANCE) : arg)
this(false, false, ExpressionUtils.mergeArguments(arg0, varArgs));
}
private MultiDistinctCount(boolean mustUseMultiDistinctAgg, boolean distinct, List<Expression> children) {
super("multi_distinct_count", false, children
.stream()
.map(arg -> !(arg instanceof Unbound) && arg.getDataType() instanceof DateLikeType
? new Cast(arg, BigIntType.INSTANCE) : arg)
.collect(ImmutableList.toImmutableList()));
this.mustUseMultiDistinctAgg = mustUseMultiDistinctAgg;
}
@Override
public MultiDistinctCount withDistinctAndChildren(boolean distinct, List<Expression> children) {
Preconditions.checkArgument(children.size() > 0);
if (children.size() > 1) {
return new MultiDistinctCount(distinct, children.get(0),
children.subList(1, children.size()).toArray(new Expression[0]));
} else {
return new MultiDistinctCount(distinct, children.get(0));
}
return new MultiDistinctCount(mustUseMultiDistinctAgg, distinct, children);
}
@Override
@ -76,4 +77,14 @@ public class MultiDistinctCount extends AggregateFunction
public List<FunctionSignature> getSignatures() {
return SIGNATURES;
}
@Override
public boolean mustUseMultiDistinctAgg() {
return mustUseMultiDistinctAgg;
}
@Override
public Expression withMustUseMultiDistinctAgg(boolean mustUseMultiDistinctAgg) {
return new MultiDistinctCount(mustUseMultiDistinctAgg, false, children);
}
}

View File

@ -37,7 +37,6 @@ import java.util.List;
/** MultiDistinctGroupConcat */
public class MultiDistinctGroupConcat extends NullableAggregateFunction
implements ExplicitlyCastableSignature, MultiDistinction {
public static final List<FunctionSignature> SIGNATURES = ImmutableList.of(
FunctionSignature.ret(VarcharType.SYSTEM_DEFAULT).args(VarcharType.SYSTEM_DEFAULT),
FunctionSignature.ret(VarcharType.SYSTEM_DEFAULT).varArgs(VarcharType.SYSTEM_DEFAULT,
@ -57,49 +56,31 @@ public class MultiDistinctGroupConcat extends NullableAggregateFunction
FunctionSignature.ret(CharType.SYSTEM_DEFAULT).varArgs(CharType.SYSTEM_DEFAULT,
CharType.SYSTEM_DEFAULT, AnyDataType.INSTANCE_WITHOUT_INDEX));
private final int nonOrderArguments;
private final boolean mustUseMultiDistinctAgg;
/**
* constructor with 1 argument.
* constructor with 1 argument with other arguments.
*/
public MultiDistinctGroupConcat(boolean alwaysNullable, Expression arg,
OrderExpression... orders) {
super("multi_distinct_group_concat", true, alwaysNullable,
ExpressionUtils.mergeArguments(arg, orders));
this.nonOrderArguments = 1;
public MultiDistinctGroupConcat(Expression arg, Expression... others) {
this(false, arg, others);
}
/**
* constructor with 1 argument.
* constructor with argument list.
*/
public MultiDistinctGroupConcat(Expression arg, OrderExpression... orders) {
this(false, arg, orders);
public MultiDistinctGroupConcat(boolean alwaysNullable, List<Expression> args) {
this(false, alwaysNullable, args);
}
/**
* constructor with 2 arguments.
*/
public MultiDistinctGroupConcat(boolean alwaysNullable, Expression arg0,
Expression arg1, OrderExpression... orders) {
super("multi_distinct_group_concat", true, alwaysNullable,
ExpressionUtils.mergeArguments(arg0, arg1, orders));
this.nonOrderArguments = 2;
private MultiDistinctGroupConcat(boolean alwaysNullable, Expression arg,
Expression... others) {
this(alwaysNullable, ExpressionUtils.mergeArguments(arg, others));
}
/**
* constructor with 2 arguments.
*/
public MultiDistinctGroupConcat(Expression arg0, Expression arg1, OrderExpression... orders) {
this(false, arg0, arg1, orders);
}
/**
* constructor for always nullable.
*/
public MultiDistinctGroupConcat(boolean alwaysNullable, int nonOrderArguments,
List<Expression> args) {
super("multi_distinct_group_concat", true, alwaysNullable, args);
this.nonOrderArguments = nonOrderArguments;
private MultiDistinctGroupConcat(boolean mustUseMultiDistinctAgg, boolean alwaysNullable, List<Expression> args) {
super("multi_distinct_group_concat", false, alwaysNullable, args);
checkArguments(children);
this.mustUseMultiDistinctAgg = mustUseMultiDistinctAgg;
}
@Override
@ -110,7 +91,7 @@ public class MultiDistinctGroupConcat extends NullableAggregateFunction
@Override
public MultiDistinctGroupConcat withAlwaysNullable(boolean alwaysNullable) {
return new MultiDistinctGroupConcat(alwaysNullable, nonOrderArguments, children);
return new MultiDistinctGroupConcat(mustUseMultiDistinctAgg, alwaysNullable, children);
}
/**
@ -118,7 +99,22 @@ public class MultiDistinctGroupConcat extends NullableAggregateFunction
*/
@Override
public MultiDistinctGroupConcat withDistinctAndChildren(boolean distinct, List<Expression> children) {
Preconditions.checkArgument(children().size() >= 1);
checkArguments(children);
return new MultiDistinctGroupConcat(mustUseMultiDistinctAgg, alwaysNullable, children);
}
@Override
public <R, C> R accept(ExpressionVisitor<R, C> visitor, C context) {
return visitor.visitMultiDistinctGroupConcat(this, context);
}
@Override
public List<FunctionSignature> getSignatures() {
return SIGNATURES;
}
private void checkArguments(List<Expression> children) {
Preconditions.checkArgument(children().size() >= 1, "children's size should >= 1");
boolean foundOrderExpr = false;
int firstOrderExrIndex = 0;
for (int i = 0; i < children.size(); i++) {
@ -133,26 +129,19 @@ public class MultiDistinctGroupConcat extends NullableAggregateFunction
}
}
List<OrderExpression> orders = (List) children.subList(firstOrderExrIndex, children.size());
if (firstOrderExrIndex == 1) {
return new MultiDistinctGroupConcat(alwaysNullable, children.get(0),
orders.toArray(new OrderExpression[0]));
} else if (firstOrderExrIndex == 2) {
return new MultiDistinctGroupConcat(alwaysNullable, children.get(0),
children.get(1), orders.toArray(new OrderExpression[0]));
} else {
if (firstOrderExrIndex > 2) {
throw new AnalysisException(
"multi_distinct_group_concat requires one or two parameters: " + children);
}
}
@Override
public <R, C> R accept(ExpressionVisitor<R, C> visitor, C context) {
return visitor.visitMultiDistinctGroupConcat(this, context);
public boolean mustUseMultiDistinctAgg() {
return mustUseMultiDistinctAgg || children.stream().anyMatch(OrderExpression.class::isInstance);
}
@Override
public List<FunctionSignature> getSignatures() {
return SIGNATURES;
public Expression withMustUseMultiDistinctAgg(boolean mustUseMultiDistinctAgg) {
return new MultiDistinctGroupConcat(mustUseMultiDistinctAgg, alwaysNullable, children);
}
}

View File

@ -43,16 +43,24 @@ public class MultiDistinctSum extends NullableAggregateFunction implements Unary
FunctionSignature.ret(BigIntType.INSTANCE).varArgs(LargeIntType.INSTANCE)
);
private final boolean mustUseMultiDistinctAgg;
public MultiDistinctSum(Expression arg0) {
super("multi_distinct_sum", true, false, arg0);
this(false, arg0);
}
public MultiDistinctSum(boolean distinct, Expression arg0) {
super("multi_distinct_sum", true, false, arg0);
this(false, false, arg0);
}
public MultiDistinctSum(boolean distinct, boolean alwaysNullable, Expression arg0) {
super("multi_distinct_sum", true, alwaysNullable, arg0);
this(false, false, alwaysNullable, arg0);
}
private MultiDistinctSum(boolean mustUseMultiDistinctAgg, boolean distinct,
boolean alwaysNullable, Expression arg0) {
super("multi_distinct_sum", false, alwaysNullable, arg0);
this.mustUseMultiDistinctAgg = mustUseMultiDistinctAgg;
}
@Override
@ -74,17 +82,27 @@ public class MultiDistinctSum extends NullableAggregateFunction implements Unary
@Override
public NullableAggregateFunction withAlwaysNullable(boolean alwaysNullable) {
return new MultiDistinctSum(distinct, alwaysNullable, children.get(0));
return new MultiDistinctSum(mustUseMultiDistinctAgg, distinct, alwaysNullable, children.get(0));
}
@Override
public MultiDistinctSum withDistinctAndChildren(boolean distinct, List<Expression> children) {
Preconditions.checkArgument(children.size() == 1);
return new MultiDistinctSum(distinct, children.get(0));
return new MultiDistinctSum(mustUseMultiDistinctAgg, distinct, alwaysNullable, children.get(0));
}
@Override
public <R, C> R accept(ExpressionVisitor<R, C> visitor, C context) {
return visitor.visitMultiDistinctSum(this, context);
}
@Override
public boolean mustUseMultiDistinctAgg() {
return mustUseMultiDistinctAgg;
}
@Override
public Expression withMustUseMultiDistinctAgg(boolean mustUseMultiDistinctAgg) {
return new MultiDistinctSum(mustUseMultiDistinctAgg, false, alwaysNullable, children.get(0));
}
}

View File

@ -44,12 +44,19 @@ public class MultiDistinctSum0 extends AggregateFunction implements UnaryExpress
FunctionSignature.ret(BigIntType.INSTANCE).varArgs(LargeIntType.INSTANCE)
);
private final boolean mustUseMultiDistinctAgg;
public MultiDistinctSum0(Expression arg0) {
super("multi_distinct_sum0", true, arg0);
this(false, arg0);
}
public MultiDistinctSum0(boolean distinct, Expression arg0) {
super("multi_distinct_sum0", true, arg0);
this(false, false, arg0);
}
private MultiDistinctSum0(boolean mustUseMultiDistinctAgg, boolean distinct, Expression arg0) {
super("multi_distinct_sum0", false, arg0);
this.mustUseMultiDistinctAgg = mustUseMultiDistinctAgg;
}
@Override
@ -72,11 +79,21 @@ public class MultiDistinctSum0 extends AggregateFunction implements UnaryExpress
@Override
public MultiDistinctSum0 withDistinctAndChildren(boolean distinct, List<Expression> children) {
Preconditions.checkArgument(children.size() == 1);
return new MultiDistinctSum0(distinct, children.get(0));
return new MultiDistinctSum0(mustUseMultiDistinctAgg, distinct, children.get(0));
}
@Override
public <R, C> R accept(ExpressionVisitor<R, C> visitor, C context) {
return visitor.visitMultiDistinctSum0(this, context);
}
@Override
public boolean mustUseMultiDistinctAgg() {
return mustUseMultiDistinctAgg;
}
@Override
public Expression withMustUseMultiDistinctAgg(boolean mustUseMultiDistinctAgg) {
return new MultiDistinctSum0(mustUseMultiDistinctAgg, false, children.get(0));
}
}

View File

@ -24,4 +24,5 @@ import org.apache.doris.nereids.trees.expressions.Expression;
* base class of multi-distinct agg function
*/
public interface MultiDistinction extends TreeNode<Expression> {
Expression withMustUseMultiDistinctAgg(boolean mustUseMultiDistinctAgg);
}

View File

@ -21,7 +21,9 @@ import org.apache.doris.catalog.Env;
import org.apache.doris.catalog.FunctionRegistry;
import org.apache.doris.catalog.FunctionSignature;
import org.apache.doris.common.Pair;
import org.apache.doris.nereids.exceptions.AnalysisException;
import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.expressions.OrderExpression;
import org.apache.doris.nereids.trees.expressions.functions.AggCombinerFunctionBuilder;
import org.apache.doris.nereids.trees.expressions.functions.AlwaysNotNullable;
import org.apache.doris.nereids.trees.expressions.functions.BoundFunction;
@ -55,6 +57,12 @@ public class StateCombinator extends ScalarFunction
*/
public StateCombinator(List<Expression> arguments, AggregateFunction nested) {
super(nested.getName() + AggCombinerFunctionBuilder.STATE_SUFFIX, arguments);
for (Expression arg : arguments) {
if (arg instanceof OrderExpression) {
throw new AnalysisException(String
.format("%s_state doesn't support order by expression", nested.getName()));
}
}
this.nested = Objects.requireNonNull(nested, "nested can not be null");
this.returnType = new AggStateType(nested.getName(), arguments.stream().map(arg -> {

View File

@ -90,4 +90,14 @@ public interface Aggregate<CHILD_TYPE extends Plan> extends UnaryPlan<CHILD_TYPE
}
return hasDistinctArguments.get();
}
/** mustUseMultiDistinctAgg */
default boolean mustUseMultiDistinctAgg() {
for (AggregateFunction aggregateFunction : getAggregateFunctions()) {
if (aggregateFunction.mustUseMultiDistinctAgg()) {
return true;
}
}
return false;
}
}