diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rewrite/rules/FoldConstantRuleOnFE.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rewrite/rules/FoldConstantRuleOnFE.java index 9bca83f27f..afd9a029ba 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rewrite/rules/FoldConstantRuleOnFE.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rewrite/rules/FoldConstantRuleOnFE.java @@ -44,6 +44,7 @@ import org.apache.doris.nereids.trees.expressions.WhenClause; import org.apache.doris.nereids.trees.expressions.functions.BoundFunction; import org.apache.doris.nereids.trees.expressions.functions.PropagateNullable; import org.apache.doris.nereids.trees.expressions.functions.agg.AggregateFunction; +import org.apache.doris.nereids.trees.expressions.functions.agg.NullableAggregateFunction; import org.apache.doris.nereids.trees.expressions.functions.scalar.Array; import org.apache.doris.nereids.trees.expressions.functions.scalar.ConnectionId; import org.apache.doris.nereids.trees.expressions.functions.scalar.CurrentUser; @@ -79,7 +80,9 @@ public class FoldConstantRuleOnFE extends AbstractExpressionRewriteRule { } expr = rewriteChildren(expr, ctx); - if (expr instanceof PropagateNullable && argsHasNullLiteral(expr)) { + if (expr instanceof PropagateNullable + && !(expr instanceof NullableAggregateFunction) + && argsHasNullLiteral(expr)) { return new NullLiteral(expr.getDataType()); } return expr.accept(this, ctx); diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/agg/GroupConcat.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/agg/GroupConcat.java index 08c6f64d0f..d691d2108a 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/agg/GroupConcat.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/agg/GroupConcat.java @@ -22,7 +22,6 @@ import org.apache.doris.nereids.exceptions.AnalysisException; import org.apache.doris.nereids.trees.expressions.Expression; import org.apache.doris.nereids.trees.expressions.OrderExpression; import org.apache.doris.nereids.trees.expressions.functions.ExplicitlyCastableSignature; -import org.apache.doris.nereids.trees.expressions.functions.PropagateNullable; import org.apache.doris.nereids.trees.expressions.visitor.ExpressionVisitor; import org.apache.doris.nereids.types.DataType; import org.apache.doris.nereids.types.VarcharType; @@ -33,16 +32,17 @@ import com.google.common.base.Preconditions; import com.google.common.collect.ImmutableList; import java.util.List; -import java.util.stream.Collectors; /** * AggregateFunction 'group_concat'. This class is generated by GenerateFunction. */ -public class GroupConcat extends AggregateFunction - implements ExplicitlyCastableSignature, PropagateNullable { +public class GroupConcat extends NullableAggregateFunction + implements ExplicitlyCastableSignature { public static final List SIGNATURES = ImmutableList.of( - FunctionSignature.ret(VarcharType.SYSTEM_DEFAULT).varArgs(VarcharType.SYSTEM_DEFAULT), + FunctionSignature.ret(VarcharType.SYSTEM_DEFAULT).args(VarcharType.SYSTEM_DEFAULT), + FunctionSignature.ret(VarcharType.SYSTEM_DEFAULT) + .args(VarcharType.SYSTEM_DEFAULT, VarcharType.SYSTEM_DEFAULT), FunctionSignature.ret(VarcharType.SYSTEM_DEFAULT) .varArgs(VarcharType.SYSTEM_DEFAULT, AnyDataType.INSTANCE), FunctionSignature.ret(VarcharType.SYSTEM_DEFAULT) @@ -54,8 +54,8 @@ public class GroupConcat extends AggregateFunction /** * constructor with 1 argument. */ - public GroupConcat(Expression arg, OrderExpression... orders) { - super("group_concat", ExpressionUtils.mergeArguments(arg, orders)); + public GroupConcat(boolean distinct, boolean alwaysNullable, Expression arg, OrderExpression... orders) { + super("group_concat", distinct, alwaysNullable, ExpressionUtils.mergeArguments(arg, orders)); this.nonOrderArguments = 1; } @@ -63,15 +63,22 @@ public class GroupConcat extends AggregateFunction * constructor with 1 argument. */ public GroupConcat(boolean distinct, Expression arg, OrderExpression... orders) { - super("group_concat", distinct, ExpressionUtils.mergeArguments(arg, orders)); - this.nonOrderArguments = 1; + this(distinct, false, arg, orders); + } + + /** + * constructor with 1 argument, use for function search. + */ + public GroupConcat(Expression arg, OrderExpression... orders) { + this(false, arg, orders); } /** * constructor with 2 arguments. */ - public GroupConcat(Expression arg0, Expression arg1, OrderExpression... orders) { - super("group_concat", ExpressionUtils.mergeArguments(arg0, arg1, orders)); + public GroupConcat(boolean distinct, boolean alwaysNullable, + Expression arg0, Expression arg1, OrderExpression... orders) { + super("group_concat", distinct, alwaysNullable, ExpressionUtils.mergeArguments(arg0, arg1, orders)); this.nonOrderArguments = 2; } @@ -79,8 +86,22 @@ public class GroupConcat extends AggregateFunction * constructor with 2 arguments. */ public GroupConcat(boolean distinct, Expression arg0, Expression arg1, OrderExpression... orders) { - super("group_concat", distinct, ExpressionUtils.mergeArguments(arg0, arg1, orders)); - this.nonOrderArguments = 2; + this(distinct, false, arg0, arg1, orders); + } + + /** + * constructor with 2 arguments, use for function search. + */ + public GroupConcat(Expression arg0, Expression arg1, OrderExpression... orders) { + this(false, arg0, arg1, orders); + } + + /** + * constructor for always nullable. + */ + public GroupConcat(boolean distinct, boolean alwaysNullable, int nonOrderArguments, List args) { + super("group_concat", distinct, alwaysNullable, args); + this.nonOrderArguments = nonOrderArguments; } @Override @@ -100,15 +121,17 @@ public class GroupConcat extends AggregateFunction } } + @Override + public GroupConcat withAlwaysNullable(boolean alwaysNullable) { + return new GroupConcat(distinct, alwaysNullable, nonOrderArguments, children); + } + /** * withDistinctAndChildren. */ @Override public GroupConcat withDistinctAndChildren(boolean distinct, List children) { Preconditions.checkArgument(children().size() >= 1); - - // in type coercion, the orderExpression will be the child of cast, we should push down cast. - children = children.stream().map(ExpressionUtils::pushDownCastInOrderExpression).collect(Collectors.toList()); boolean foundOrderExpr = false; int firstOrderExrIndex = 0; for (int i = 0; i < children.size(); i++) { @@ -124,9 +147,11 @@ public class GroupConcat extends AggregateFunction List orders = (List) children.subList(firstOrderExrIndex, children.size()); if (firstOrderExrIndex == 1) { - return new GroupConcat(distinct, children.get(0), orders.toArray(new OrderExpression[0])); + return new GroupConcat(distinct, alwaysNullable, + children.get(0), orders.toArray(new OrderExpression[0])); } else if (firstOrderExrIndex == 2) { - return new GroupConcat(distinct, children.get(0), children.get(1), orders.toArray(new OrderExpression[0])); + return new GroupConcat(distinct, alwaysNullable, + children.get(0), children.get(1), orders.toArray(new OrderExpression[0])); } else { throw new AnalysisException("group_concat requires one or two parameters: " + children); } diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/agg/NullableAggregateFunction.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/agg/NullableAggregateFunction.java index 21484cacaa..bdb2b539a0 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/agg/NullableAggregateFunction.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/agg/NullableAggregateFunction.java @@ -21,6 +21,7 @@ import org.apache.doris.nereids.trees.expressions.Expression; import org.apache.doris.nereids.trees.expressions.functions.AlwaysNullable; import org.apache.doris.nereids.trees.expressions.functions.PropagateNullable; +import java.util.List; import java.util.Objects; /** @@ -35,17 +36,33 @@ public abstract class NullableAggregateFunction extends AggregateFunction implem this.alwaysNullable = false; } + protected NullableAggregateFunction(String name, List expressions) { + super(name, false, expressions); + this.alwaysNullable = false; + } + protected NullableAggregateFunction(String name, boolean distinct, Expression ...expressions) { super(name, distinct, expressions); this.alwaysNullable = false; } + protected NullableAggregateFunction(String name, boolean distinct, List expressions) { + super(name, distinct, expressions); + this.alwaysNullable = false; + } + protected NullableAggregateFunction(String name, boolean distinct, boolean alwaysNullable, Expression ...expressions) { super(name, distinct, expressions); this.alwaysNullable = alwaysNullable; } + protected NullableAggregateFunction(String name, boolean distinct, boolean alwaysNullable, + List expressions) { + super(name, distinct, expressions); + this.alwaysNullable = alwaysNullable; + } + @Override public boolean nullable() { return alwaysNullable ? AlwaysNullable.super.nullable() : PropagateNullable.super.nullable(); diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/visitor/AggregateFunctionVisitor.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/visitor/AggregateFunctionVisitor.java index 0ee8f2e45a..9c1b46e292 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/visitor/AggregateFunctionVisitor.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/visitor/AggregateFunctionVisitor.java @@ -135,11 +135,11 @@ public interface AggregateFunctionVisitor { } default R visitGroupBitmapXor(GroupBitmapXor groupBitmapXor, C context) { - return visitAggregateFunction(groupBitmapXor, context); + return visitNullableAggregateFunction(groupBitmapXor, context); } default R visitGroupConcat(GroupConcat groupConcat, C context) { - return visitAggregateFunction(groupConcat, context); + return visitNullableAggregateFunction(groupConcat, context); } default R visitHistogram(Histogram histogram, C context) { diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/util/ExpressionUtils.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/util/ExpressionUtils.java index 3490ec284f..7a7ed01f30 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/util/ExpressionUtils.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/util/ExpressionUtils.java @@ -28,7 +28,6 @@ import org.apache.doris.nereids.trees.expressions.InPredicate; import org.apache.doris.nereids.trees.expressions.IsNull; import org.apache.doris.nereids.trees.expressions.Not; import org.apache.doris.nereids.trees.expressions.Or; -import org.apache.doris.nereids.trees.expressions.OrderExpression; import org.apache.doris.nereids.trees.expressions.Slot; import org.apache.doris.nereids.trees.expressions.SlotReference; import org.apache.doris.nereids.trees.expressions.literal.BooleanLiteral; @@ -467,17 +466,4 @@ public class ExpressionUtils { .flatMap(expr -> expr.getInputSlots().stream()) .collect(ImmutableSet.toImmutableSet()); } - - /** - * cast push down in order expression - */ - public static Expression pushDownCastInOrderExpression(Expression expression) { - if (expression instanceof Cast - && ((Cast) expression).child() instanceof OrderExpression) { - Cast cast = ((Cast) expression); - OrderExpression order = ((OrderExpression) cast.child()); - return order.withChildren(cast.withChildren(order.getOrderKey().getExpr())); - } - return expression; - } } diff --git a/regression-test/data/query_p0/group_concat/test_group_concat.out b/regression-test/data/query_p0/group_concat/test_group_concat.out index 8a742347a8..9c6ff6c8b4 100644 --- a/regression-test/data/query_p0/group_concat/test_group_concat.out +++ b/regression-test/data/query_p0/group_concat/test_group_concat.out @@ -45,3 +45,43 @@ false 25699 1989 2147483647 255:1991:32767:32767 +-- !select -- +\N \N +103 255 +1001 1986, 1989 +1002 1989, 32767 +3021 1991, 1992, 32767 +5014 1985, 1991 +25699 1989 +2147483647 255, 1991, 32767, 32767 + +-- !select -- +\N \N +103 255 +1001 1986:1989 +1002 1989:32767 +3021 1991:1992:32767 +5014 1985:1991 +25699 1989 +2147483647 255:1991:32767:32767 + +-- !select -- +\N \N +103 255 +1001 1989, 1986 +1002 1989, 32767 +3021 1991, 32767, 1992 +5014 1985, 1991 +25699 1989 +2147483647 255, 1991, 32767, 32767 + +-- !select -- +\N \N +103 255 +1001 1989:1986 +1002 1989:32767 +3021 1991:32767:1992 +5014 1985:1991 +25699 1989 +2147483647 255:1991:32767:32767 + diff --git a/regression-test/suites/nereids_syntax_p0/group_concat.groovy b/regression-test/suites/nereids_syntax_p0/group_concat.groovy index b046b1a5d4..551fe37384 100644 --- a/regression-test/suites/nereids_syntax_p0/group_concat.groovy +++ b/regression-test/suites/nereids_syntax_p0/group_concat.groovy @@ -43,4 +43,11 @@ suite("group_concat") { sql "select group_concat(cast(number as string), ' : ') from numbers('number'='10')" result([["0 : 1 : 2 : 3 : 4 : 5 : 6 : 7 : 8 : 9"]]) } + + test { + sql "select group_concat(cast(number as string), NULL) from numbers('number'='10')" + result([[null]]) + } + + } diff --git a/regression-test/suites/query_p0/group_concat/test_group_concat.groovy b/regression-test/suites/query_p0/group_concat/test_group_concat.groovy index 1b4501ca8d..dcd8d5cc85 100644 --- a/regression-test/suites/query_p0/group_concat/test_group_concat.groovy +++ b/regression-test/suites/query_p0/group_concat/test_group_concat.groovy @@ -38,4 +38,21 @@ suite("test_group_concat") { SELECT abs(k3), group_concat(distinct cast(abs(k2) as char), ":" order by abs(k1), k2) FROM test_query_db.baseall group by abs(k3) order by abs(k3); """ + sql "set enable_nereids_planner=true" + sql "set enable_vectorized_engine=true" + sql "set enable_fallback_to_original_planner=false" + + qt_select """ + SELECT abs(k3), group_concat(cast(abs(k2) as varchar) order by abs(k2), k1) FROM test_query_db.baseall group by abs(k3) order by abs(k3) + """ + + qt_select """ + SELECT abs(k3), group_concat(cast(abs(k2) as varchar), ":" order by abs(k2), k1) FROM test_query_db.baseall group by abs(k3) order by abs(k3) + """ + qt_select """ + SELECT abs(k3), group_concat(distinct cast(abs(k2) as char) order by abs(k1), k2) FROM test_query_db.baseall group by abs(k3) order by abs(k3); + """ + qt_select """ + SELECT abs(k3), group_concat(distinct cast(abs(k2) as char), ":" order by abs(k1), k2) FROM test_query_db.baseall group by abs(k3) order by abs(k3); + """ }