diff --git a/be/src/vec/aggregate_functions/aggregate_function_avg.cpp b/be/src/vec/aggregate_functions/aggregate_function_avg.cpp index 8bda389f4a..4f493c9529 100644 --- a/be/src/vec/aggregate_functions/aggregate_function_avg.cpp +++ b/be/src/vec/aggregate_functions/aggregate_function_avg.cpp @@ -52,7 +52,6 @@ AggregateFunctionPtr create_aggregate_function_avg(const std::string& name, } void register_aggregate_function_avg(AggregateFunctionSimpleFactory& factory) { - factory.register_function("avg", create_aggregate_function_avg); - factory.register_function("avg", create_aggregate_function_avg, true); + factory.register_function_both("avg", create_aggregate_function_avg); } } // namespace doris::vectorized diff --git a/be/src/vec/aggregate_functions/aggregate_function_bit.cpp b/be/src/vec/aggregate_functions/aggregate_function_bit.cpp index 379df49559..6b9be5c92c 100644 --- a/be/src/vec/aggregate_functions/aggregate_function_bit.cpp +++ b/be/src/vec/aggregate_functions/aggregate_function_bit.cpp @@ -47,21 +47,12 @@ AggregateFunctionPtr createAggregateFunctionBitwise(const std::string& name, } void register_aggregate_function_bit(AggregateFunctionSimpleFactory& factory) { - factory.register_function("group_bit_or", - createAggregateFunctionBitwise); - factory.register_function("group_bit_and", - createAggregateFunctionBitwise); - factory.register_function("group_bit_xor", - createAggregateFunctionBitwise); - - factory.register_function( - "group_bit_or", createAggregateFunctionBitwise, true); - factory.register_function("group_bit_and", - createAggregateFunctionBitwise, - true); - factory.register_function("group_bit_xor", - createAggregateFunctionBitwise, - true); + factory.register_function_both("group_bit_or", + createAggregateFunctionBitwise); + factory.register_function_both( + "group_bit_and", createAggregateFunctionBitwise); + factory.register_function_both( + "group_bit_xor", createAggregateFunctionBitwise); } } // namespace doris::vectorized \ No newline at end of file diff --git a/be/src/vec/aggregate_functions/aggregate_function_bitmap.cpp b/be/src/vec/aggregate_functions/aggregate_function_bitmap.cpp index eb9a8fb35c..e2dd7e309d 100644 --- a/be/src/vec/aggregate_functions/aggregate_function_bitmap.cpp +++ b/be/src/vec/aggregate_functions/aggregate_function_bitmap.cpp @@ -18,50 +18,45 @@ #include "vec/aggregate_functions/aggregate_function_bitmap.h" #include "vec/aggregate_functions/aggregate_function_simple_factory.h" +#include "vec/aggregate_functions/helpers.h" namespace doris::vectorized { template class AggregateFunctionTemplate> -static IAggregateFunction* createWithIntDataType(const DataTypes& argument_type) { - auto type = argument_type[0].get(); - if (type->is_nullable()) { - type = assert_cast(type)->get_nested_type().get(); - } +static IAggregateFunction* create_with_int_data_type(const DataTypes& argument_type) { + auto type = remove_nullable(argument_type[0]); WhichDataType which(type); - if (which.idx == TypeIndex::Int8) { - return new AggregateFunctionTemplate>(argument_type); - } - if (which.idx == TypeIndex::Int16) { - return new AggregateFunctionTemplate>(argument_type); - } - if (which.idx == TypeIndex::Int32) { - return new AggregateFunctionTemplate>(argument_type); - } - if (which.idx == TypeIndex::Int64) { - return new AggregateFunctionTemplate>(argument_type); +#define DISPATCH(TYPE) \ + if (which.idx == TypeIndex::TYPE) { \ + return new AggregateFunctionTemplate>(argument_type); \ } + FOR_INTEGER_TYPES(DISPATCH) +#undef DISPATCH return nullptr; } AggregateFunctionPtr create_aggregate_function_bitmap_union(const std::string& name, const DataTypes& argument_types, const bool result_is_nullable) { - return std::make_shared>( - argument_types); + return AggregateFunctionPtr( + creator_without_type::create>( + result_is_nullable, argument_types)); } AggregateFunctionPtr create_aggregate_function_bitmap_intersect(const std::string& name, const DataTypes& argument_types, const bool result_is_nullable) { - return std::make_shared>( - argument_types); + return AggregateFunctionPtr(creator_without_type::create< + AggregateFunctionBitmapOp>( + result_is_nullable, argument_types)); } AggregateFunctionPtr create_aggregate_function_group_bitmap_xor(const std::string& name, const DataTypes& argument_types, const bool result_is_nullable) { - return std::make_shared>( - argument_types); + return AggregateFunctionPtr(creator_without_type::create< + AggregateFunctionBitmapOp>( + result_is_nullable, argument_types)); } AggregateFunctionPtr create_aggregate_function_bitmap_union_count(const std::string& name, @@ -81,22 +76,19 @@ AggregateFunctionPtr create_aggregate_function_bitmap_union_int(const std::strin const bool arg_is_nullable = argument_types[0]->is_nullable(); if (arg_is_nullable) { return std::shared_ptr( - createWithIntDataType(argument_types)); + create_with_int_data_type(argument_types)); } else { return std::shared_ptr( - createWithIntDataType(argument_types)); + create_with_int_data_type(argument_types)); } } void register_aggregate_function_bitmap(AggregateFunctionSimpleFactory& factory) { - factory.register_function("bitmap_union", create_aggregate_function_bitmap_union); - factory.register_function("bitmap_intersect", create_aggregate_function_bitmap_intersect); - factory.register_function("group_bitmap_xor", create_aggregate_function_group_bitmap_xor); - factory.register_function("bitmap_union_count", create_aggregate_function_bitmap_union_count); - factory.register_function("bitmap_union_count", create_aggregate_function_bitmap_union_count, - true); - - factory.register_function("bitmap_union_int", create_aggregate_function_bitmap_union_int); - factory.register_function("bitmap_union_int", create_aggregate_function_bitmap_union_int, true); + factory.register_function_both("bitmap_union", create_aggregate_function_bitmap_union); + factory.register_function_both("bitmap_intersect", create_aggregate_function_bitmap_intersect); + factory.register_function_both("group_bitmap_xor", create_aggregate_function_group_bitmap_xor); + factory.register_function_both("bitmap_union_count", + create_aggregate_function_bitmap_union_count); + factory.register_function_both("bitmap_union_int", create_aggregate_function_bitmap_union_int); } } // namespace doris::vectorized \ No newline at end of file diff --git a/be/src/vec/aggregate_functions/aggregate_function_group_concat.cpp b/be/src/vec/aggregate_functions/aggregate_function_group_concat.cpp index bcd7becc5e..5bd070ada3 100644 --- a/be/src/vec/aggregate_functions/aggregate_function_group_concat.cpp +++ b/be/src/vec/aggregate_functions/aggregate_function_group_concat.cpp @@ -17,6 +17,8 @@ #include "vec/aggregate_functions/aggregate_function_group_concat.h" +#include "vec/aggregate_functions/helpers.h" + namespace doris::vectorized { const std::string AggregateFunctionGroupConcatImplStr::separator = ", "; @@ -26,12 +28,14 @@ AggregateFunctionPtr create_aggregate_function_group_concat(const std::string& n const bool result_is_nullable) { if (argument_types.size() == 1) { return AggregateFunctionPtr( - new AggregateFunctionGroupConcat( - argument_types)); + creator_without_type::create< + AggregateFunctionGroupConcat>( + result_is_nullable, argument_types)); } else if (argument_types.size() == 2) { return AggregateFunctionPtr( - new AggregateFunctionGroupConcat( - argument_types)); + creator_without_type::create< + AggregateFunctionGroupConcat>( + result_is_nullable, argument_types)); } LOG(WARNING) << fmt::format("Illegal number {} of argument for aggregate function {}", @@ -40,6 +44,6 @@ AggregateFunctionPtr create_aggregate_function_group_concat(const std::string& n } void register_aggregate_function_group_concat(AggregateFunctionSimpleFactory& factory) { - factory.register_function("group_concat", create_aggregate_function_group_concat); + factory.register_function_both("group_concat", create_aggregate_function_group_concat); } } // namespace doris::vectorized diff --git a/be/src/vec/aggregate_functions/aggregate_function_histogram.cpp b/be/src/vec/aggregate_functions/aggregate_function_histogram.cpp index 81dece0c95..77e67ab29a 100644 --- a/be/src/vec/aggregate_functions/aggregate_function_histogram.cpp +++ b/be/src/vec/aggregate_functions/aggregate_function_histogram.cpp @@ -23,56 +23,46 @@ namespace doris::vectorized { template -AggregateFunctionPtr create_agg_function_histogram(const DataTypes& argument_types) { +AggregateFunctionPtr create_agg_function_histogram(const DataTypes& argument_types, + const bool result_is_nullable) { bool has_input_param = (argument_types.size() == 3); if (has_input_param) { return AggregateFunctionPtr( - new AggregateFunctionHistogram, T, true>( - argument_types)); + creator_without_type::create< + AggregateFunctionHistogram, T, true>>( + result_is_nullable, argument_types)); } else { return AggregateFunctionPtr( - new AggregateFunctionHistogram, T, false>( - argument_types)); + creator_without_type::create< + AggregateFunctionHistogram, T, false>>( + result_is_nullable, argument_types)); } } AggregateFunctionPtr create_aggregate_function_histogram(const std::string& name, const DataTypes& argument_types, const bool result_is_nullable) { - WhichDataType type(argument_types[0]); + WhichDataType type(remove_nullable(argument_types[0])); - LOG(INFO) << fmt::format("supported input type {} for aggregate function {}", - argument_types[0]->get_name(), name); - -#define DISPATCH(TYPE) \ - if (type.idx == TypeIndex::TYPE) return create_agg_function_histogram(argument_types); +#define DISPATCH(TYPE) \ + if (type.idx == TypeIndex::TYPE) \ + return create_agg_function_histogram(argument_types, result_is_nullable); FOR_NUMERIC_TYPES(DISPATCH) + FOR_DECIMAL_TYPES(DISPATCH) #undef DISPATCH if (type.idx == TypeIndex::String) { - return create_agg_function_histogram(argument_types); + return create_agg_function_histogram(argument_types, result_is_nullable); } if (type.idx == TypeIndex::DateTime || type.idx == TypeIndex::Date) { - return create_agg_function_histogram(argument_types); + return create_agg_function_histogram(argument_types, result_is_nullable); } if (type.idx == TypeIndex::DateV2) { - return create_agg_function_histogram(argument_types); + return create_agg_function_histogram(argument_types, result_is_nullable); } if (type.idx == TypeIndex::DateTimeV2) { - return create_agg_function_histogram(argument_types); - } - if (type.idx == TypeIndex::Decimal32) { - return create_agg_function_histogram(argument_types); - } - if (type.idx == TypeIndex::Decimal64) { - return create_agg_function_histogram(argument_types); - } - if (type.idx == TypeIndex::Decimal128) { - return create_agg_function_histogram(argument_types); - } - if (type.idx == TypeIndex::Decimal128I) { - return create_agg_function_histogram(argument_types); + return create_agg_function_histogram(argument_types, result_is_nullable); } LOG(WARNING) << fmt::format("unsupported input type {} for aggregate function {}", @@ -81,7 +71,7 @@ AggregateFunctionPtr create_aggregate_function_histogram(const std::string& name } void register_aggregate_function_histogram(AggregateFunctionSimpleFactory& factory) { - factory.register_function("histogram", create_aggregate_function_histogram); + factory.register_function_both("histogram", create_aggregate_function_histogram); factory.register_alias("histogram", "hist"); } diff --git a/be/src/vec/aggregate_functions/aggregate_function_min_max.cpp b/be/src/vec/aggregate_functions/aggregate_function_min_max.cpp index 882b532c7e..46606142b2 100644 --- a/be/src/vec/aggregate_functions/aggregate_function_min_max.cpp +++ b/be/src/vec/aggregate_functions/aggregate_function_min_max.cpp @@ -97,14 +97,9 @@ AggregateFunctionPtr create_aggregate_function_any(const std::string& name, } void register_aggregate_function_minmax(AggregateFunctionSimpleFactory& factory) { - factory.register_function("max", create_aggregate_function_max); - factory.register_function("min", create_aggregate_function_min); - factory.register_function("any", create_aggregate_function_any); - - factory.register_function("max", create_aggregate_function_max, true); - factory.register_function("min", create_aggregate_function_min, true); - factory.register_function("any", create_aggregate_function_any, true); - + factory.register_function_both("max", create_aggregate_function_max); + factory.register_function_both("min", create_aggregate_function_min); + factory.register_function_both("any", create_aggregate_function_any); factory.register_alias("any", "any_value"); } diff --git a/be/src/vec/aggregate_functions/aggregate_function_min_max_by.cpp b/be/src/vec/aggregate_functions/aggregate_function_min_max_by.cpp index 8a4ad945f9..2252da7721 100644 --- a/be/src/vec/aggregate_functions/aggregate_function_min_max_by.cpp +++ b/be/src/vec/aggregate_functions/aggregate_function_min_max_by.cpp @@ -26,101 +26,95 @@ namespace doris::vectorized { /// min_by, max_by -template