From 527eb5b059e15c22411ba44f2434b47aace0dc9f Mon Sep 17 00:00:00 2001 From: Pxl Date: Thu, 2 Mar 2023 00:00:01 +0800 Subject: [PATCH] =?UTF-8?q?[Enchancement](function)=20nullable=20inline=20?= =?UTF-8?q?refactor=20of=20min=5Fmax=5Fby/bitmap=20&&=20add=20register=5Ff?= =?UTF-8?q?unctio=E2=80=A6=20(#17228)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit 1. nullable inline refactor of min_max_by/bitmap/group_concat/histogram/topn 2. add register_function_both method 3. add datetimev2 type creator of min_max_by 4. remove uint16/32/64 in FOR_INTEGER_TYPES --- .../aggregate_function_avg.cpp | 3 +- .../aggregate_function_bit.cpp | 21 +-- .../aggregate_function_bitmap.cpp | 58 ++++---- .../aggregate_function_group_concat.cpp | 14 +- .../aggregate_function_histogram.cpp | 46 +++---- .../aggregate_function_min_max.cpp | 11 +- .../aggregate_function_min_max_by.cpp | 124 +++++++++--------- .../aggregate_function_min_max_by.h | 16 +-- .../aggregate_function_orthogonal_bitmap.cpp | 28 +--- .../aggregate_function_simple_factory.h | 5 + .../aggregate_function_sum.cpp | 5 +- .../aggregate_function_topn.cpp | 83 ++++++------ .../aggregate_function_uniq.cpp | 3 +- .../aggregate_function_window.cpp | 12 +- be/src/vec/aggregate_functions/helpers.h | 119 ++++------------- be/src/vec/core/types.h | 9 ++ 16 files changed, 230 insertions(+), 327 deletions(-) diff --git a/be/src/vec/aggregate_functions/aggregate_function_avg.cpp b/be/src/vec/aggregate_functions/aggregate_function_avg.cpp index 8bda389f4a..4f493c9529 100644 --- a/be/src/vec/aggregate_functions/aggregate_function_avg.cpp +++ b/be/src/vec/aggregate_functions/aggregate_function_avg.cpp @@ -52,7 +52,6 @@ AggregateFunctionPtr create_aggregate_function_avg(const std::string& name, } void register_aggregate_function_avg(AggregateFunctionSimpleFactory& factory) { - factory.register_function("avg", create_aggregate_function_avg); - factory.register_function("avg", create_aggregate_function_avg, true); + factory.register_function_both("avg", create_aggregate_function_avg); } } // namespace doris::vectorized diff --git a/be/src/vec/aggregate_functions/aggregate_function_bit.cpp b/be/src/vec/aggregate_functions/aggregate_function_bit.cpp index 379df49559..6b9be5c92c 100644 --- a/be/src/vec/aggregate_functions/aggregate_function_bit.cpp +++ b/be/src/vec/aggregate_functions/aggregate_function_bit.cpp @@ -47,21 +47,12 @@ AggregateFunctionPtr createAggregateFunctionBitwise(const std::string& name, } void register_aggregate_function_bit(AggregateFunctionSimpleFactory& factory) { - factory.register_function("group_bit_or", - createAggregateFunctionBitwise); - factory.register_function("group_bit_and", - createAggregateFunctionBitwise); - factory.register_function("group_bit_xor", - createAggregateFunctionBitwise); - - factory.register_function( - "group_bit_or", createAggregateFunctionBitwise, true); - factory.register_function("group_bit_and", - createAggregateFunctionBitwise, - true); - factory.register_function("group_bit_xor", - createAggregateFunctionBitwise, - true); + factory.register_function_both("group_bit_or", + createAggregateFunctionBitwise); + factory.register_function_both( + "group_bit_and", createAggregateFunctionBitwise); + factory.register_function_both( + "group_bit_xor", createAggregateFunctionBitwise); } } // namespace doris::vectorized \ No newline at end of file diff --git a/be/src/vec/aggregate_functions/aggregate_function_bitmap.cpp b/be/src/vec/aggregate_functions/aggregate_function_bitmap.cpp index eb9a8fb35c..e2dd7e309d 100644 --- a/be/src/vec/aggregate_functions/aggregate_function_bitmap.cpp +++ b/be/src/vec/aggregate_functions/aggregate_function_bitmap.cpp @@ -18,50 +18,45 @@ #include "vec/aggregate_functions/aggregate_function_bitmap.h" #include "vec/aggregate_functions/aggregate_function_simple_factory.h" +#include "vec/aggregate_functions/helpers.h" namespace doris::vectorized { template class AggregateFunctionTemplate> -static IAggregateFunction* createWithIntDataType(const DataTypes& argument_type) { - auto type = argument_type[0].get(); - if (type->is_nullable()) { - type = assert_cast(type)->get_nested_type().get(); - } +static IAggregateFunction* create_with_int_data_type(const DataTypes& argument_type) { + auto type = remove_nullable(argument_type[0]); WhichDataType which(type); - if (which.idx == TypeIndex::Int8) { - return new AggregateFunctionTemplate>(argument_type); - } - if (which.idx == TypeIndex::Int16) { - return new AggregateFunctionTemplate>(argument_type); - } - if (which.idx == TypeIndex::Int32) { - return new AggregateFunctionTemplate>(argument_type); - } - if (which.idx == TypeIndex::Int64) { - return new AggregateFunctionTemplate>(argument_type); +#define DISPATCH(TYPE) \ + if (which.idx == TypeIndex::TYPE) { \ + return new AggregateFunctionTemplate>(argument_type); \ } + FOR_INTEGER_TYPES(DISPATCH) +#undef DISPATCH return nullptr; } AggregateFunctionPtr create_aggregate_function_bitmap_union(const std::string& name, const DataTypes& argument_types, const bool result_is_nullable) { - return std::make_shared>( - argument_types); + return AggregateFunctionPtr( + creator_without_type::create>( + result_is_nullable, argument_types)); } AggregateFunctionPtr create_aggregate_function_bitmap_intersect(const std::string& name, const DataTypes& argument_types, const bool result_is_nullable) { - return std::make_shared>( - argument_types); + return AggregateFunctionPtr(creator_without_type::create< + AggregateFunctionBitmapOp>( + result_is_nullable, argument_types)); } AggregateFunctionPtr create_aggregate_function_group_bitmap_xor(const std::string& name, const DataTypes& argument_types, const bool result_is_nullable) { - return std::make_shared>( - argument_types); + return AggregateFunctionPtr(creator_without_type::create< + AggregateFunctionBitmapOp>( + result_is_nullable, argument_types)); } AggregateFunctionPtr create_aggregate_function_bitmap_union_count(const std::string& name, @@ -81,22 +76,19 @@ AggregateFunctionPtr create_aggregate_function_bitmap_union_int(const std::strin const bool arg_is_nullable = argument_types[0]->is_nullable(); if (arg_is_nullable) { return std::shared_ptr( - createWithIntDataType(argument_types)); + create_with_int_data_type(argument_types)); } else { return std::shared_ptr( - createWithIntDataType(argument_types)); + create_with_int_data_type(argument_types)); } } void register_aggregate_function_bitmap(AggregateFunctionSimpleFactory& factory) { - factory.register_function("bitmap_union", create_aggregate_function_bitmap_union); - factory.register_function("bitmap_intersect", create_aggregate_function_bitmap_intersect); - factory.register_function("group_bitmap_xor", create_aggregate_function_group_bitmap_xor); - factory.register_function("bitmap_union_count", create_aggregate_function_bitmap_union_count); - factory.register_function("bitmap_union_count", create_aggregate_function_bitmap_union_count, - true); - - factory.register_function("bitmap_union_int", create_aggregate_function_bitmap_union_int); - factory.register_function("bitmap_union_int", create_aggregate_function_bitmap_union_int, true); + factory.register_function_both("bitmap_union", create_aggregate_function_bitmap_union); + factory.register_function_both("bitmap_intersect", create_aggregate_function_bitmap_intersect); + factory.register_function_both("group_bitmap_xor", create_aggregate_function_group_bitmap_xor); + factory.register_function_both("bitmap_union_count", + create_aggregate_function_bitmap_union_count); + factory.register_function_both("bitmap_union_int", create_aggregate_function_bitmap_union_int); } } // namespace doris::vectorized \ No newline at end of file diff --git a/be/src/vec/aggregate_functions/aggregate_function_group_concat.cpp b/be/src/vec/aggregate_functions/aggregate_function_group_concat.cpp index bcd7becc5e..5bd070ada3 100644 --- a/be/src/vec/aggregate_functions/aggregate_function_group_concat.cpp +++ b/be/src/vec/aggregate_functions/aggregate_function_group_concat.cpp @@ -17,6 +17,8 @@ #include "vec/aggregate_functions/aggregate_function_group_concat.h" +#include "vec/aggregate_functions/helpers.h" + namespace doris::vectorized { const std::string AggregateFunctionGroupConcatImplStr::separator = ", "; @@ -26,12 +28,14 @@ AggregateFunctionPtr create_aggregate_function_group_concat(const std::string& n const bool result_is_nullable) { if (argument_types.size() == 1) { return AggregateFunctionPtr( - new AggregateFunctionGroupConcat( - argument_types)); + creator_without_type::create< + AggregateFunctionGroupConcat>( + result_is_nullable, argument_types)); } else if (argument_types.size() == 2) { return AggregateFunctionPtr( - new AggregateFunctionGroupConcat( - argument_types)); + creator_without_type::create< + AggregateFunctionGroupConcat>( + result_is_nullable, argument_types)); } LOG(WARNING) << fmt::format("Illegal number {} of argument for aggregate function {}", @@ -40,6 +44,6 @@ AggregateFunctionPtr create_aggregate_function_group_concat(const std::string& n } void register_aggregate_function_group_concat(AggregateFunctionSimpleFactory& factory) { - factory.register_function("group_concat", create_aggregate_function_group_concat); + factory.register_function_both("group_concat", create_aggregate_function_group_concat); } } // namespace doris::vectorized diff --git a/be/src/vec/aggregate_functions/aggregate_function_histogram.cpp b/be/src/vec/aggregate_functions/aggregate_function_histogram.cpp index 81dece0c95..77e67ab29a 100644 --- a/be/src/vec/aggregate_functions/aggregate_function_histogram.cpp +++ b/be/src/vec/aggregate_functions/aggregate_function_histogram.cpp @@ -23,56 +23,46 @@ namespace doris::vectorized { template -AggregateFunctionPtr create_agg_function_histogram(const DataTypes& argument_types) { +AggregateFunctionPtr create_agg_function_histogram(const DataTypes& argument_types, + const bool result_is_nullable) { bool has_input_param = (argument_types.size() == 3); if (has_input_param) { return AggregateFunctionPtr( - new AggregateFunctionHistogram, T, true>( - argument_types)); + creator_without_type::create< + AggregateFunctionHistogram, T, true>>( + result_is_nullable, argument_types)); } else { return AggregateFunctionPtr( - new AggregateFunctionHistogram, T, false>( - argument_types)); + creator_without_type::create< + AggregateFunctionHistogram, T, false>>( + result_is_nullable, argument_types)); } } AggregateFunctionPtr create_aggregate_function_histogram(const std::string& name, const DataTypes& argument_types, const bool result_is_nullable) { - WhichDataType type(argument_types[0]); + WhichDataType type(remove_nullable(argument_types[0])); - LOG(INFO) << fmt::format("supported input type {} for aggregate function {}", - argument_types[0]->get_name(), name); - -#define DISPATCH(TYPE) \ - if (type.idx == TypeIndex::TYPE) return create_agg_function_histogram(argument_types); +#define DISPATCH(TYPE) \ + if (type.idx == TypeIndex::TYPE) \ + return create_agg_function_histogram(argument_types, result_is_nullable); FOR_NUMERIC_TYPES(DISPATCH) + FOR_DECIMAL_TYPES(DISPATCH) #undef DISPATCH if (type.idx == TypeIndex::String) { - return create_agg_function_histogram(argument_types); + return create_agg_function_histogram(argument_types, result_is_nullable); } if (type.idx == TypeIndex::DateTime || type.idx == TypeIndex::Date) { - return create_agg_function_histogram(argument_types); + return create_agg_function_histogram(argument_types, result_is_nullable); } if (type.idx == TypeIndex::DateV2) { - return create_agg_function_histogram(argument_types); + return create_agg_function_histogram(argument_types, result_is_nullable); } if (type.idx == TypeIndex::DateTimeV2) { - return create_agg_function_histogram(argument_types); - } - if (type.idx == TypeIndex::Decimal32) { - return create_agg_function_histogram(argument_types); - } - if (type.idx == TypeIndex::Decimal64) { - return create_agg_function_histogram(argument_types); - } - if (type.idx == TypeIndex::Decimal128) { - return create_agg_function_histogram(argument_types); - } - if (type.idx == TypeIndex::Decimal128I) { - return create_agg_function_histogram(argument_types); + return create_agg_function_histogram(argument_types, result_is_nullable); } LOG(WARNING) << fmt::format("unsupported input type {} for aggregate function {}", @@ -81,7 +71,7 @@ AggregateFunctionPtr create_aggregate_function_histogram(const std::string& name } void register_aggregate_function_histogram(AggregateFunctionSimpleFactory& factory) { - factory.register_function("histogram", create_aggregate_function_histogram); + factory.register_function_both("histogram", create_aggregate_function_histogram); factory.register_alias("histogram", "hist"); } diff --git a/be/src/vec/aggregate_functions/aggregate_function_min_max.cpp b/be/src/vec/aggregate_functions/aggregate_function_min_max.cpp index 882b532c7e..46606142b2 100644 --- a/be/src/vec/aggregate_functions/aggregate_function_min_max.cpp +++ b/be/src/vec/aggregate_functions/aggregate_function_min_max.cpp @@ -97,14 +97,9 @@ AggregateFunctionPtr create_aggregate_function_any(const std::string& name, } void register_aggregate_function_minmax(AggregateFunctionSimpleFactory& factory) { - factory.register_function("max", create_aggregate_function_max); - factory.register_function("min", create_aggregate_function_min); - factory.register_function("any", create_aggregate_function_any); - - factory.register_function("max", create_aggregate_function_max, true); - factory.register_function("min", create_aggregate_function_min, true); - factory.register_function("any", create_aggregate_function_any, true); - + factory.register_function_both("max", create_aggregate_function_max); + factory.register_function_both("min", create_aggregate_function_min); + factory.register_function_both("any", create_aggregate_function_any); factory.register_alias("any", "any_value"); } diff --git a/be/src/vec/aggregate_functions/aggregate_function_min_max_by.cpp b/be/src/vec/aggregate_functions/aggregate_function_min_max_by.cpp index 8a4ad945f9..2252da7721 100644 --- a/be/src/vec/aggregate_functions/aggregate_function_min_max_by.cpp +++ b/be/src/vec/aggregate_functions/aggregate_function_min_max_by.cpp @@ -26,101 +26,95 @@ namespace doris::vectorized { /// min_by, max_by -template