diff --git a/be/src/vec/aggregate_functions/aggregate_function_avg.cpp b/be/src/vec/aggregate_functions/aggregate_function_avg.cpp index 5875b831f3..7f9295d8e7 100644 --- a/be/src/vec/aggregate_functions/aggregate_function_avg.cpp +++ b/be/src/vec/aggregate_functions/aggregate_function_avg.cpp @@ -45,11 +45,23 @@ AggregateFunctionPtr create_aggregate_function_avg(const std::string& name, AggregateFunctionPtr res; DataTypePtr data_type = argument_types[0]; - if (is_decimal(data_type)) { - res.reset( - create_with_decimal_type(*data_type, *data_type, argument_types)); + if (data_type->is_nullable()) { + auto no_null_argument_types = remove_nullable(argument_types); + if (is_decimal(no_null_argument_types[0])) { + res.reset(create_with_decimal_type_null( + no_null_argument_types, parameters, *no_null_argument_types[0], + no_null_argument_types)); + } else { + res.reset(create_with_numeric_type_null( + no_null_argument_types, parameters, no_null_argument_types)); + } } else { - res.reset(create_with_numeric_type(*data_type, argument_types)); + if (is_decimal(data_type)) { + res.reset(create_with_decimal_type(*data_type, *data_type, + argument_types)); + } else { + res.reset(create_with_numeric_type(*data_type, argument_types)); + } } if (!res) { @@ -61,5 +73,6 @@ AggregateFunctionPtr create_aggregate_function_avg(const std::string& name, void register_aggregate_function_avg(AggregateFunctionSimpleFactory& factory) { factory.register_function("avg", create_aggregate_function_avg); + factory.register_function("avg", create_aggregate_function_avg, true); } } // namespace doris::vectorized diff --git a/be/src/vec/aggregate_functions/aggregate_function_count.h b/be/src/vec/aggregate_functions/aggregate_function_count.h index 960d4111cb..bc87e4bb10 100644 --- a/be/src/vec/aggregate_functions/aggregate_function_count.h +++ b/be/src/vec/aggregate_functions/aggregate_function_count.h @@ -121,7 +121,8 @@ public: DataTypePtr get_serialized_type() const override { return std::make_shared(); } }; -/// Simply count number of not-NULL values. +// TODO: Maybe AggregateFunctionCountNotNullUnary should be a subclass of AggregateFunctionCount +// Simply count number of not-NULL values. class AggregateFunctionCountNotNullUnary final : public IAggregateFunctionDataHelper { diff --git a/be/src/vec/aggregate_functions/aggregate_function_min_max.cpp b/be/src/vec/aggregate_functions/aggregate_function_min_max.cpp index 83045dbd00..a01e2ce51a 100644 --- a/be/src/vec/aggregate_functions/aggregate_function_min_max.cpp +++ b/be/src/vec/aggregate_functions/aggregate_function_min_max.cpp @@ -25,7 +25,6 @@ #include "vec/aggregate_functions/helpers.h" namespace doris::vectorized { - /// min, max, any template