[fix](Nereids) collect_list and collect_set should always not null (#25592)

This commit is contained in:
morrySnow
2023-10-20 14:54:00 +08:00
committed by GitHub
parent 9f31914018
commit 26f8c7e352
5 changed files with 21 additions and 13 deletions

View File

@ -73,8 +73,9 @@ struct AggregateFunctionCollectSetData {
void merge(const SelfType& rhs) {
if constexpr (HasLimit::value) {
DCHECK(max_size == -1 || max_size == rhs.max_size);
max_size = rhs.max_size;
if (max_size == -1) {
max_size = rhs.max_size;
}
for (auto& rhs_elem : rhs.data_set) {
if (size() >= max_size) {
@ -130,7 +131,9 @@ struct AggregateFunctionCollectSetData<StringRef, HasLimit> {
void merge(const SelfType& rhs, Arena* arena) {
bool inserted;
Set::LookupResult it;
DCHECK(max_size == -1 || max_size == rhs.max_size);
if (max_size == -1) {
max_size = rhs.max_size;
}
max_size = rhs.max_size;
for (auto& rhs_elem : rhs.data_set) {
@ -193,7 +196,9 @@ struct AggregateFunctionCollectListData {
void merge(const SelfType& rhs) {
if constexpr (HasLimit::value) {
DCHECK(max_size == -1 || max_size == rhs.max_size);
if (max_size == -1) {
max_size = rhs.max_size;
}
max_size = rhs.max_size;
for (auto& rhs_elem : rhs.data) {
if (size() >= max_size) {
@ -245,7 +250,9 @@ struct AggregateFunctionCollectListData<StringRef, HasLimit> {
void merge(const AggregateFunctionCollectListData& rhs) {
if constexpr (HasLimit::value) {
DCHECK(max_size == -1 || max_size == rhs.max_size);
if (max_size == -1) {
max_size = rhs.max_size;
}
max_size = rhs.max_size;
data->insert_range_from(*rhs.data, 0,

View File

@ -55,7 +55,7 @@ public class AggregateFunction extends Function {
FunctionSet.COUNT, "approx_count_distinct", "ndv", FunctionSet.BITMAP_UNION_INT,
FunctionSet.BITMAP_UNION_COUNT, "ndv_no_finalize", FunctionSet.WINDOW_FUNNEL, FunctionSet.RETENTION,
FunctionSet.SEQUENCE_MATCH, FunctionSet.SEQUENCE_COUNT, FunctionSet.MAP_AGG, FunctionSet.BITMAP_AGG,
FunctionSet.ARRAY_AGG);
FunctionSet.ARRAY_AGG, FunctionSet.COLLECT_LIST, FunctionSet.COLLECT_SET);
public static ImmutableSet<String> ALWAYS_NULLABLE_AGGREGATE_FUNCTION_NAME_SET =
ImmutableSet.of("stddev_samp", "variance_samp", "var_samp", "percentile_approx");

View File

@ -47,7 +47,6 @@ import org.apache.doris.nereids.trees.expressions.WhenClause;
import org.apache.doris.nereids.trees.expressions.functions.BoundFunction;
import org.apache.doris.nereids.trees.expressions.functions.PropagateNullable;
import org.apache.doris.nereids.trees.expressions.functions.agg.AggregateFunction;
import org.apache.doris.nereids.trees.expressions.functions.agg.NullableAggregateFunction;
import org.apache.doris.nereids.trees.expressions.functions.scalar.Array;
import org.apache.doris.nereids.trees.expressions.functions.scalar.ConnectionId;
import org.apache.doris.nereids.trees.expressions.functions.scalar.CurrentCatalog;
@ -551,8 +550,10 @@ public class FoldConstantRuleOnFE extends AbstractExpressionRewriteRule {
}
private Optional<Expression> preProcess(Expression expression) {
if (expression instanceof PropagateNullable && !(expression instanceof NullableAggregateFunction)
&& argsHasNullLiteral(expression)) {
if (expression instanceof AggregateFunction) {
return Optional.of(expression);
}
if (expression instanceof PropagateNullable && argsHasNullLiteral(expression)) {
return Optional.of(new NullLiteral(expression.getDataType()));
}
if (!allArgsIsAllLiteral(expression)) {

View File

@ -19,8 +19,8 @@ package org.apache.doris.nereids.trees.expressions.functions.agg;
import org.apache.doris.catalog.FunctionSignature;
import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.expressions.functions.AlwaysNotNullable;
import org.apache.doris.nereids.trees.expressions.functions.ExplicitlyCastableSignature;
import org.apache.doris.nereids.trees.expressions.functions.PropagateNullable;
import org.apache.doris.nereids.trees.expressions.shape.UnaryExpression;
import org.apache.doris.nereids.trees.expressions.visitor.ExpressionVisitor;
import org.apache.doris.nereids.types.ArrayType;
@ -37,7 +37,7 @@ import java.util.List;
* AggregateFunction 'collect_list'. This class is generated by GenerateFunction.
*/
public class CollectList extends AggregateFunction
implements UnaryExpression, ExplicitlyCastableSignature, PropagateNullable {
implements UnaryExpression, ExplicitlyCastableSignature, AlwaysNotNullable {
public static final List<FunctionSignature> SIGNATURES = ImmutableList.of(
FunctionSignature.ret(ArrayType.of(new FollowToAnyDataType(0))).args(new AnyDataType(0)),

View File

@ -19,8 +19,8 @@ package org.apache.doris.nereids.trees.expressions.functions.agg;
import org.apache.doris.catalog.FunctionSignature;
import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.expressions.functions.AlwaysNotNullable;
import org.apache.doris.nereids.trees.expressions.functions.ExplicitlyCastableSignature;
import org.apache.doris.nereids.trees.expressions.functions.PropagateNullable;
import org.apache.doris.nereids.trees.expressions.shape.UnaryExpression;
import org.apache.doris.nereids.trees.expressions.visitor.ExpressionVisitor;
import org.apache.doris.nereids.types.ArrayType;
@ -37,7 +37,7 @@ import java.util.List;
* AggregateFunction 'collect_set'. This class is generated by GenerateFunction.
*/
public class CollectSet extends AggregateFunction
implements UnaryExpression, ExplicitlyCastableSignature, PropagateNullable {
implements UnaryExpression, ExplicitlyCastableSignature, AlwaysNotNullable {
public static final List<FunctionSignature> SIGNATURES = ImmutableList.of(
FunctionSignature.ret(ArrayType.of(new FollowToAnyDataType(0))).args(new AnyDataType(0)),