From c4edea5936d2357c537002acf07d2970142e575c Mon Sep 17 00:00:00 2001 From: Pxl Date: Fri, 24 Feb 2023 10:05:11 +0800 Subject: [PATCH] [Enchancement](function) refact and optimize some function register (#16955) refact and optimize some function register --- .../aggregate_function_avg.cpp | 21 +- .../aggregate_function_avg.h | 7 +- .../aggregate_function_bit.cpp | 34 +- .../aggregate_function_distinct.cpp | 31 +- .../aggregate_function_min_max.cpp | 74 ++--- .../aggregate_function_min_max.h | 15 +- .../aggregate_function_null.h | 94 ++++-- .../aggregate_function_orthogonal_bitmap.cpp | 29 +- .../aggregate_function_product.h | 7 +- .../aggregate_function_sum.cpp | 21 +- .../aggregate_function_sum.h | 7 +- .../aggregate_function_uniq.cpp | 38 +-- be/src/vec/aggregate_functions/helpers.h | 301 ++++++++++-------- be/src/vec/data_types/data_type_decimal.h | 3 +- be/src/vec/data_types/data_type_nullable.cpp | 15 +- be/src/vec/data_types/data_type_nullable.h | 5 + be/src/vec/data_types/get_least_supertype.cpp | 2 +- .../array/function_array_aggregation.cpp | 9 +- 18 files changed, 370 insertions(+), 343 deletions(-) diff --git a/be/src/vec/aggregate_functions/aggregate_function_avg.cpp b/be/src/vec/aggregate_functions/aggregate_function_avg.cpp index 19b5549582..8bda389f4a 100644 --- a/be/src/vec/aggregate_functions/aggregate_function_avg.cpp +++ b/be/src/vec/aggregate_functions/aggregate_function_avg.cpp @@ -41,25 +41,8 @@ AggregateFunctionPtr create_aggregate_function_avg(const std::string& name, const bool result_is_nullable) { assert_unary(name, argument_types); - AggregateFunctionPtr res; - DataTypePtr data_type = argument_types[0]; - if (data_type->is_nullable()) { - auto no_null_argument_types = remove_nullable(argument_types); - if (is_decimal(no_null_argument_types[0])) { - res.reset(create_with_decimal_type_null( - no_null_argument_types, *no_null_argument_types[0], no_null_argument_types)); - } else { - res.reset(create_with_numeric_type_null(no_null_argument_types, - no_null_argument_types)); - } - } else { - if (is_decimal(data_type)) { - res.reset(create_with_decimal_type(*data_type, *data_type, - argument_types)); - } else { - res.reset(create_with_numeric_type(*data_type, argument_types)); - } - } + AggregateFunctionPtr res( + creator_with_type::create(result_is_nullable, argument_types)); if (!res) { LOG(WARNING) << fmt::format("Illegal type {} of argument for aggregate function {}", diff --git a/be/src/vec/aggregate_functions/aggregate_function_avg.h b/be/src/vec/aggregate_functions/aggregate_function_avg.h index e74a38793f..125ab1b8ee 100644 --- a/be/src/vec/aggregate_functions/aggregate_function_avg.h +++ b/be/src/vec/aggregate_functions/aggregate_function_avg.h @@ -95,12 +95,7 @@ public: /// ctor for native types AggregateFunctionAvg(const DataTypes& argument_types_) : IAggregateFunctionDataHelper>(argument_types_), - scale(0) {} - - /// ctor for Decimals - AggregateFunctionAvg(const IDataType& data_type, const DataTypes& argument_types_) - : IAggregateFunctionDataHelper>(argument_types_), - scale(get_decimal_scale(data_type)) {} + scale(get_decimal_scale(*argument_types_[0])) {} String get_name() const override { return "avg"; } diff --git a/be/src/vec/aggregate_functions/aggregate_function_bit.cpp b/be/src/vec/aggregate_functions/aggregate_function_bit.cpp index afe7fde3a1..379df49559 100644 --- a/be/src/vec/aggregate_functions/aggregate_function_bit.cpp +++ b/be/src/vec/aggregate_functions/aggregate_function_bit.cpp @@ -21,6 +21,7 @@ #include "vec/aggregate_functions/aggregate_function_bit.h" #include "vec/aggregate_functions/aggregate_function_simple_factory.h" +#include "vec/aggregate_functions/helpers.h" namespace doris::vectorized { @@ -34,26 +35,10 @@ AggregateFunctionPtr createAggregateFunctionBitwise(const std::string& name, " is illegal, because it cannot be used in bitwise operations"); } - auto type = argument_types[0].get(); - if (type->is_nullable()) { - type = assert_cast(type)->get_nested_type().get(); - } - - WhichDataType which(*type); - if (which.is_int8()) { - return AggregateFunctionPtr(new AggregateFunctionBitwise>(argument_types)); - } else if (which.is_int16()) { - return AggregateFunctionPtr( - new AggregateFunctionBitwise>(argument_types)); - } else if (which.is_int32()) { - return AggregateFunctionPtr( - new AggregateFunctionBitwise>(argument_types)); - } else if (which.is_int64()) { - return AggregateFunctionPtr( - new AggregateFunctionBitwise>(argument_types)); - } else if (which.is_int128()) { - return AggregateFunctionPtr( - new AggregateFunctionBitwise>(argument_types)); + AggregateFunctionPtr res(creator_with_integer_type::create( + result_is_nullable, argument_types)); + if (res) { + return res; } LOG(WARNING) << fmt::format("Illegal type " + argument_types[0]->get_name() + @@ -68,6 +53,15 @@ void register_aggregate_function_bit(AggregateFunctionSimpleFactory& factory) { createAggregateFunctionBitwise); factory.register_function("group_bit_xor", createAggregateFunctionBitwise); + + factory.register_function( + "group_bit_or", createAggregateFunctionBitwise, true); + factory.register_function("group_bit_and", + createAggregateFunctionBitwise, + true); + factory.register_function("group_bit_xor", + createAggregateFunctionBitwise, + true); } } // namespace doris::vectorized \ No newline at end of file diff --git a/be/src/vec/aggregate_functions/aggregate_function_distinct.cpp b/be/src/vec/aggregate_functions/aggregate_function_distinct.cpp index 218c11e73b..a98a1b8e19 100644 --- a/be/src/vec/aggregate_functions/aggregate_function_distinct.cpp +++ b/be/src/vec/aggregate_functions/aggregate_function_distinct.cpp @@ -49,30 +49,30 @@ public: return nullptr; } - AggregateFunctionPtr res; if (arguments.size() == 1) { - res.reset(create_with_numeric_type( - *arguments[0], nested_function, arguments)); - + AggregateFunctionPtr res( + creator_with_numeric_type::create( + result_is_nullable, arguments, nested_function)); if (res) { return res; } if (arguments[0]->is_value_unambiguously_represented_in_contiguous_memory_region()) { - return std::make_shared>>(nested_function, - arguments); + res.reset(creator_without_type::create>>( + result_is_nullable, arguments, nested_function)); } else { - return std::make_shared>>(nested_function, - arguments); + res.reset(creator_without_type::create>>( + result_is_nullable, arguments, nested_function)); } + return res; } - - return std::make_shared< - AggregateFunctionDistinct>( - nested_function, arguments); + return AggregateFunctionPtr( + creator_without_type::create< + AggregateFunctionDistinct>( + result_is_nullable, arguments, nested_function)); } }; @@ -93,5 +93,6 @@ void register_aggregate_function_combinator_distinct(AggregateFunctionSimpleFact result_is_nullable); }; factory.register_distinct_function_combinator(creator, DISTINCT_FUNCTION_PREFIX); + factory.register_distinct_function_combinator(creator, DISTINCT_FUNCTION_PREFIX, true); } } // namespace doris::vectorized diff --git a/be/src/vec/aggregate_functions/aggregate_function_min_max.cpp b/be/src/vec/aggregate_functions/aggregate_function_min_max.cpp index 075aa1e943..882b532c7e 100644 --- a/be/src/vec/aggregate_functions/aggregate_function_min_max.cpp +++ b/be/src/vec/aggregate_functions/aggregate_function_min_max.cpp @@ -26,50 +26,45 @@ namespace doris::vectorized { /// min, max, any -template