From abdf56bfa573a46da9d08b9d28019920f533e015 Mon Sep 17 00:00:00 2001 From: morrySnow <101034200+morrySnow@users.noreply.github.com> Date: Thu, 19 Jan 2023 11:22:30 +0800 Subject: [PATCH] [fix](Nereids) wrong result of group_concat with order by or null args (#16081) 1. signatures without order element are wrong 2. signature with one arg is miss 3. group_concat should be NullableAggregateFunction 4. fold constant on fe should not fold NullableAggregateFunction with null arg TODO 1. reorder rewrite rules, and then only forbid fold constant on NullableAggregateFunction with alwaysNullable == true --- .../rewrite/rules/FoldConstantRuleOnFE.java | 5 +- .../functions/agg/GroupConcat.java | 61 +++++++++++++------ .../agg/NullableAggregateFunction.java | 17 ++++++ .../visitor/AggregateFunctionVisitor.java | 4 +- .../doris/nereids/util/ExpressionUtils.java | 14 ----- .../group_concat/test_group_concat.out | 40 ++++++++++++ .../nereids_syntax_p0/group_concat.groovy | 7 +++ .../group_concat/test_group_concat.groovy | 17 ++++++ 8 files changed, 130 insertions(+), 35 deletions(-) diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rewrite/rules/FoldConstantRuleOnFE.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rewrite/rules/FoldConstantRuleOnFE.java index 9bca83f27f..afd9a029ba 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rewrite/rules/FoldConstantRuleOnFE.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rewrite/rules/FoldConstantRuleOnFE.java @@ -44,6 +44,7 @@ import org.apache.doris.nereids.trees.expressions.WhenClause; import org.apache.doris.nereids.trees.expressions.functions.BoundFunction; import org.apache.doris.nereids.trees.expressions.functions.PropagateNullable; import org.apache.doris.nereids.trees.expressions.functions.agg.AggregateFunction; +import org.apache.doris.nereids.trees.expressions.functions.agg.NullableAggregateFunction; import org.apache.doris.nereids.trees.expressions.functions.scalar.Array; import org.apache.doris.nereids.trees.expressions.functions.scalar.ConnectionId; import org.apache.doris.nereids.trees.expressions.functions.scalar.CurrentUser; @@ -79,7 +80,9 @@ public class FoldConstantRuleOnFE extends AbstractExpressionRewriteRule { } expr = rewriteChildren(expr, ctx); - if (expr instanceof PropagateNullable && argsHasNullLiteral(expr)) { + if (expr instanceof PropagateNullable + && !(expr instanceof NullableAggregateFunction) + && argsHasNullLiteral(expr)) { return new NullLiteral(expr.getDataType()); } return expr.accept(this, ctx); diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/agg/GroupConcat.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/agg/GroupConcat.java index 08c6f64d0f..d691d2108a 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/agg/GroupConcat.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/agg/GroupConcat.java @@ -22,7 +22,6 @@ import org.apache.doris.nereids.exceptions.AnalysisException; import org.apache.doris.nereids.trees.expressions.Expression; import org.apache.doris.nereids.trees.expressions.OrderExpression; import org.apache.doris.nereids.trees.expressions.functions.ExplicitlyCastableSignature; -import org.apache.doris.nereids.trees.expressions.functions.PropagateNullable; import org.apache.doris.nereids.trees.expressions.visitor.ExpressionVisitor; import org.apache.doris.nereids.types.DataType; import org.apache.doris.nereids.types.VarcharType; @@ -33,16 +32,17 @@ import com.google.common.base.Preconditions; import com.google.common.collect.ImmutableList; import java.util.List; -import java.util.stream.Collectors; /** * AggregateFunction 'group_concat'. This class is generated by GenerateFunction. */ -public class GroupConcat extends AggregateFunction - implements ExplicitlyCastableSignature, PropagateNullable { +public class GroupConcat extends NullableAggregateFunction + implements ExplicitlyCastableSignature { public static final List SIGNATURES = ImmutableList.of( - FunctionSignature.ret(VarcharType.SYSTEM_DEFAULT).varArgs(VarcharType.SYSTEM_DEFAULT), + FunctionSignature.ret(VarcharType.SYSTEM_DEFAULT).args(VarcharType.SYSTEM_DEFAULT), + FunctionSignature.ret(VarcharType.SYSTEM_DEFAULT) + .args(VarcharType.SYSTEM_DEFAULT, VarcharType.SYSTEM_DEFAULT), FunctionSignature.ret(VarcharType.SYSTEM_DEFAULT) .varArgs(VarcharType.SYSTEM_DEFAULT, AnyDataType.INSTANCE), FunctionSignature.ret(VarcharType.SYSTEM_DEFAULT) @@ -54,8 +54,8 @@ public class GroupConcat extends AggregateFunction /** * constructor with 1 argument. */ - public GroupConcat(Expression arg, OrderExpression... orders) { - super("group_concat", ExpressionUtils.mergeArguments(arg, orders)); + public GroupConcat(boolean distinct, boolean alwaysNullable, Expression arg, OrderExpression... orders) { + super("group_concat", distinct, alwaysNullable, ExpressionUtils.mergeArguments(arg, orders)); this.nonOrderArguments = 1; } @@ -63,15 +63,22 @@ public class GroupConcat extends AggregateFunction * constructor with 1 argument. */ public GroupConcat(boolean distinct, Expression arg, OrderExpression... orders) { - super("group_concat", distinct, ExpressionUtils.mergeArguments(arg, orders)); - this.nonOrderArguments = 1; + this(distinct, false, arg, orders); + } + + /** + * constructor with 1 argument, use for function search. + */ + public GroupConcat(Expression arg, OrderExpression... orders) { + this(false, arg, orders); } /** * constructor with 2 arguments. */ - public GroupConcat(Expression arg0, Expression arg1, OrderExpression... orders) { - super("group_concat", ExpressionUtils.mergeArguments(arg0, arg1, orders)); + public GroupConcat(boolean distinct, boolean alwaysNullable, + Expression arg0, Expression arg1, OrderExpression... orders) { + super("group_concat", distinct, alwaysNullable, ExpressionUtils.mergeArguments(arg0, arg1, orders)); this.nonOrderArguments = 2; } @@ -79,8 +86,22 @@ public class GroupConcat extends AggregateFunction * constructor with 2 arguments. */ public GroupConcat(boolean distinct, Expression arg0, Expression arg1, OrderExpression... orders) { - super("group_concat", distinct, ExpressionUtils.mergeArguments(arg0, arg1, orders)); - this.nonOrderArguments = 2; + this(distinct, false, arg0, arg1, orders); + } + + /** + * constructor with 2 arguments, use for function search. + */ + public GroupConcat(Expression arg0, Expression arg1, OrderExpression... orders) { + this(false, arg0, arg1, orders); + } + + /** + * constructor for always nullable. + */ + public GroupConcat(boolean distinct, boolean alwaysNullable, int nonOrderArguments, List args) { + super("group_concat", distinct, alwaysNullable, args); + this.nonOrderArguments = nonOrderArguments; } @Override @@ -100,15 +121,17 @@ public class GroupConcat extends AggregateFunction } } + @Override + public GroupConcat withAlwaysNullable(boolean alwaysNullable) { + return new GroupConcat(distinct, alwaysNullable, nonOrderArguments, children); + } + /** * withDistinctAndChildren. */ @Override public GroupConcat withDistinctAndChildren(boolean distinct, List children) { Preconditions.checkArgument(children().size() >= 1); - - // in type coercion, the orderExpression will be the child of cast, we should push down cast. - children = children.stream().map(ExpressionUtils::pushDownCastInOrderExpression).collect(Collectors.toList()); boolean foundOrderExpr = false; int firstOrderExrIndex = 0; for (int i = 0; i < children.size(); i++) { @@ -124,9 +147,11 @@ public class GroupConcat extends AggregateFunction List orders = (List) children.subList(firstOrderExrIndex, children.size()); if (firstOrderExrIndex == 1) { - return new GroupConcat(distinct, children.get(0), orders.toArray(new OrderExpression[0])); + return new GroupConcat(distinct, alwaysNullable, + children.get(0), orders.toArray(new OrderExpression[0])); } else if (firstOrderExrIndex == 2) { - return new GroupConcat(distinct, children.get(0), children.get(1), orders.toArray(new OrderExpression[0])); + return new GroupConcat(distinct, alwaysNullable, + children.get(0), children.get(1), orders.toArray(new OrderExpression[0])); } else { throw new AnalysisException("group_concat requires one or two parameters: " + children); } diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/agg/NullableAggregateFunction.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/agg/NullableAggregateFunction.java index 21484cacaa..bdb2b539a0 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/agg/NullableAggregateFunction.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/agg/NullableAggregateFunction.java @@ -21,6 +21,7 @@ import org.apache.doris.nereids.trees.expressions.Expression; import org.apache.doris.nereids.trees.expressions.functions.AlwaysNullable; import org.apache.doris.nereids.trees.expressions.functions.PropagateNullable; +import java.util.List; import java.util.Objects; /** @@ -35,17 +36,33 @@ public abstract class NullableAggregateFunction extends AggregateFunction implem this.alwaysNullable = false; } + protected NullableAggregateFunction(String name, List expressions) { + super(name, false, expressions); + this.alwaysNullable = false; + } + protected NullableAggregateFunction(String name, boolean distinct, Expression ...expressions) { super(name, distinct, expressions); this.alwaysNullable = false; } + protected NullableAggregateFunction(String name, boolean distinct, List expressions) { + super(name, distinct, expressions); + this.alwaysNullable = false; + } + protected NullableAggregateFunction(String name, boolean distinct, boolean alwaysNullable, Expression ...expressions) { super(name, distinct, expressions); this.alwaysNullable = alwaysNullable; } + protected NullableAggregateFunction(String name, boolean distinct, boolean alwaysNullable, + List expressions) { + super(name, distinct, expressions); + this.alwaysNullable = alwaysNullable; + } + @Override public boolean nullable() { return alwaysNullable ? AlwaysNullable.super.nullable() : PropagateNullable.super.nullable(); diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/visitor/AggregateFunctionVisitor.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/visitor/AggregateFunctionVisitor.java index 0ee8f2e45a..9c1b46e292 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/visitor/AggregateFunctionVisitor.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/visitor/AggregateFunctionVisitor.java @@ -135,11 +135,11 @@ public interface AggregateFunctionVisitor { } default R visitGroupBitmapXor(GroupBitmapXor groupBitmapXor, C context) { - return visitAggregateFunction(groupBitmapXor, context); + return visitNullableAggregateFunction(groupBitmapXor, context); } default R visitGroupConcat(GroupConcat groupConcat, C context) { - return visitAggregateFunction(groupConcat, context); + return visitNullableAggregateFunction(groupConcat, context); } default R visitHistogram(Histogram histogram, C context) { diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/util/ExpressionUtils.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/util/ExpressionUtils.java index 3490ec284f..7a7ed01f30 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/util/ExpressionUtils.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/util/ExpressionUtils.java @@ -28,7 +28,6 @@ import org.apache.doris.nereids.trees.expressions.InPredicate; import org.apache.doris.nereids.trees.expressions.IsNull; import org.apache.doris.nereids.trees.expressions.Not; import org.apache.doris.nereids.trees.expressions.Or; -import org.apache.doris.nereids.trees.expressions.OrderExpression; import org.apache.doris.nereids.trees.expressions.Slot; import org.apache.doris.nereids.trees.expressions.SlotReference; import org.apache.doris.nereids.trees.expressions.literal.BooleanLiteral; @@ -467,17 +466,4 @@ public class ExpressionUtils { .flatMap(expr -> expr.getInputSlots().stream()) .collect(ImmutableSet.toImmutableSet()); } - - /** - * cast push down in order expression - */ - public static Expression pushDownCastInOrderExpression(Expression expression) { - if (expression instanceof Cast - && ((Cast) expression).child() instanceof OrderExpression) { - Cast cast = ((Cast) expression); - OrderExpression order = ((OrderExpression) cast.child()); - return order.withChildren(cast.withChildren(order.getOrderKey().getExpr())); - } - return expression; - } } diff --git a/regression-test/data/query_p0/group_concat/test_group_concat.out b/regression-test/data/query_p0/group_concat/test_group_concat.out index 8a742347a8..9c6ff6c8b4 100644 --- a/regression-test/data/query_p0/group_concat/test_group_concat.out +++ b/regression-test/data/query_p0/group_concat/test_group_concat.out @@ -45,3 +45,43 @@ false 25699 1989 2147483647 255:1991:32767:32767 +-- !select -- +\N \N +103 255 +1001 1986, 1989 +1002 1989, 32767 +3021 1991, 1992, 32767 +5014 1985, 1991 +25699 1989 +2147483647 255, 1991, 32767, 32767 + +-- !select -- +\N \N +103 255 +1001 1986:1989 +1002 1989:32767 +3021 1991:1992:32767 +5014 1985:1991 +25699 1989 +2147483647 255:1991:32767:32767 + +-- !select -- +\N \N +103 255 +1001 1989, 1986 +1002 1989, 32767 +3021 1991, 32767, 1992 +5014 1985, 1991 +25699 1989 +2147483647 255, 1991, 32767, 32767 + +-- !select -- +\N \N +103 255 +1001 1989:1986 +1002 1989:32767 +3021 1991:32767:1992 +5014 1985:1991 +25699 1989 +2147483647 255:1991:32767:32767 + diff --git a/regression-test/suites/nereids_syntax_p0/group_concat.groovy b/regression-test/suites/nereids_syntax_p0/group_concat.groovy index b046b1a5d4..551fe37384 100644 --- a/regression-test/suites/nereids_syntax_p0/group_concat.groovy +++ b/regression-test/suites/nereids_syntax_p0/group_concat.groovy @@ -43,4 +43,11 @@ suite("group_concat") { sql "select group_concat(cast(number as string), ' : ') from numbers('number'='10')" result([["0 : 1 : 2 : 3 : 4 : 5 : 6 : 7 : 8 : 9"]]) } + + test { + sql "select group_concat(cast(number as string), NULL) from numbers('number'='10')" + result([[null]]) + } + + } diff --git a/regression-test/suites/query_p0/group_concat/test_group_concat.groovy b/regression-test/suites/query_p0/group_concat/test_group_concat.groovy index 1b4501ca8d..dcd8d5cc85 100644 --- a/regression-test/suites/query_p0/group_concat/test_group_concat.groovy +++ b/regression-test/suites/query_p0/group_concat/test_group_concat.groovy @@ -38,4 +38,21 @@ suite("test_group_concat") { SELECT abs(k3), group_concat(distinct cast(abs(k2) as char), ":" order by abs(k1), k2) FROM test_query_db.baseall group by abs(k3) order by abs(k3); """ + sql "set enable_nereids_planner=true" + sql "set enable_vectorized_engine=true" + sql "set enable_fallback_to_original_planner=false" + + qt_select """ + SELECT abs(k3), group_concat(cast(abs(k2) as varchar) order by abs(k2), k1) FROM test_query_db.baseall group by abs(k3) order by abs(k3) + """ + + qt_select """ + SELECT abs(k3), group_concat(cast(abs(k2) as varchar), ":" order by abs(k2), k1) FROM test_query_db.baseall group by abs(k3) order by abs(k3) + """ + qt_select """ + SELECT abs(k3), group_concat(distinct cast(abs(k2) as char) order by abs(k1), k2) FROM test_query_db.baseall group by abs(k3) order by abs(k3); + """ + qt_select """ + SELECT abs(k3), group_concat(distinct cast(abs(k2) as char), ":" order by abs(k1), k2) FROM test_query_db.baseall group by abs(k3) order by abs(k3); + """ }