[fix](Nereids) wrong result of group_concat with order by or null args (#16081)

1. signatures without order element are wrong
2. signature with one arg is miss
3. group_concat should be NullableAggregateFunction
4. fold constant on fe should not fold NullableAggregateFunction with null arg

TODO
1. reorder rewrite rules, and then only forbid fold constant on NullableAggregateFunction with alwaysNullable == true
This commit is contained in:
morrySnow
2023-01-19 11:22:30 +08:00
committed by GitHub
parent e846e8c0fd
commit abdf56bfa5
8 changed files with 130 additions and 35 deletions

View File

@ -44,6 +44,7 @@ import org.apache.doris.nereids.trees.expressions.WhenClause;
import org.apache.doris.nereids.trees.expressions.functions.BoundFunction;
import org.apache.doris.nereids.trees.expressions.functions.PropagateNullable;
import org.apache.doris.nereids.trees.expressions.functions.agg.AggregateFunction;
import org.apache.doris.nereids.trees.expressions.functions.agg.NullableAggregateFunction;
import org.apache.doris.nereids.trees.expressions.functions.scalar.Array;
import org.apache.doris.nereids.trees.expressions.functions.scalar.ConnectionId;
import org.apache.doris.nereids.trees.expressions.functions.scalar.CurrentUser;
@ -79,7 +80,9 @@ public class FoldConstantRuleOnFE extends AbstractExpressionRewriteRule {
}
expr = rewriteChildren(expr, ctx);
if (expr instanceof PropagateNullable && argsHasNullLiteral(expr)) {
if (expr instanceof PropagateNullable
&& !(expr instanceof NullableAggregateFunction)
&& argsHasNullLiteral(expr)) {
return new NullLiteral(expr.getDataType());
}
return expr.accept(this, ctx);

View File

@ -22,7 +22,6 @@ import org.apache.doris.nereids.exceptions.AnalysisException;
import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.expressions.OrderExpression;
import org.apache.doris.nereids.trees.expressions.functions.ExplicitlyCastableSignature;
import org.apache.doris.nereids.trees.expressions.functions.PropagateNullable;
import org.apache.doris.nereids.trees.expressions.visitor.ExpressionVisitor;
import org.apache.doris.nereids.types.DataType;
import org.apache.doris.nereids.types.VarcharType;
@ -33,16 +32,17 @@ import com.google.common.base.Preconditions;
import com.google.common.collect.ImmutableList;
import java.util.List;
import java.util.stream.Collectors;
/**
* AggregateFunction 'group_concat'. This class is generated by GenerateFunction.
*/
public class GroupConcat extends AggregateFunction
implements ExplicitlyCastableSignature, PropagateNullable {
public class GroupConcat extends NullableAggregateFunction
implements ExplicitlyCastableSignature {
public static final List<FunctionSignature> SIGNATURES = ImmutableList.of(
FunctionSignature.ret(VarcharType.SYSTEM_DEFAULT).varArgs(VarcharType.SYSTEM_DEFAULT),
FunctionSignature.ret(VarcharType.SYSTEM_DEFAULT).args(VarcharType.SYSTEM_DEFAULT),
FunctionSignature.ret(VarcharType.SYSTEM_DEFAULT)
.args(VarcharType.SYSTEM_DEFAULT, VarcharType.SYSTEM_DEFAULT),
FunctionSignature.ret(VarcharType.SYSTEM_DEFAULT)
.varArgs(VarcharType.SYSTEM_DEFAULT, AnyDataType.INSTANCE),
FunctionSignature.ret(VarcharType.SYSTEM_DEFAULT)
@ -54,8 +54,8 @@ public class GroupConcat extends AggregateFunction
/**
* constructor with 1 argument.
*/
public GroupConcat(Expression arg, OrderExpression... orders) {
super("group_concat", ExpressionUtils.mergeArguments(arg, orders));
public GroupConcat(boolean distinct, boolean alwaysNullable, Expression arg, OrderExpression... orders) {
super("group_concat", distinct, alwaysNullable, ExpressionUtils.mergeArguments(arg, orders));
this.nonOrderArguments = 1;
}
@ -63,15 +63,22 @@ public class GroupConcat extends AggregateFunction
* constructor with 1 argument.
*/
public GroupConcat(boolean distinct, Expression arg, OrderExpression... orders) {
super("group_concat", distinct, ExpressionUtils.mergeArguments(arg, orders));
this.nonOrderArguments = 1;
this(distinct, false, arg, orders);
}
/**
* constructor with 1 argument, use for function search.
*/
public GroupConcat(Expression arg, OrderExpression... orders) {
this(false, arg, orders);
}
/**
* constructor with 2 arguments.
*/
public GroupConcat(Expression arg0, Expression arg1, OrderExpression... orders) {
super("group_concat", ExpressionUtils.mergeArguments(arg0, arg1, orders));
public GroupConcat(boolean distinct, boolean alwaysNullable,
Expression arg0, Expression arg1, OrderExpression... orders) {
super("group_concat", distinct, alwaysNullable, ExpressionUtils.mergeArguments(arg0, arg1, orders));
this.nonOrderArguments = 2;
}
@ -79,8 +86,22 @@ public class GroupConcat extends AggregateFunction
* constructor with 2 arguments.
*/
public GroupConcat(boolean distinct, Expression arg0, Expression arg1, OrderExpression... orders) {
super("group_concat", distinct, ExpressionUtils.mergeArguments(arg0, arg1, orders));
this.nonOrderArguments = 2;
this(distinct, false, arg0, arg1, orders);
}
/**
* constructor with 2 arguments, use for function search.
*/
public GroupConcat(Expression arg0, Expression arg1, OrderExpression... orders) {
this(false, arg0, arg1, orders);
}
/**
* constructor for always nullable.
*/
public GroupConcat(boolean distinct, boolean alwaysNullable, int nonOrderArguments, List<Expression> args) {
super("group_concat", distinct, alwaysNullable, args);
this.nonOrderArguments = nonOrderArguments;
}
@Override
@ -100,15 +121,17 @@ public class GroupConcat extends AggregateFunction
}
}
@Override
public GroupConcat withAlwaysNullable(boolean alwaysNullable) {
return new GroupConcat(distinct, alwaysNullable, nonOrderArguments, children);
}
/**
* withDistinctAndChildren.
*/
@Override
public GroupConcat withDistinctAndChildren(boolean distinct, List<Expression> children) {
Preconditions.checkArgument(children().size() >= 1);
// in type coercion, the orderExpression will be the child of cast, we should push down cast.
children = children.stream().map(ExpressionUtils::pushDownCastInOrderExpression).collect(Collectors.toList());
boolean foundOrderExpr = false;
int firstOrderExrIndex = 0;
for (int i = 0; i < children.size(); i++) {
@ -124,9 +147,11 @@ public class GroupConcat extends AggregateFunction
List<OrderExpression> orders = (List) children.subList(firstOrderExrIndex, children.size());
if (firstOrderExrIndex == 1) {
return new GroupConcat(distinct, children.get(0), orders.toArray(new OrderExpression[0]));
return new GroupConcat(distinct, alwaysNullable,
children.get(0), orders.toArray(new OrderExpression[0]));
} else if (firstOrderExrIndex == 2) {
return new GroupConcat(distinct, children.get(0), children.get(1), orders.toArray(new OrderExpression[0]));
return new GroupConcat(distinct, alwaysNullable,
children.get(0), children.get(1), orders.toArray(new OrderExpression[0]));
} else {
throw new AnalysisException("group_concat requires one or two parameters: " + children);
}

View File

@ -21,6 +21,7 @@ import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.expressions.functions.AlwaysNullable;
import org.apache.doris.nereids.trees.expressions.functions.PropagateNullable;
import java.util.List;
import java.util.Objects;
/**
@ -35,17 +36,33 @@ public abstract class NullableAggregateFunction extends AggregateFunction implem
this.alwaysNullable = false;
}
protected NullableAggregateFunction(String name, List<Expression> expressions) {
super(name, false, expressions);
this.alwaysNullable = false;
}
protected NullableAggregateFunction(String name, boolean distinct, Expression ...expressions) {
super(name, distinct, expressions);
this.alwaysNullable = false;
}
protected NullableAggregateFunction(String name, boolean distinct, List<Expression> expressions) {
super(name, distinct, expressions);
this.alwaysNullable = false;
}
protected NullableAggregateFunction(String name, boolean distinct, boolean alwaysNullable,
Expression ...expressions) {
super(name, distinct, expressions);
this.alwaysNullable = alwaysNullable;
}
protected NullableAggregateFunction(String name, boolean distinct, boolean alwaysNullable,
List<Expression> expressions) {
super(name, distinct, expressions);
this.alwaysNullable = alwaysNullable;
}
@Override
public boolean nullable() {
return alwaysNullable ? AlwaysNullable.super.nullable() : PropagateNullable.super.nullable();

View File

@ -135,11 +135,11 @@ public interface AggregateFunctionVisitor<R, C> {
}
default R visitGroupBitmapXor(GroupBitmapXor groupBitmapXor, C context) {
return visitAggregateFunction(groupBitmapXor, context);
return visitNullableAggregateFunction(groupBitmapXor, context);
}
default R visitGroupConcat(GroupConcat groupConcat, C context) {
return visitAggregateFunction(groupConcat, context);
return visitNullableAggregateFunction(groupConcat, context);
}
default R visitHistogram(Histogram histogram, C context) {

View File

@ -28,7 +28,6 @@ import org.apache.doris.nereids.trees.expressions.InPredicate;
import org.apache.doris.nereids.trees.expressions.IsNull;
import org.apache.doris.nereids.trees.expressions.Not;
import org.apache.doris.nereids.trees.expressions.Or;
import org.apache.doris.nereids.trees.expressions.OrderExpression;
import org.apache.doris.nereids.trees.expressions.Slot;
import org.apache.doris.nereids.trees.expressions.SlotReference;
import org.apache.doris.nereids.trees.expressions.literal.BooleanLiteral;
@ -467,17 +466,4 @@ public class ExpressionUtils {
.flatMap(expr -> expr.getInputSlots().stream())
.collect(ImmutableSet.toImmutableSet());
}
/**
* cast push down in order expression
*/
public static Expression pushDownCastInOrderExpression(Expression expression) {
if (expression instanceof Cast
&& ((Cast) expression).child() instanceof OrderExpression) {
Cast cast = ((Cast) expression);
OrderExpression order = ((OrderExpression) cast.child());
return order.withChildren(cast.withChildren(order.getOrderKey().getExpr()));
}
return expression;
}
}

View File

@ -45,3 +45,43 @@ false
25699 1989
2147483647 255:1991:32767:32767
-- !select --
\N \N
103 255
1001 1986, 1989
1002 1989, 32767
3021 1991, 1992, 32767
5014 1985, 1991
25699 1989
2147483647 255, 1991, 32767, 32767
-- !select --
\N \N
103 255
1001 1986:1989
1002 1989:32767
3021 1991:1992:32767
5014 1985:1991
25699 1989
2147483647 255:1991:32767:32767
-- !select --
\N \N
103 255
1001 1989, 1986
1002 1989, 32767
3021 1991, 32767, 1992
5014 1985, 1991
25699 1989
2147483647 255, 1991, 32767, 32767
-- !select --
\N \N
103 255
1001 1989:1986
1002 1989:32767
3021 1991:32767:1992
5014 1985:1991
25699 1989
2147483647 255:1991:32767:32767

View File

@ -43,4 +43,11 @@ suite("group_concat") {
sql "select group_concat(cast(number as string), ' : ') from numbers('number'='10')"
result([["0 : 1 : 2 : 3 : 4 : 5 : 6 : 7 : 8 : 9"]])
}
test {
sql "select group_concat(cast(number as string), NULL) from numbers('number'='10')"
result([[null]])
}
}

View File

@ -38,4 +38,21 @@ suite("test_group_concat") {
SELECT abs(k3), group_concat(distinct cast(abs(k2) as char), ":" order by abs(k1), k2) FROM test_query_db.baseall group by abs(k3) order by abs(k3);
"""
sql "set enable_nereids_planner=true"
sql "set enable_vectorized_engine=true"
sql "set enable_fallback_to_original_planner=false"
qt_select """
SELECT abs(k3), group_concat(cast(abs(k2) as varchar) order by abs(k2), k1) FROM test_query_db.baseall group by abs(k3) order by abs(k3)
"""
qt_select """
SELECT abs(k3), group_concat(cast(abs(k2) as varchar), ":" order by abs(k2), k1) FROM test_query_db.baseall group by abs(k3) order by abs(k3)
"""
qt_select """
SELECT abs(k3), group_concat(distinct cast(abs(k2) as char) order by abs(k1), k2) FROM test_query_db.baseall group by abs(k3) order by abs(k3);
"""
qt_select """
SELECT abs(k3), group_concat(distinct cast(abs(k2) as char), ":" order by abs(k1), k2) FROM test_query_db.baseall group by abs(k3) order by abs(k3);
"""
}