From 26f8c7e35260010da5132848e1fc80f6ec430736 Mon Sep 17 00:00:00 2001 From: morrySnow <101034200+morrySnow@users.noreply.github.com> Date: Fri, 20 Oct 2023 14:54:00 +0800 Subject: [PATCH] [fix](Nereids) collect_list and collect_set should always not null (#25592) --- .../aggregate_function_collect.h | 17 ++++++++++++----- .../apache/doris/catalog/AggregateFunction.java | 2 +- .../expression/rules/FoldConstantRuleOnFE.java | 7 ++++--- .../expressions/functions/agg/CollectList.java | 4 ++-- .../expressions/functions/agg/CollectSet.java | 4 ++-- 5 files changed, 21 insertions(+), 13 deletions(-) diff --git a/be/src/vec/aggregate_functions/aggregate_function_collect.h b/be/src/vec/aggregate_functions/aggregate_function_collect.h index 0a2e9c443f..63ff0a680b 100644 --- a/be/src/vec/aggregate_functions/aggregate_function_collect.h +++ b/be/src/vec/aggregate_functions/aggregate_function_collect.h @@ -73,8 +73,9 @@ struct AggregateFunctionCollectSetData { void merge(const SelfType& rhs) { if constexpr (HasLimit::value) { - DCHECK(max_size == -1 || max_size == rhs.max_size); - max_size = rhs.max_size; + if (max_size == -1) { + max_size = rhs.max_size; + } for (auto& rhs_elem : rhs.data_set) { if (size() >= max_size) { @@ -130,7 +131,9 @@ struct AggregateFunctionCollectSetData { void merge(const SelfType& rhs, Arena* arena) { bool inserted; Set::LookupResult it; - DCHECK(max_size == -1 || max_size == rhs.max_size); + if (max_size == -1) { + max_size = rhs.max_size; + } max_size = rhs.max_size; for (auto& rhs_elem : rhs.data_set) { @@ -193,7 +196,9 @@ struct AggregateFunctionCollectListData { void merge(const SelfType& rhs) { if constexpr (HasLimit::value) { - DCHECK(max_size == -1 || max_size == rhs.max_size); + if (max_size == -1) { + max_size = rhs.max_size; + } max_size = rhs.max_size; for (auto& rhs_elem : rhs.data) { if (size() >= max_size) { @@ -245,7 +250,9 @@ struct AggregateFunctionCollectListData { void merge(const AggregateFunctionCollectListData& rhs) { if constexpr (HasLimit::value) { - DCHECK(max_size == -1 || max_size == rhs.max_size); + if (max_size == -1) { + max_size = rhs.max_size; + } max_size = rhs.max_size; data->insert_range_from(*rhs.data, 0, diff --git a/fe/fe-core/src/main/java/org/apache/doris/catalog/AggregateFunction.java b/fe/fe-core/src/main/java/org/apache/doris/catalog/AggregateFunction.java index 2dceb302b7..1989466836 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/catalog/AggregateFunction.java +++ b/fe/fe-core/src/main/java/org/apache/doris/catalog/AggregateFunction.java @@ -55,7 +55,7 @@ public class AggregateFunction extends Function { FunctionSet.COUNT, "approx_count_distinct", "ndv", FunctionSet.BITMAP_UNION_INT, FunctionSet.BITMAP_UNION_COUNT, "ndv_no_finalize", FunctionSet.WINDOW_FUNNEL, FunctionSet.RETENTION, FunctionSet.SEQUENCE_MATCH, FunctionSet.SEQUENCE_COUNT, FunctionSet.MAP_AGG, FunctionSet.BITMAP_AGG, - FunctionSet.ARRAY_AGG); + FunctionSet.ARRAY_AGG, FunctionSet.COLLECT_LIST, FunctionSet.COLLECT_SET); public static ImmutableSet ALWAYS_NULLABLE_AGGREGATE_FUNCTION_NAME_SET = ImmutableSet.of("stddev_samp", "variance_samp", "var_samp", "percentile_approx"); diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rules/FoldConstantRuleOnFE.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rules/FoldConstantRuleOnFE.java index 321465082d..85dd0cc579 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rules/FoldConstantRuleOnFE.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rules/FoldConstantRuleOnFE.java @@ -47,7 +47,6 @@ import org.apache.doris.nereids.trees.expressions.WhenClause; import org.apache.doris.nereids.trees.expressions.functions.BoundFunction; import org.apache.doris.nereids.trees.expressions.functions.PropagateNullable; import org.apache.doris.nereids.trees.expressions.functions.agg.AggregateFunction; -import org.apache.doris.nereids.trees.expressions.functions.agg.NullableAggregateFunction; import org.apache.doris.nereids.trees.expressions.functions.scalar.Array; import org.apache.doris.nereids.trees.expressions.functions.scalar.ConnectionId; import org.apache.doris.nereids.trees.expressions.functions.scalar.CurrentCatalog; @@ -551,8 +550,10 @@ public class FoldConstantRuleOnFE extends AbstractExpressionRewriteRule { } private Optional preProcess(Expression expression) { - if (expression instanceof PropagateNullable && !(expression instanceof NullableAggregateFunction) - && argsHasNullLiteral(expression)) { + if (expression instanceof AggregateFunction) { + return Optional.of(expression); + } + if (expression instanceof PropagateNullable && argsHasNullLiteral(expression)) { return Optional.of(new NullLiteral(expression.getDataType())); } if (!allArgsIsAllLiteral(expression)) { diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/agg/CollectList.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/agg/CollectList.java index 9e02435ef5..2aef07b481 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/agg/CollectList.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/agg/CollectList.java @@ -19,8 +19,8 @@ package org.apache.doris.nereids.trees.expressions.functions.agg; import org.apache.doris.catalog.FunctionSignature; import org.apache.doris.nereids.trees.expressions.Expression; +import org.apache.doris.nereids.trees.expressions.functions.AlwaysNotNullable; import org.apache.doris.nereids.trees.expressions.functions.ExplicitlyCastableSignature; -import org.apache.doris.nereids.trees.expressions.functions.PropagateNullable; import org.apache.doris.nereids.trees.expressions.shape.UnaryExpression; import org.apache.doris.nereids.trees.expressions.visitor.ExpressionVisitor; import org.apache.doris.nereids.types.ArrayType; @@ -37,7 +37,7 @@ import java.util.List; * AggregateFunction 'collect_list'. This class is generated by GenerateFunction. */ public class CollectList extends AggregateFunction - implements UnaryExpression, ExplicitlyCastableSignature, PropagateNullable { + implements UnaryExpression, ExplicitlyCastableSignature, AlwaysNotNullable { public static final List SIGNATURES = ImmutableList.of( FunctionSignature.ret(ArrayType.of(new FollowToAnyDataType(0))).args(new AnyDataType(0)), diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/agg/CollectSet.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/agg/CollectSet.java index 57af28a957..5eeab663fd 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/agg/CollectSet.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/agg/CollectSet.java @@ -19,8 +19,8 @@ package org.apache.doris.nereids.trees.expressions.functions.agg; import org.apache.doris.catalog.FunctionSignature; import org.apache.doris.nereids.trees.expressions.Expression; +import org.apache.doris.nereids.trees.expressions.functions.AlwaysNotNullable; import org.apache.doris.nereids.trees.expressions.functions.ExplicitlyCastableSignature; -import org.apache.doris.nereids.trees.expressions.functions.PropagateNullable; import org.apache.doris.nereids.trees.expressions.shape.UnaryExpression; import org.apache.doris.nereids.trees.expressions.visitor.ExpressionVisitor; import org.apache.doris.nereids.types.ArrayType; @@ -37,7 +37,7 @@ import java.util.List; * AggregateFunction 'collect_set'. This class is generated by GenerateFunction. */ public class CollectSet extends AggregateFunction - implements UnaryExpression, ExplicitlyCastableSignature, PropagateNullable { + implements UnaryExpression, ExplicitlyCastableSignature, AlwaysNotNullable { public static final List SIGNATURES = ImmutableList.of( FunctionSignature.ret(ArrayType.of(new FollowToAnyDataType(0))).args(new AnyDataType(0)),