[fix](nereids)support group_concat with distinct and order by (#38871)
## Proposed changes pick from master https://github.com/apache/doris/pull/38080 <!--Describe your changes.-->
This commit is contained in:
@ -624,7 +624,9 @@ public class ExpressionTranslator extends DefaultExpressionVisitor<Expr, PlanTra
|
||||
|
||||
@Override
|
||||
public Expr visitStateCombinator(StateCombinator combinator, PlanTranslatorContext context) {
|
||||
List<Expr> arguments = combinator.getArguments().stream().map(arg -> arg.accept(this, context))
|
||||
List<Expr> arguments = combinator.getArguments().stream().map(arg -> arg instanceof OrderExpression
|
||||
? translateOrderExpression((OrderExpression) arg, context).getExpr()
|
||||
: arg.accept(this, context))
|
||||
.collect(Collectors.toList());
|
||||
return Function.convertToStateCombinator(
|
||||
new FunctionCallExpr(visitAggregateFunction(combinator.getNestedFunction(), context).getFn(),
|
||||
|
||||
@ -274,7 +274,6 @@ import org.apache.doris.nereids.trees.expressions.WindowExpression;
|
||||
import org.apache.doris.nereids.trees.expressions.WindowFrame;
|
||||
import org.apache.doris.nereids.trees.expressions.functions.Function;
|
||||
import org.apache.doris.nereids.trees.expressions.functions.agg.Count;
|
||||
import org.apache.doris.nereids.trees.expressions.functions.agg.GroupConcat;
|
||||
import org.apache.doris.nereids.trees.expressions.functions.scalar.Array;
|
||||
import org.apache.doris.nereids.trees.expressions.functions.scalar.ArrayRange;
|
||||
import org.apache.doris.nereids.trees.expressions.functions.scalar.ArrayRangeDayUnit;
|
||||
@ -2096,11 +2095,10 @@ public class LogicalPlanBuilder extends DorisParserBaseVisitor<Object> {
|
||||
return ParserUtils.withOrigin(ctx, () -> {
|
||||
String functionName = ctx.functionIdentifier().functionNameIdentifier().getText();
|
||||
boolean isDistinct = ctx.DISTINCT() != null;
|
||||
List<Expression> params = visit(ctx.expression(), Expression.class);
|
||||
List<Expression> params = Lists.newArrayList();
|
||||
params.addAll(visit(ctx.expression(), Expression.class));
|
||||
List<OrderKey> orderKeys = visit(ctx.sortItem(), OrderKey.class);
|
||||
if (!orderKeys.isEmpty()) {
|
||||
return parseFunctionWithOrderKeys(functionName, isDistinct, params, orderKeys, ctx);
|
||||
}
|
||||
params.addAll(orderKeys.stream().map(OrderExpression::new).collect(Collectors.toList()));
|
||||
|
||||
List<UnboundStar> unboundStars = ExpressionUtils.collectAll(params, UnboundStar.class::isInstance);
|
||||
if (!unboundStars.isEmpty()) {
|
||||
@ -3471,23 +3469,6 @@ public class LogicalPlanBuilder extends DorisParserBaseVisitor<Object> {
|
||||
return new StructField(ctx.identifier().getText(), typedVisit(ctx.dataType()), true, comment);
|
||||
}
|
||||
|
||||
private Expression parseFunctionWithOrderKeys(String functionName, boolean isDistinct,
|
||||
List<Expression> params, List<OrderKey> orderKeys, ParserRuleContext ctx) {
|
||||
if (functionName.equalsIgnoreCase("group_concat")) {
|
||||
OrderExpression[] orderExpressions = orderKeys.stream()
|
||||
.map(OrderExpression::new)
|
||||
.toArray(OrderExpression[]::new);
|
||||
if (params.size() == 1) {
|
||||
return new GroupConcat(isDistinct, params.get(0), orderExpressions);
|
||||
} else if (params.size() == 2) {
|
||||
return new GroupConcat(isDistinct, params.get(0), params.get(1), orderExpressions);
|
||||
} else {
|
||||
throw new ParseException("group_concat requires one or two parameters: " + params, ctx);
|
||||
}
|
||||
}
|
||||
throw new ParseException("Unsupported function with order expressions" + ctx.getText(), ctx);
|
||||
}
|
||||
|
||||
private String parseConstant(ConstantContext context) {
|
||||
Object constant = visit(context);
|
||||
if (constant instanceof Literal && ((Literal) constant).isStringLikeLiteral()) {
|
||||
|
||||
@ -159,12 +159,15 @@ public class ChildrenPropertiesRegulator extends PlanVisitor<Boolean, Void> {
|
||||
distinctChildColumns, ShuffleType.REQUIRE);
|
||||
if ((!groupByColumns.isEmpty() && distributionSpecHash.satisfy(groupByRequire))
|
||||
|| (groupByColumns.isEmpty() && distributionSpecHash.satisfy(distinctChildRequire))) {
|
||||
return false;
|
||||
if (!agg.mustUseMultiDistinctAgg()) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
}
|
||||
// if distinct without group by key, we prefer three or four stage distinct agg
|
||||
// because the second phase of multi-distinct only have one instance, and it is slow generally.
|
||||
if (agg.getOutputExpressions().size() == 1 && agg.getGroupByExpressions().isEmpty()) {
|
||||
if (agg.getOutputExpressions().size() == 1 && agg.getGroupByExpressions().isEmpty()
|
||||
&& !agg.mustUseMultiDistinctAgg()) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
@ -21,6 +21,7 @@ import org.apache.doris.nereids.exceptions.AnalysisException;
|
||||
import org.apache.doris.nereids.rules.Rule;
|
||||
import org.apache.doris.nereids.rules.RuleType;
|
||||
import org.apache.doris.nereids.trees.expressions.Expression;
|
||||
import org.apache.doris.nereids.trees.expressions.OrderExpression;
|
||||
import org.apache.doris.nereids.trees.expressions.WindowExpression;
|
||||
import org.apache.doris.nereids.trees.expressions.functions.agg.AggregateFunction;
|
||||
import org.apache.doris.nereids.trees.expressions.functions.generator.TableGeneratingFunction;
|
||||
@ -148,7 +149,7 @@ public class CheckAnalysis implements AnalysisRuleFactory {
|
||||
continue;
|
||||
}
|
||||
for (int i = 1; i < func.arity(); i++) {
|
||||
if (!func.child(i).getInputSlots().isEmpty()) {
|
||||
if (!func.child(i).getInputSlots().isEmpty() && !(func.child(i) instanceof OrderExpression)) {
|
||||
// think about group_concat(distinct col_1, ',')
|
||||
distinctMultiColumns = true;
|
||||
break;
|
||||
|
||||
@ -403,6 +403,7 @@ public class ExpressionAnalyzer extends SubExprAnalyzer<ExpressionRewriteContext
|
||||
}
|
||||
|
||||
Pair<? extends Expression, ? extends BoundFunction> buildResult = builder.build(functionName, arguments);
|
||||
buildResult.second.checkOrderExprIsValid();
|
||||
Optional<SqlCacheContext> sqlCacheContext = Optional.empty();
|
||||
if (wantToParseSqlFromSqlCache) {
|
||||
StatementContext statementContext = context.cascadesContext.getStatementContext();
|
||||
|
||||
@ -26,12 +26,14 @@ import org.apache.doris.nereids.trees.expressions.Alias;
|
||||
import org.apache.doris.nereids.trees.expressions.ExprId;
|
||||
import org.apache.doris.nereids.trees.expressions.Expression;
|
||||
import org.apache.doris.nereids.trees.expressions.NamedExpression;
|
||||
import org.apache.doris.nereids.trees.expressions.OrderExpression;
|
||||
import org.apache.doris.nereids.trees.expressions.Slot;
|
||||
import org.apache.doris.nereids.trees.expressions.SlotNotFromChildren;
|
||||
import org.apache.doris.nereids.trees.expressions.SlotReference;
|
||||
import org.apache.doris.nereids.trees.expressions.SubqueryExpr;
|
||||
import org.apache.doris.nereids.trees.expressions.WindowExpression;
|
||||
import org.apache.doris.nereids.trees.expressions.functions.agg.AggregateFunction;
|
||||
import org.apache.doris.nereids.trees.expressions.functions.agg.MultiDistinction;
|
||||
import org.apache.doris.nereids.trees.expressions.literal.Literal;
|
||||
import org.apache.doris.nereids.trees.plans.Plan;
|
||||
import org.apache.doris.nereids.trees.plans.logical.LogicalAggregate;
|
||||
@ -54,6 +56,7 @@ import java.util.Map;
|
||||
import java.util.Optional;
|
||||
import java.util.Set;
|
||||
import java.util.stream.Collectors;
|
||||
import java.util.stream.Stream;
|
||||
|
||||
/**
|
||||
* normalize aggregate's group keys and AggregateFunction's child to SlotReference
|
||||
@ -170,6 +173,7 @@ public class NormalizeAggregate implements RewriteRuleFactory, NormalizeToSlot {
|
||||
// should not push down literal under aggregate
|
||||
// e.g. group_concat(distinct xxx, ','), the ',' literal show stay in aggregate
|
||||
.filter(arg -> !(arg instanceof Literal))
|
||||
.flatMap(arg -> arg instanceof OrderExpression ? arg.getInputSlots().stream() : Stream.of(arg))
|
||||
.collect(
|
||||
Collectors.groupingBy(
|
||||
child -> !(child instanceof SlotReference),
|
||||
@ -255,7 +259,15 @@ public class NormalizeAggregate implements RewriteRuleFactory, NormalizeToSlot {
|
||||
normalizedAggFuncsToSlotContext.pushDownToNamedExpression(normalizedAggFuncs)
|
||||
);
|
||||
// create new agg node
|
||||
ImmutableList<NamedExpression> normalizedAggOutput = normalizedAggOutputBuilder.build();
|
||||
ImmutableList<NamedExpression> aggOutput = normalizedAggOutputBuilder.build();
|
||||
ImmutableList.Builder<NamedExpression> newAggOutputBuilder
|
||||
= ImmutableList.builderWithExpectedSize(aggOutput.size());
|
||||
for (NamedExpression output : aggOutput) {
|
||||
Expression rewrittenExpr = output.rewriteDownShortCircuit(
|
||||
e -> e instanceof MultiDistinction ? ((MultiDistinction) e).withMustUseMultiDistinctAgg(true) : e);
|
||||
newAggOutputBuilder.add((NamedExpression) rewrittenExpr);
|
||||
}
|
||||
ImmutableList<NamedExpression> normalizedAggOutput = newAggOutputBuilder.build();
|
||||
LogicalAggregate<?> newAggregate =
|
||||
aggregate.withNormalized(normalizedGroupExprs, normalizedAggOutput, bottomPlan);
|
||||
|
||||
|
||||
@ -40,6 +40,7 @@ import org.apache.doris.nereids.trees.expressions.Cast;
|
||||
import org.apache.doris.nereids.trees.expressions.Expression;
|
||||
import org.apache.doris.nereids.trees.expressions.IsNull;
|
||||
import org.apache.doris.nereids.trees.expressions.NamedExpression;
|
||||
import org.apache.doris.nereids.trees.expressions.OrderExpression;
|
||||
import org.apache.doris.nereids.trees.expressions.Slot;
|
||||
import org.apache.doris.nereids.trees.expressions.SlotReference;
|
||||
import org.apache.doris.nereids.trees.expressions.functions.ExpressionTrait;
|
||||
@ -296,6 +297,7 @@ public class AggregateStrategies implements ImplementationRuleFactory {
|
||||
RuleType.THREE_PHASE_AGGREGATE_WITH_DISTINCT.build(
|
||||
basePattern
|
||||
.when(agg -> agg.getDistinctArguments().size() == 1)
|
||||
.whenNot(agg -> agg.mustUseMultiDistinctAgg())
|
||||
.thenApplyMulti(ctx -> threePhaseAggregateWithDistinct(ctx.root, ctx.connectContext))
|
||||
),
|
||||
/*
|
||||
@ -319,6 +321,7 @@ public class AggregateStrategies implements ImplementationRuleFactory {
|
||||
basePattern
|
||||
.when(agg -> agg.getDistinctArguments().size() == 1)
|
||||
.when(agg -> agg.getGroupByExpressions().isEmpty())
|
||||
.whenNot(agg -> agg.mustUseMultiDistinctAgg())
|
||||
.thenApplyMulti(ctx -> {
|
||||
Function<List<Expression>, RequireProperties> secondPhaseRequireDistinctHash =
|
||||
groupByAndDistinct -> RequireProperties.of(
|
||||
@ -1940,7 +1943,7 @@ public class AggregateStrategies implements ImplementationRuleFactory {
|
||||
}
|
||||
for (int i = 1; i < func.arity(); i++) {
|
||||
// think about group_concat(distinct col_1, ',')
|
||||
if (!func.child(i).getInputSlots().isEmpty()) {
|
||||
if (!(func.child(i) instanceof OrderExpression) && !func.child(i).getInputSlots().isEmpty()) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
@ -124,6 +124,12 @@ public class WindowExpression extends Expression {
|
||||
.orElseGet(() -> new WindowExpression(function, partitionKeys, orderKeys));
|
||||
}
|
||||
|
||||
public WindowExpression withFunctionPartitionKeysOrderKeys(Expression function,
|
||||
List<Expression> partitionKeys, List<OrderExpression> orderKeys) {
|
||||
return windowFrame.map(frame -> new WindowExpression(function, partitionKeys, orderKeys, frame))
|
||||
.orElseGet(() -> new WindowExpression(function, partitionKeys, orderKeys));
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean nullable() {
|
||||
return function.nullable();
|
||||
|
||||
@ -127,6 +127,10 @@ public class AggCombinerFunctionBuilder extends FunctionBuilder {
|
||||
String nestedName = getNestedName(name);
|
||||
if (combinatorSuffix.equalsIgnoreCase(STATE)) {
|
||||
AggregateFunction nestedFunction = buildState(nestedName, arguments);
|
||||
// distinct will be passed as 1st boolean true arg. remove it
|
||||
if (!arguments.isEmpty() && arguments.get(0) instanceof Boolean && (Boolean) arguments.get(0)) {
|
||||
arguments = arguments.subList(1, arguments.size());
|
||||
}
|
||||
return Pair.of(new StateCombinator((List<Expression>) arguments, nestedFunction), nestedFunction);
|
||||
} else if (combinatorSuffix.equalsIgnoreCase(MERGE)) {
|
||||
AggregateFunction nestedFunction = buildMergeOrUnion(nestedName, arguments);
|
||||
|
||||
@ -18,8 +18,12 @@
|
||||
package org.apache.doris.nereids.trees.expressions.functions;
|
||||
|
||||
import org.apache.doris.catalog.FunctionSignature;
|
||||
import org.apache.doris.nereids.exceptions.AnalysisException;
|
||||
import org.apache.doris.nereids.exceptions.UnboundException;
|
||||
import org.apache.doris.nereids.trees.expressions.Expression;
|
||||
import org.apache.doris.nereids.trees.expressions.OrderExpression;
|
||||
import org.apache.doris.nereids.trees.expressions.functions.agg.GroupConcat;
|
||||
import org.apache.doris.nereids.trees.expressions.functions.agg.MultiDistinctGroupConcat;
|
||||
import org.apache.doris.nereids.trees.expressions.visitor.ExpressionVisitor;
|
||||
import org.apache.doris.nereids.util.Utils;
|
||||
|
||||
@ -98,4 +102,17 @@ public abstract class BoundFunction extends Function implements ComputeSignature
|
||||
.collect(Collectors.joining(", "));
|
||||
return getName() + "(" + args + ")";
|
||||
}
|
||||
|
||||
/**
|
||||
* checkOrderExprIsValid.
|
||||
*/
|
||||
public void checkOrderExprIsValid() {
|
||||
for (Expression child : children) {
|
||||
if (child instanceof OrderExpression
|
||||
&& !(this instanceof GroupConcat || this instanceof MultiDistinctGroupConcat)) {
|
||||
throw new AnalysisException(
|
||||
String.format("%s doesn't support order by expression", getName()));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@ -134,4 +134,8 @@ public abstract class AggregateFunction extends BoundFunction implements Expects
|
||||
public List<Expression> getDistinctArguments() {
|
||||
return distinct ? getArguments() : ImmutableList.of();
|
||||
}
|
||||
|
||||
public boolean mustUseMultiDistinctAgg() {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
@ -52,57 +52,30 @@ public class GroupConcat extends NullableAggregateFunction
|
||||
/**
|
||||
* constructor with 1 argument.
|
||||
*/
|
||||
public GroupConcat(boolean distinct, boolean alwaysNullable, Expression arg, OrderExpression... orders) {
|
||||
super("group_concat", distinct, alwaysNullable, ExpressionUtils.mergeArguments(arg, orders));
|
||||
this.nonOrderArguments = 1;
|
||||
checkArguments();
|
||||
public GroupConcat(boolean distinct, boolean alwaysNullable, Expression arg, Expression... others) {
|
||||
this(distinct, alwaysNullable, ExpressionUtils.mergeArguments(arg, others));
|
||||
}
|
||||
|
||||
/**
|
||||
* constructor with 1 argument.
|
||||
*/
|
||||
public GroupConcat(boolean distinct, Expression arg, OrderExpression... orders) {
|
||||
this(distinct, false, arg, orders);
|
||||
public GroupConcat(boolean distinct, Expression arg, Expression... others) {
|
||||
this(distinct, false, arg, others);
|
||||
}
|
||||
|
||||
/**
|
||||
* constructor with 1 argument, use for function search.
|
||||
*/
|
||||
public GroupConcat(Expression arg, OrderExpression... orders) {
|
||||
this(false, arg, orders);
|
||||
}
|
||||
|
||||
/**
|
||||
* constructor with 2 arguments.
|
||||
*/
|
||||
public GroupConcat(boolean distinct, boolean alwaysNullable,
|
||||
Expression arg0, Expression arg1, OrderExpression... orders) {
|
||||
super("group_concat", distinct, alwaysNullable, ExpressionUtils.mergeArguments(arg0, arg1, orders));
|
||||
this.nonOrderArguments = 2;
|
||||
checkArguments();
|
||||
}
|
||||
|
||||
/**
|
||||
* constructor with 2 arguments.
|
||||
*/
|
||||
public GroupConcat(boolean distinct, Expression arg0, Expression arg1, OrderExpression... orders) {
|
||||
this(distinct, false, arg0, arg1, orders);
|
||||
}
|
||||
|
||||
/**
|
||||
* constructor with 2 arguments, use for function search.
|
||||
*/
|
||||
public GroupConcat(Expression arg0, Expression arg1, OrderExpression... orders) {
|
||||
this(false, arg0, arg1, orders);
|
||||
public GroupConcat(Expression arg, Expression... others) {
|
||||
this(false, arg, others);
|
||||
}
|
||||
|
||||
/**
|
||||
* constructor for always nullable.
|
||||
*/
|
||||
public GroupConcat(boolean distinct, boolean alwaysNullable, int nonOrderArguments, List<Expression> args) {
|
||||
public GroupConcat(boolean distinct, boolean alwaysNullable, List<Expression> args) {
|
||||
super("group_concat", distinct, alwaysNullable, args);
|
||||
this.nonOrderArguments = nonOrderArguments;
|
||||
checkArguments();
|
||||
this.nonOrderArguments = findOrderExprIndex(children);
|
||||
}
|
||||
|
||||
@Override
|
||||
@ -139,7 +112,7 @@ public class GroupConcat extends NullableAggregateFunction
|
||||
|
||||
@Override
|
||||
public GroupConcat withAlwaysNullable(boolean alwaysNullable) {
|
||||
return new GroupConcat(distinct, alwaysNullable, nonOrderArguments, children);
|
||||
return new GroupConcat(distinct, alwaysNullable, children);
|
||||
}
|
||||
|
||||
/**
|
||||
@ -147,30 +120,7 @@ public class GroupConcat extends NullableAggregateFunction
|
||||
*/
|
||||
@Override
|
||||
public GroupConcat withDistinctAndChildren(boolean distinct, List<Expression> children) {
|
||||
Preconditions.checkArgument(children().size() >= 1);
|
||||
boolean foundOrderExpr = false;
|
||||
int firstOrderExrIndex = 0;
|
||||
for (int i = 0; i < children.size(); i++) {
|
||||
Expression child = children.get(i);
|
||||
if (child instanceof OrderExpression) {
|
||||
foundOrderExpr = true;
|
||||
} else if (!foundOrderExpr) {
|
||||
firstOrderExrIndex++;
|
||||
} else {
|
||||
throw new AnalysisException("invalid group_concat parameters: " + children);
|
||||
}
|
||||
}
|
||||
|
||||
List<OrderExpression> orders = (List) children.subList(firstOrderExrIndex, children.size());
|
||||
if (firstOrderExrIndex == 1) {
|
||||
return new GroupConcat(distinct, alwaysNullable,
|
||||
children.get(0), orders.toArray(new OrderExpression[0]));
|
||||
} else if (firstOrderExrIndex == 2) {
|
||||
return new GroupConcat(distinct, alwaysNullable,
|
||||
children.get(0), children.get(1), orders.toArray(new OrderExpression[0]));
|
||||
} else {
|
||||
throw new AnalysisException("group_concat requires one or two parameters: " + children);
|
||||
}
|
||||
return new GroupConcat(distinct, alwaysNullable, children);
|
||||
}
|
||||
|
||||
@Override
|
||||
@ -186,15 +136,34 @@ public class GroupConcat extends NullableAggregateFunction
|
||||
public MultiDistinctGroupConcat convertToMultiDistinct() {
|
||||
Preconditions.checkArgument(distinct,
|
||||
"can't convert to multi_distinct_group_concat because there is no distinct args");
|
||||
return new MultiDistinctGroupConcat(alwaysNullable, nonOrderArguments, children);
|
||||
return new MultiDistinctGroupConcat(alwaysNullable, children);
|
||||
}
|
||||
|
||||
// TODO: because of current be's limitation, we have to thow exception for now
|
||||
// remove this after be support new method of multi distinct functions
|
||||
private void checkArguments() {
|
||||
if (isDistinct() && children().stream().anyMatch(expression -> expression instanceof OrderExpression)) {
|
||||
throw new AnalysisException(
|
||||
"group_concat don't support using distinct with order by together");
|
||||
@Override
|
||||
public boolean mustUseMultiDistinctAgg() {
|
||||
return distinct && children.stream().anyMatch(OrderExpression.class::isInstance);
|
||||
}
|
||||
|
||||
private int findOrderExprIndex(List<Expression> children) {
|
||||
Preconditions.checkArgument(children().size() >= 1, "children's size should >= 1");
|
||||
boolean foundOrderExpr = false;
|
||||
int firstOrderExrIndex = 0;
|
||||
for (int i = 0; i < children.size(); i++) {
|
||||
Expression child = children.get(i);
|
||||
if (child instanceof OrderExpression) {
|
||||
foundOrderExpr = true;
|
||||
} else if (!foundOrderExpr) {
|
||||
firstOrderExrIndex++;
|
||||
} else {
|
||||
throw new AnalysisException(
|
||||
"invalid multi_distinct_group_concat parameters: " + children);
|
||||
}
|
||||
}
|
||||
|
||||
if (firstOrderExrIndex > 2) {
|
||||
throw new AnalysisException(
|
||||
"multi_distinct_group_concat requires one or two parameters: " + children);
|
||||
}
|
||||
return firstOrderExrIndex;
|
||||
}
|
||||
}
|
||||
|
||||
@ -18,6 +18,7 @@
|
||||
package org.apache.doris.nereids.trees.expressions.functions.agg;
|
||||
|
||||
import org.apache.doris.catalog.FunctionSignature;
|
||||
import org.apache.doris.nereids.analyzer.Unbound;
|
||||
import org.apache.doris.nereids.trees.expressions.Cast;
|
||||
import org.apache.doris.nereids.trees.expressions.Expression;
|
||||
import org.apache.doris.nereids.trees.expressions.functions.AlwaysNotNullable;
|
||||
@ -36,35 +37,35 @@ import java.util.List;
|
||||
/** MultiDistinctCount */
|
||||
public class MultiDistinctCount extends AggregateFunction
|
||||
implements AlwaysNotNullable, ExplicitlyCastableSignature, MultiDistinction {
|
||||
|
||||
public static final List<FunctionSignature> SIGNATURES = ImmutableList.of(
|
||||
FunctionSignature.ret(BigIntType.INSTANCE).varArgs(AnyDataType.INSTANCE_WITHOUT_INDEX)
|
||||
);
|
||||
private final boolean mustUseMultiDistinctAgg;
|
||||
|
||||
// MultiDistinctCount is created in AggregateStrategies phase
|
||||
// can't change getSignatures to use type coercion rule to add a cast expr
|
||||
// because AggregateStrategies phase is after type coercion
|
||||
public MultiDistinctCount(Expression arg0, Expression... varArgs) {
|
||||
super("multi_distinct_count", true, ExpressionUtils.mergeArguments(arg0, varArgs).stream()
|
||||
.map(arg -> arg.getDataType() instanceof DateLikeType ? new Cast(arg, BigIntType.INSTANCE) : arg)
|
||||
.collect(ImmutableList.toImmutableList()));
|
||||
this(false, arg0, varArgs);
|
||||
}
|
||||
|
||||
public MultiDistinctCount(boolean distinct, Expression arg0, Expression... varArgs) {
|
||||
super("multi_distinct_count", distinct, ExpressionUtils.mergeArguments(arg0, varArgs).stream()
|
||||
.map(arg -> arg.getDataType() instanceof DateLikeType ? new Cast(arg, BigIntType.INSTANCE) : arg)
|
||||
this(false, false, ExpressionUtils.mergeArguments(arg0, varArgs));
|
||||
}
|
||||
|
||||
private MultiDistinctCount(boolean mustUseMultiDistinctAgg, boolean distinct, List<Expression> children) {
|
||||
super("multi_distinct_count", false, children
|
||||
.stream()
|
||||
.map(arg -> !(arg instanceof Unbound) && arg.getDataType() instanceof DateLikeType
|
||||
? new Cast(arg, BigIntType.INSTANCE) : arg)
|
||||
.collect(ImmutableList.toImmutableList()));
|
||||
this.mustUseMultiDistinctAgg = mustUseMultiDistinctAgg;
|
||||
}
|
||||
|
||||
@Override
|
||||
public MultiDistinctCount withDistinctAndChildren(boolean distinct, List<Expression> children) {
|
||||
Preconditions.checkArgument(children.size() > 0);
|
||||
if (children.size() > 1) {
|
||||
return new MultiDistinctCount(distinct, children.get(0),
|
||||
children.subList(1, children.size()).toArray(new Expression[0]));
|
||||
} else {
|
||||
return new MultiDistinctCount(distinct, children.get(0));
|
||||
}
|
||||
return new MultiDistinctCount(mustUseMultiDistinctAgg, distinct, children);
|
||||
}
|
||||
|
||||
@Override
|
||||
@ -76,4 +77,14 @@ public class MultiDistinctCount extends AggregateFunction
|
||||
public List<FunctionSignature> getSignatures() {
|
||||
return SIGNATURES;
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean mustUseMultiDistinctAgg() {
|
||||
return mustUseMultiDistinctAgg;
|
||||
}
|
||||
|
||||
@Override
|
||||
public Expression withMustUseMultiDistinctAgg(boolean mustUseMultiDistinctAgg) {
|
||||
return new MultiDistinctCount(mustUseMultiDistinctAgg, false, children);
|
||||
}
|
||||
}
|
||||
|
||||
@ -37,7 +37,6 @@ import java.util.List;
|
||||
/** MultiDistinctGroupConcat */
|
||||
public class MultiDistinctGroupConcat extends NullableAggregateFunction
|
||||
implements ExplicitlyCastableSignature, MultiDistinction {
|
||||
|
||||
public static final List<FunctionSignature> SIGNATURES = ImmutableList.of(
|
||||
FunctionSignature.ret(VarcharType.SYSTEM_DEFAULT).args(VarcharType.SYSTEM_DEFAULT),
|
||||
FunctionSignature.ret(VarcharType.SYSTEM_DEFAULT).varArgs(VarcharType.SYSTEM_DEFAULT,
|
||||
@ -57,49 +56,31 @@ public class MultiDistinctGroupConcat extends NullableAggregateFunction
|
||||
FunctionSignature.ret(CharType.SYSTEM_DEFAULT).varArgs(CharType.SYSTEM_DEFAULT,
|
||||
CharType.SYSTEM_DEFAULT, AnyDataType.INSTANCE_WITHOUT_INDEX));
|
||||
|
||||
private final int nonOrderArguments;
|
||||
private final boolean mustUseMultiDistinctAgg;
|
||||
|
||||
/**
|
||||
* constructor with 1 argument.
|
||||
* constructor with 1 argument with other arguments.
|
||||
*/
|
||||
public MultiDistinctGroupConcat(boolean alwaysNullable, Expression arg,
|
||||
OrderExpression... orders) {
|
||||
super("multi_distinct_group_concat", true, alwaysNullable,
|
||||
ExpressionUtils.mergeArguments(arg, orders));
|
||||
this.nonOrderArguments = 1;
|
||||
public MultiDistinctGroupConcat(Expression arg, Expression... others) {
|
||||
this(false, arg, others);
|
||||
}
|
||||
|
||||
/**
|
||||
* constructor with 1 argument.
|
||||
* constructor with argument list.
|
||||
*/
|
||||
public MultiDistinctGroupConcat(Expression arg, OrderExpression... orders) {
|
||||
this(false, arg, orders);
|
||||
public MultiDistinctGroupConcat(boolean alwaysNullable, List<Expression> args) {
|
||||
this(false, alwaysNullable, args);
|
||||
}
|
||||
|
||||
/**
|
||||
* constructor with 2 arguments.
|
||||
*/
|
||||
public MultiDistinctGroupConcat(boolean alwaysNullable, Expression arg0,
|
||||
Expression arg1, OrderExpression... orders) {
|
||||
super("multi_distinct_group_concat", true, alwaysNullable,
|
||||
ExpressionUtils.mergeArguments(arg0, arg1, orders));
|
||||
this.nonOrderArguments = 2;
|
||||
private MultiDistinctGroupConcat(boolean alwaysNullable, Expression arg,
|
||||
Expression... others) {
|
||||
this(alwaysNullable, ExpressionUtils.mergeArguments(arg, others));
|
||||
}
|
||||
|
||||
/**
|
||||
* constructor with 2 arguments.
|
||||
*/
|
||||
public MultiDistinctGroupConcat(Expression arg0, Expression arg1, OrderExpression... orders) {
|
||||
this(false, arg0, arg1, orders);
|
||||
}
|
||||
|
||||
/**
|
||||
* constructor for always nullable.
|
||||
*/
|
||||
public MultiDistinctGroupConcat(boolean alwaysNullable, int nonOrderArguments,
|
||||
List<Expression> args) {
|
||||
super("multi_distinct_group_concat", true, alwaysNullable, args);
|
||||
this.nonOrderArguments = nonOrderArguments;
|
||||
private MultiDistinctGroupConcat(boolean mustUseMultiDistinctAgg, boolean alwaysNullable, List<Expression> args) {
|
||||
super("multi_distinct_group_concat", false, alwaysNullable, args);
|
||||
checkArguments(children);
|
||||
this.mustUseMultiDistinctAgg = mustUseMultiDistinctAgg;
|
||||
}
|
||||
|
||||
@Override
|
||||
@ -110,7 +91,7 @@ public class MultiDistinctGroupConcat extends NullableAggregateFunction
|
||||
|
||||
@Override
|
||||
public MultiDistinctGroupConcat withAlwaysNullable(boolean alwaysNullable) {
|
||||
return new MultiDistinctGroupConcat(alwaysNullable, nonOrderArguments, children);
|
||||
return new MultiDistinctGroupConcat(mustUseMultiDistinctAgg, alwaysNullable, children);
|
||||
}
|
||||
|
||||
/**
|
||||
@ -118,7 +99,22 @@ public class MultiDistinctGroupConcat extends NullableAggregateFunction
|
||||
*/
|
||||
@Override
|
||||
public MultiDistinctGroupConcat withDistinctAndChildren(boolean distinct, List<Expression> children) {
|
||||
Preconditions.checkArgument(children().size() >= 1);
|
||||
checkArguments(children);
|
||||
return new MultiDistinctGroupConcat(mustUseMultiDistinctAgg, alwaysNullable, children);
|
||||
}
|
||||
|
||||
@Override
|
||||
public <R, C> R accept(ExpressionVisitor<R, C> visitor, C context) {
|
||||
return visitor.visitMultiDistinctGroupConcat(this, context);
|
||||
}
|
||||
|
||||
@Override
|
||||
public List<FunctionSignature> getSignatures() {
|
||||
return SIGNATURES;
|
||||
}
|
||||
|
||||
private void checkArguments(List<Expression> children) {
|
||||
Preconditions.checkArgument(children().size() >= 1, "children's size should >= 1");
|
||||
boolean foundOrderExpr = false;
|
||||
int firstOrderExrIndex = 0;
|
||||
for (int i = 0; i < children.size(); i++) {
|
||||
@ -133,26 +129,19 @@ public class MultiDistinctGroupConcat extends NullableAggregateFunction
|
||||
}
|
||||
}
|
||||
|
||||
List<OrderExpression> orders = (List) children.subList(firstOrderExrIndex, children.size());
|
||||
if (firstOrderExrIndex == 1) {
|
||||
return new MultiDistinctGroupConcat(alwaysNullable, children.get(0),
|
||||
orders.toArray(new OrderExpression[0]));
|
||||
} else if (firstOrderExrIndex == 2) {
|
||||
return new MultiDistinctGroupConcat(alwaysNullable, children.get(0),
|
||||
children.get(1), orders.toArray(new OrderExpression[0]));
|
||||
} else {
|
||||
if (firstOrderExrIndex > 2) {
|
||||
throw new AnalysisException(
|
||||
"multi_distinct_group_concat requires one or two parameters: " + children);
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public <R, C> R accept(ExpressionVisitor<R, C> visitor, C context) {
|
||||
return visitor.visitMultiDistinctGroupConcat(this, context);
|
||||
public boolean mustUseMultiDistinctAgg() {
|
||||
return mustUseMultiDistinctAgg || children.stream().anyMatch(OrderExpression.class::isInstance);
|
||||
}
|
||||
|
||||
@Override
|
||||
public List<FunctionSignature> getSignatures() {
|
||||
return SIGNATURES;
|
||||
public Expression withMustUseMultiDistinctAgg(boolean mustUseMultiDistinctAgg) {
|
||||
return new MultiDistinctGroupConcat(mustUseMultiDistinctAgg, alwaysNullable, children);
|
||||
}
|
||||
}
|
||||
|
||||
@ -43,16 +43,24 @@ public class MultiDistinctSum extends NullableAggregateFunction implements Unary
|
||||
FunctionSignature.ret(BigIntType.INSTANCE).varArgs(LargeIntType.INSTANCE)
|
||||
);
|
||||
|
||||
private final boolean mustUseMultiDistinctAgg;
|
||||
|
||||
public MultiDistinctSum(Expression arg0) {
|
||||
super("multi_distinct_sum", true, false, arg0);
|
||||
this(false, arg0);
|
||||
}
|
||||
|
||||
public MultiDistinctSum(boolean distinct, Expression arg0) {
|
||||
super("multi_distinct_sum", true, false, arg0);
|
||||
this(false, false, arg0);
|
||||
}
|
||||
|
||||
public MultiDistinctSum(boolean distinct, boolean alwaysNullable, Expression arg0) {
|
||||
super("multi_distinct_sum", true, alwaysNullable, arg0);
|
||||
this(false, false, alwaysNullable, arg0);
|
||||
}
|
||||
|
||||
private MultiDistinctSum(boolean mustUseMultiDistinctAgg, boolean distinct,
|
||||
boolean alwaysNullable, Expression arg0) {
|
||||
super("multi_distinct_sum", false, alwaysNullable, arg0);
|
||||
this.mustUseMultiDistinctAgg = mustUseMultiDistinctAgg;
|
||||
}
|
||||
|
||||
@Override
|
||||
@ -74,17 +82,27 @@ public class MultiDistinctSum extends NullableAggregateFunction implements Unary
|
||||
|
||||
@Override
|
||||
public NullableAggregateFunction withAlwaysNullable(boolean alwaysNullable) {
|
||||
return new MultiDistinctSum(distinct, alwaysNullable, children.get(0));
|
||||
return new MultiDistinctSum(mustUseMultiDistinctAgg, distinct, alwaysNullable, children.get(0));
|
||||
}
|
||||
|
||||
@Override
|
||||
public MultiDistinctSum withDistinctAndChildren(boolean distinct, List<Expression> children) {
|
||||
Preconditions.checkArgument(children.size() == 1);
|
||||
return new MultiDistinctSum(distinct, children.get(0));
|
||||
return new MultiDistinctSum(mustUseMultiDistinctAgg, distinct, alwaysNullable, children.get(0));
|
||||
}
|
||||
|
||||
@Override
|
||||
public <R, C> R accept(ExpressionVisitor<R, C> visitor, C context) {
|
||||
return visitor.visitMultiDistinctSum(this, context);
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean mustUseMultiDistinctAgg() {
|
||||
return mustUseMultiDistinctAgg;
|
||||
}
|
||||
|
||||
@Override
|
||||
public Expression withMustUseMultiDistinctAgg(boolean mustUseMultiDistinctAgg) {
|
||||
return new MultiDistinctSum(mustUseMultiDistinctAgg, false, alwaysNullable, children.get(0));
|
||||
}
|
||||
}
|
||||
|
||||
@ -44,12 +44,19 @@ public class MultiDistinctSum0 extends AggregateFunction implements UnaryExpress
|
||||
FunctionSignature.ret(BigIntType.INSTANCE).varArgs(LargeIntType.INSTANCE)
|
||||
);
|
||||
|
||||
private final boolean mustUseMultiDistinctAgg;
|
||||
|
||||
public MultiDistinctSum0(Expression arg0) {
|
||||
super("multi_distinct_sum0", true, arg0);
|
||||
this(false, arg0);
|
||||
}
|
||||
|
||||
public MultiDistinctSum0(boolean distinct, Expression arg0) {
|
||||
super("multi_distinct_sum0", true, arg0);
|
||||
this(false, false, arg0);
|
||||
}
|
||||
|
||||
private MultiDistinctSum0(boolean mustUseMultiDistinctAgg, boolean distinct, Expression arg0) {
|
||||
super("multi_distinct_sum0", false, arg0);
|
||||
this.mustUseMultiDistinctAgg = mustUseMultiDistinctAgg;
|
||||
}
|
||||
|
||||
@Override
|
||||
@ -72,11 +79,21 @@ public class MultiDistinctSum0 extends AggregateFunction implements UnaryExpress
|
||||
@Override
|
||||
public MultiDistinctSum0 withDistinctAndChildren(boolean distinct, List<Expression> children) {
|
||||
Preconditions.checkArgument(children.size() == 1);
|
||||
return new MultiDistinctSum0(distinct, children.get(0));
|
||||
return new MultiDistinctSum0(mustUseMultiDistinctAgg, distinct, children.get(0));
|
||||
}
|
||||
|
||||
@Override
|
||||
public <R, C> R accept(ExpressionVisitor<R, C> visitor, C context) {
|
||||
return visitor.visitMultiDistinctSum0(this, context);
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean mustUseMultiDistinctAgg() {
|
||||
return mustUseMultiDistinctAgg;
|
||||
}
|
||||
|
||||
@Override
|
||||
public Expression withMustUseMultiDistinctAgg(boolean mustUseMultiDistinctAgg) {
|
||||
return new MultiDistinctSum0(mustUseMultiDistinctAgg, false, children.get(0));
|
||||
}
|
||||
}
|
||||
|
||||
@ -24,4 +24,5 @@ import org.apache.doris.nereids.trees.expressions.Expression;
|
||||
* base class of multi-distinct agg function
|
||||
*/
|
||||
public interface MultiDistinction extends TreeNode<Expression> {
|
||||
Expression withMustUseMultiDistinctAgg(boolean mustUseMultiDistinctAgg);
|
||||
}
|
||||
|
||||
@ -21,7 +21,9 @@ import org.apache.doris.catalog.Env;
|
||||
import org.apache.doris.catalog.FunctionRegistry;
|
||||
import org.apache.doris.catalog.FunctionSignature;
|
||||
import org.apache.doris.common.Pair;
|
||||
import org.apache.doris.nereids.exceptions.AnalysisException;
|
||||
import org.apache.doris.nereids.trees.expressions.Expression;
|
||||
import org.apache.doris.nereids.trees.expressions.OrderExpression;
|
||||
import org.apache.doris.nereids.trees.expressions.functions.AggCombinerFunctionBuilder;
|
||||
import org.apache.doris.nereids.trees.expressions.functions.AlwaysNotNullable;
|
||||
import org.apache.doris.nereids.trees.expressions.functions.BoundFunction;
|
||||
@ -55,6 +57,12 @@ public class StateCombinator extends ScalarFunction
|
||||
*/
|
||||
public StateCombinator(List<Expression> arguments, AggregateFunction nested) {
|
||||
super(nested.getName() + AggCombinerFunctionBuilder.STATE_SUFFIX, arguments);
|
||||
for (Expression arg : arguments) {
|
||||
if (arg instanceof OrderExpression) {
|
||||
throw new AnalysisException(String
|
||||
.format("%s_state doesn't support order by expression", nested.getName()));
|
||||
}
|
||||
}
|
||||
|
||||
this.nested = Objects.requireNonNull(nested, "nested can not be null");
|
||||
this.returnType = new AggStateType(nested.getName(), arguments.stream().map(arg -> {
|
||||
|
||||
@ -90,4 +90,14 @@ public interface Aggregate<CHILD_TYPE extends Plan> extends UnaryPlan<CHILD_TYPE
|
||||
}
|
||||
return hasDistinctArguments.get();
|
||||
}
|
||||
|
||||
/** mustUseMultiDistinctAgg */
|
||||
default boolean mustUseMultiDistinctAgg() {
|
||||
for (AggregateFunction aggregateFunction : getAggregateFunctions()) {
|
||||
if (aggregateFunction.mustUseMultiDistinctAgg()) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user