[fix](Nereids) wrong result of group_concat with order by or null args (#16081)

1. signatures without order element are wrong
2. signature with one arg is miss
3. group_concat should be NullableAggregateFunction
4. fold constant on fe should not fold NullableAggregateFunction with null arg

TODO
1. reorder rewrite rules, and then only forbid fold constant on NullableAggregateFunction with alwaysNullable == true
This commit is contained in:
morrySnow
2023-01-19 11:22:30 +08:00
committed by GitHub
parent e846e8c0fd
commit abdf56bfa5
8 changed files with 130 additions and 35 deletions

View File

@ -44,6 +44,7 @@ import org.apache.doris.nereids.trees.expressions.WhenClause;
import org.apache.doris.nereids.trees.expressions.functions.BoundFunction;
import org.apache.doris.nereids.trees.expressions.functions.PropagateNullable;
import org.apache.doris.nereids.trees.expressions.functions.agg.AggregateFunction;
import org.apache.doris.nereids.trees.expressions.functions.agg.NullableAggregateFunction;
import org.apache.doris.nereids.trees.expressions.functions.scalar.Array;
import org.apache.doris.nereids.trees.expressions.functions.scalar.ConnectionId;
import org.apache.doris.nereids.trees.expressions.functions.scalar.CurrentUser;
@ -79,7 +80,9 @@ public class FoldConstantRuleOnFE extends AbstractExpressionRewriteRule {
}
expr = rewriteChildren(expr, ctx);
if (expr instanceof PropagateNullable && argsHasNullLiteral(expr)) {
if (expr instanceof PropagateNullable
&& !(expr instanceof NullableAggregateFunction)
&& argsHasNullLiteral(expr)) {
return new NullLiteral(expr.getDataType());
}
return expr.accept(this, ctx);

View File

@ -22,7 +22,6 @@ import org.apache.doris.nereids.exceptions.AnalysisException;
import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.expressions.OrderExpression;
import org.apache.doris.nereids.trees.expressions.functions.ExplicitlyCastableSignature;
import org.apache.doris.nereids.trees.expressions.functions.PropagateNullable;
import org.apache.doris.nereids.trees.expressions.visitor.ExpressionVisitor;
import org.apache.doris.nereids.types.DataType;
import org.apache.doris.nereids.types.VarcharType;
@ -33,16 +32,17 @@ import com.google.common.base.Preconditions;
import com.google.common.collect.ImmutableList;
import java.util.List;
import java.util.stream.Collectors;
/**
* AggregateFunction 'group_concat'. This class is generated by GenerateFunction.
*/
public class GroupConcat extends AggregateFunction
implements ExplicitlyCastableSignature, PropagateNullable {
public class GroupConcat extends NullableAggregateFunction
implements ExplicitlyCastableSignature {
public static final List<FunctionSignature> SIGNATURES = ImmutableList.of(
FunctionSignature.ret(VarcharType.SYSTEM_DEFAULT).varArgs(VarcharType.SYSTEM_DEFAULT),
FunctionSignature.ret(VarcharType.SYSTEM_DEFAULT).args(VarcharType.SYSTEM_DEFAULT),
FunctionSignature.ret(VarcharType.SYSTEM_DEFAULT)
.args(VarcharType.SYSTEM_DEFAULT, VarcharType.SYSTEM_DEFAULT),
FunctionSignature.ret(VarcharType.SYSTEM_DEFAULT)
.varArgs(VarcharType.SYSTEM_DEFAULT, AnyDataType.INSTANCE),
FunctionSignature.ret(VarcharType.SYSTEM_DEFAULT)
@ -54,8 +54,8 @@ public class GroupConcat extends AggregateFunction
/**
* constructor with 1 argument.
*/
public GroupConcat(Expression arg, OrderExpression... orders) {
super("group_concat", ExpressionUtils.mergeArguments(arg, orders));
public GroupConcat(boolean distinct, boolean alwaysNullable, Expression arg, OrderExpression... orders) {
super("group_concat", distinct, alwaysNullable, ExpressionUtils.mergeArguments(arg, orders));
this.nonOrderArguments = 1;
}
@ -63,15 +63,22 @@ public class GroupConcat extends AggregateFunction
* constructor with 1 argument.
*/
public GroupConcat(boolean distinct, Expression arg, OrderExpression... orders) {
super("group_concat", distinct, ExpressionUtils.mergeArguments(arg, orders));
this.nonOrderArguments = 1;
this(distinct, false, arg, orders);
}
/**
* constructor with 1 argument, use for function search.
*/
public GroupConcat(Expression arg, OrderExpression... orders) {
this(false, arg, orders);
}
/**
* constructor with 2 arguments.
*/
public GroupConcat(Expression arg0, Expression arg1, OrderExpression... orders) {
super("group_concat", ExpressionUtils.mergeArguments(arg0, arg1, orders));
public GroupConcat(boolean distinct, boolean alwaysNullable,
Expression arg0, Expression arg1, OrderExpression... orders) {
super("group_concat", distinct, alwaysNullable, ExpressionUtils.mergeArguments(arg0, arg1, orders));
this.nonOrderArguments = 2;
}
@ -79,8 +86,22 @@ public class GroupConcat extends AggregateFunction
* constructor with 2 arguments.
*/
public GroupConcat(boolean distinct, Expression arg0, Expression arg1, OrderExpression... orders) {
super("group_concat", distinct, ExpressionUtils.mergeArguments(arg0, arg1, orders));
this.nonOrderArguments = 2;
this(distinct, false, arg0, arg1, orders);
}
/**
* constructor with 2 arguments, use for function search.
*/
public GroupConcat(Expression arg0, Expression arg1, OrderExpression... orders) {
this(false, arg0, arg1, orders);
}
/**
* constructor for always nullable.
*/
public GroupConcat(boolean distinct, boolean alwaysNullable, int nonOrderArguments, List<Expression> args) {
super("group_concat", distinct, alwaysNullable, args);
this.nonOrderArguments = nonOrderArguments;
}
@Override
@ -100,15 +121,17 @@ public class GroupConcat extends AggregateFunction
}
}
@Override
public GroupConcat withAlwaysNullable(boolean alwaysNullable) {
return new GroupConcat(distinct, alwaysNullable, nonOrderArguments, children);
}
/**
* withDistinctAndChildren.
*/
@Override
public GroupConcat withDistinctAndChildren(boolean distinct, List<Expression> children) {
Preconditions.checkArgument(children().size() >= 1);
// in type coercion, the orderExpression will be the child of cast, we should push down cast.
children = children.stream().map(ExpressionUtils::pushDownCastInOrderExpression).collect(Collectors.toList());
boolean foundOrderExpr = false;
int firstOrderExrIndex = 0;
for (int i = 0; i < children.size(); i++) {
@ -124,9 +147,11 @@ public class GroupConcat extends AggregateFunction
List<OrderExpression> orders = (List) children.subList(firstOrderExrIndex, children.size());
if (firstOrderExrIndex == 1) {
return new GroupConcat(distinct, children.get(0), orders.toArray(new OrderExpression[0]));
return new GroupConcat(distinct, alwaysNullable,
children.get(0), orders.toArray(new OrderExpression[0]));
} else if (firstOrderExrIndex == 2) {
return new GroupConcat(distinct, children.get(0), children.get(1), orders.toArray(new OrderExpression[0]));
return new GroupConcat(distinct, alwaysNullable,
children.get(0), children.get(1), orders.toArray(new OrderExpression[0]));
} else {
throw new AnalysisException("group_concat requires one or two parameters: " + children);
}

View File

@ -21,6 +21,7 @@ import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.expressions.functions.AlwaysNullable;
import org.apache.doris.nereids.trees.expressions.functions.PropagateNullable;
import java.util.List;
import java.util.Objects;
/**
@ -35,17 +36,33 @@ public abstract class NullableAggregateFunction extends AggregateFunction implem
this.alwaysNullable = false;
}
protected NullableAggregateFunction(String name, List<Expression> expressions) {
super(name, false, expressions);
this.alwaysNullable = false;
}
protected NullableAggregateFunction(String name, boolean distinct, Expression ...expressions) {
super(name, distinct, expressions);
this.alwaysNullable = false;
}
protected NullableAggregateFunction(String name, boolean distinct, List<Expression> expressions) {
super(name, distinct, expressions);
this.alwaysNullable = false;
}
protected NullableAggregateFunction(String name, boolean distinct, boolean alwaysNullable,
Expression ...expressions) {
super(name, distinct, expressions);
this.alwaysNullable = alwaysNullable;
}
protected NullableAggregateFunction(String name, boolean distinct, boolean alwaysNullable,
List<Expression> expressions) {
super(name, distinct, expressions);
this.alwaysNullable = alwaysNullable;
}
@Override
public boolean nullable() {
return alwaysNullable ? AlwaysNullable.super.nullable() : PropagateNullable.super.nullable();

View File

@ -135,11 +135,11 @@ public interface AggregateFunctionVisitor<R, C> {
}
default R visitGroupBitmapXor(GroupBitmapXor groupBitmapXor, C context) {
return visitAggregateFunction(groupBitmapXor, context);
return visitNullableAggregateFunction(groupBitmapXor, context);
}
default R visitGroupConcat(GroupConcat groupConcat, C context) {
return visitAggregateFunction(groupConcat, context);
return visitNullableAggregateFunction(groupConcat, context);
}
default R visitHistogram(Histogram histogram, C context) {

View File

@ -28,7 +28,6 @@ import org.apache.doris.nereids.trees.expressions.InPredicate;
import org.apache.doris.nereids.trees.expressions.IsNull;
import org.apache.doris.nereids.trees.expressions.Not;
import org.apache.doris.nereids.trees.expressions.Or;
import org.apache.doris.nereids.trees.expressions.OrderExpression;
import org.apache.doris.nereids.trees.expressions.Slot;
import org.apache.doris.nereids.trees.expressions.SlotReference;
import org.apache.doris.nereids.trees.expressions.literal.BooleanLiteral;
@ -467,17 +466,4 @@ public class ExpressionUtils {
.flatMap(expr -> expr.getInputSlots().stream())
.collect(ImmutableSet.toImmutableSet());
}
/**
* cast push down in order expression
*/
public static Expression pushDownCastInOrderExpression(Expression expression) {
if (expression instanceof Cast
&& ((Cast) expression).child() instanceof OrderExpression) {
Cast cast = ((Cast) expression);
OrderExpression order = ((OrderExpression) cast.child());
return order.withChildren(cast.withChildren(order.getOrderKey().getExpr()));
}
return expression;
}
}