[fix](Nereids) collect_list and collect_set should always not null (#25592)
This commit is contained in:
@ -73,8 +73,9 @@ struct AggregateFunctionCollectSetData {
|
||||
|
||||
void merge(const SelfType& rhs) {
|
||||
if constexpr (HasLimit::value) {
|
||||
DCHECK(max_size == -1 || max_size == rhs.max_size);
|
||||
max_size = rhs.max_size;
|
||||
if (max_size == -1) {
|
||||
max_size = rhs.max_size;
|
||||
}
|
||||
|
||||
for (auto& rhs_elem : rhs.data_set) {
|
||||
if (size() >= max_size) {
|
||||
@ -130,7 +131,9 @@ struct AggregateFunctionCollectSetData<StringRef, HasLimit> {
|
||||
void merge(const SelfType& rhs, Arena* arena) {
|
||||
bool inserted;
|
||||
Set::LookupResult it;
|
||||
DCHECK(max_size == -1 || max_size == rhs.max_size);
|
||||
if (max_size == -1) {
|
||||
max_size = rhs.max_size;
|
||||
}
|
||||
max_size = rhs.max_size;
|
||||
|
||||
for (auto& rhs_elem : rhs.data_set) {
|
||||
@ -193,7 +196,9 @@ struct AggregateFunctionCollectListData {
|
||||
|
||||
void merge(const SelfType& rhs) {
|
||||
if constexpr (HasLimit::value) {
|
||||
DCHECK(max_size == -1 || max_size == rhs.max_size);
|
||||
if (max_size == -1) {
|
||||
max_size = rhs.max_size;
|
||||
}
|
||||
max_size = rhs.max_size;
|
||||
for (auto& rhs_elem : rhs.data) {
|
||||
if (size() >= max_size) {
|
||||
@ -245,7 +250,9 @@ struct AggregateFunctionCollectListData<StringRef, HasLimit> {
|
||||
|
||||
void merge(const AggregateFunctionCollectListData& rhs) {
|
||||
if constexpr (HasLimit::value) {
|
||||
DCHECK(max_size == -1 || max_size == rhs.max_size);
|
||||
if (max_size == -1) {
|
||||
max_size = rhs.max_size;
|
||||
}
|
||||
max_size = rhs.max_size;
|
||||
|
||||
data->insert_range_from(*rhs.data, 0,
|
||||
|
||||
@ -55,7 +55,7 @@ public class AggregateFunction extends Function {
|
||||
FunctionSet.COUNT, "approx_count_distinct", "ndv", FunctionSet.BITMAP_UNION_INT,
|
||||
FunctionSet.BITMAP_UNION_COUNT, "ndv_no_finalize", FunctionSet.WINDOW_FUNNEL, FunctionSet.RETENTION,
|
||||
FunctionSet.SEQUENCE_MATCH, FunctionSet.SEQUENCE_COUNT, FunctionSet.MAP_AGG, FunctionSet.BITMAP_AGG,
|
||||
FunctionSet.ARRAY_AGG);
|
||||
FunctionSet.ARRAY_AGG, FunctionSet.COLLECT_LIST, FunctionSet.COLLECT_SET);
|
||||
|
||||
public static ImmutableSet<String> ALWAYS_NULLABLE_AGGREGATE_FUNCTION_NAME_SET =
|
||||
ImmutableSet.of("stddev_samp", "variance_samp", "var_samp", "percentile_approx");
|
||||
|
||||
@ -47,7 +47,6 @@ import org.apache.doris.nereids.trees.expressions.WhenClause;
|
||||
import org.apache.doris.nereids.trees.expressions.functions.BoundFunction;
|
||||
import org.apache.doris.nereids.trees.expressions.functions.PropagateNullable;
|
||||
import org.apache.doris.nereids.trees.expressions.functions.agg.AggregateFunction;
|
||||
import org.apache.doris.nereids.trees.expressions.functions.agg.NullableAggregateFunction;
|
||||
import org.apache.doris.nereids.trees.expressions.functions.scalar.Array;
|
||||
import org.apache.doris.nereids.trees.expressions.functions.scalar.ConnectionId;
|
||||
import org.apache.doris.nereids.trees.expressions.functions.scalar.CurrentCatalog;
|
||||
@ -551,8 +550,10 @@ public class FoldConstantRuleOnFE extends AbstractExpressionRewriteRule {
|
||||
}
|
||||
|
||||
private Optional<Expression> preProcess(Expression expression) {
|
||||
if (expression instanceof PropagateNullable && !(expression instanceof NullableAggregateFunction)
|
||||
&& argsHasNullLiteral(expression)) {
|
||||
if (expression instanceof AggregateFunction) {
|
||||
return Optional.of(expression);
|
||||
}
|
||||
if (expression instanceof PropagateNullable && argsHasNullLiteral(expression)) {
|
||||
return Optional.of(new NullLiteral(expression.getDataType()));
|
||||
}
|
||||
if (!allArgsIsAllLiteral(expression)) {
|
||||
|
||||
@ -19,8 +19,8 @@ package org.apache.doris.nereids.trees.expressions.functions.agg;
|
||||
|
||||
import org.apache.doris.catalog.FunctionSignature;
|
||||
import org.apache.doris.nereids.trees.expressions.Expression;
|
||||
import org.apache.doris.nereids.trees.expressions.functions.AlwaysNotNullable;
|
||||
import org.apache.doris.nereids.trees.expressions.functions.ExplicitlyCastableSignature;
|
||||
import org.apache.doris.nereids.trees.expressions.functions.PropagateNullable;
|
||||
import org.apache.doris.nereids.trees.expressions.shape.UnaryExpression;
|
||||
import org.apache.doris.nereids.trees.expressions.visitor.ExpressionVisitor;
|
||||
import org.apache.doris.nereids.types.ArrayType;
|
||||
@ -37,7 +37,7 @@ import java.util.List;
|
||||
* AggregateFunction 'collect_list'. This class is generated by GenerateFunction.
|
||||
*/
|
||||
public class CollectList extends AggregateFunction
|
||||
implements UnaryExpression, ExplicitlyCastableSignature, PropagateNullable {
|
||||
implements UnaryExpression, ExplicitlyCastableSignature, AlwaysNotNullable {
|
||||
|
||||
public static final List<FunctionSignature> SIGNATURES = ImmutableList.of(
|
||||
FunctionSignature.ret(ArrayType.of(new FollowToAnyDataType(0))).args(new AnyDataType(0)),
|
||||
|
||||
@ -19,8 +19,8 @@ package org.apache.doris.nereids.trees.expressions.functions.agg;
|
||||
|
||||
import org.apache.doris.catalog.FunctionSignature;
|
||||
import org.apache.doris.nereids.trees.expressions.Expression;
|
||||
import org.apache.doris.nereids.trees.expressions.functions.AlwaysNotNullable;
|
||||
import org.apache.doris.nereids.trees.expressions.functions.ExplicitlyCastableSignature;
|
||||
import org.apache.doris.nereids.trees.expressions.functions.PropagateNullable;
|
||||
import org.apache.doris.nereids.trees.expressions.shape.UnaryExpression;
|
||||
import org.apache.doris.nereids.trees.expressions.visitor.ExpressionVisitor;
|
||||
import org.apache.doris.nereids.types.ArrayType;
|
||||
@ -37,7 +37,7 @@ import java.util.List;
|
||||
* AggregateFunction 'collect_set'. This class is generated by GenerateFunction.
|
||||
*/
|
||||
public class CollectSet extends AggregateFunction
|
||||
implements UnaryExpression, ExplicitlyCastableSignature, PropagateNullable {
|
||||
implements UnaryExpression, ExplicitlyCastableSignature, AlwaysNotNullable {
|
||||
|
||||
public static final List<FunctionSignature> SIGNATURES = ImmutableList.of(
|
||||
FunctionSignature.ret(ArrayType.of(new FollowToAnyDataType(0))).args(new AnyDataType(0)),
|
||||
|
||||
Reference in New Issue
Block a user