From 65b8dfc7ff8eb2af77d76631d898a41814774aa7 Mon Sep 17 00:00:00 2001 From: Pxl Date: Thu, 9 Mar 2023 10:39:04 +0800 Subject: [PATCH] [Enchancement](function) Inline some aggregate function && remove nullable combinator (#17328) 1. Inline some aggregate function 2. remove nullable combinator --- .gitignore | 1 + be/src/vec/CMakeLists.txt | 1 - .../aggregate_functions/aggregate_function.h | 4 +- ...gregate_function_approx_count_distinct.cpp | 18 +- .../aggregate_function_avg_weighted.cpp | 27 +- .../aggregate_function_avg_weighted.h | 8 +- .../aggregate_function_bit.cpp | 6 - .../aggregate_function_collect.cpp | 56 +-- .../aggregate_function_collect.h | 6 +- .../aggregate_function_hll_union_agg.cpp | 34 +- .../aggregate_function_hll_union_agg.h | 33 +- .../aggregate_function_null.cpp | 99 ----- .../aggregate_function_null.h | 368 ------------------ .../aggregate_function_percentile_approx.cpp | 34 +- .../aggregate_function_percentile_approx.h | 34 +- .../aggregate_function_reader.cpp | 18 +- .../aggregate_function_retention.cpp | 5 +- .../aggregate_function_sequence_match.cpp | 20 +- .../aggregate_function_simple_factory.cpp | 4 - .../aggregate_function_stddev.cpp | 83 ++-- .../aggregate_function_stddev.h | 1 + .../aggregate_function_window.cpp | 14 +- .../aggregate_function_window_funnel.cpp | 18 +- .../array/function_array_aggregation.cpp | 30 +- be/src/vec/utils/template_helpers.hpp | 21 +- .../agg_histogram_test.cpp | 5 +- 26 files changed, 206 insertions(+), 742 deletions(-) delete mode 100644 be/src/vec/aggregate_functions/aggregate_function_null.cpp diff --git a/.gitignore b/.gitignore index c6229863aa..7b5868f79b 100644 --- a/.gitignore +++ b/.gitignore @@ -96,5 +96,6 @@ tools/single-node-cluster/fe* # be-ut data_test +lru_cache_test /conf/log4j2-spring.xml diff --git a/be/src/vec/CMakeLists.txt b/be/src/vec/CMakeLists.txt index d9e897e67b..ac94673475 100644 --- a/be/src/vec/CMakeLists.txt +++ b/be/src/vec/CMakeLists.txt @@ -31,7 +31,6 @@ set(VEC_FILES aggregate_functions/aggregate_function_sort.cpp aggregate_functions/aggregate_function_min_max.cpp aggregate_functions/aggregate_function_min_max_by.cpp - aggregate_functions/aggregate_function_null.cpp aggregate_functions/aggregate_function_uniq.cpp aggregate_functions/aggregate_function_hll_union_agg.cpp aggregate_functions/aggregate_function_bit.cpp diff --git a/be/src/vec/aggregate_functions/aggregate_function.h b/be/src/vec/aggregate_functions/aggregate_function.h index d4a906231f..e86415b729 100644 --- a/be/src/vec/aggregate_functions/aggregate_function.h +++ b/be/src/vec/aggregate_functions/aggregate_function.h @@ -317,13 +317,13 @@ public: void streaming_agg_serialize_to_column(const IColumn** columns, MutableColumnPtr& dst, const size_t num_rows, Arena* arena) const override { - VectorBufferWriter writter(static_cast(*dst)); + VectorBufferWriter writter(assert_cast(*dst)); streaming_agg_serialize(columns, writter, num_rows, arena); } void serialize_without_key_to_column(ConstAggregateDataPtr __restrict place, MutableColumnPtr& dst) const override { - VectorBufferWriter writter(static_cast(*dst)); + VectorBufferWriter writter(assert_cast(*dst)); static_cast(this)->serialize(place, writter); writter.commit(); } diff --git a/be/src/vec/aggregate_functions/aggregate_function_approx_count_distinct.cpp b/be/src/vec/aggregate_functions/aggregate_function_approx_count_distinct.cpp index 0fa50b5194..2c22586d43 100644 --- a/be/src/vec/aggregate_functions/aggregate_function_approx_count_distinct.cpp +++ b/be/src/vec/aggregate_functions/aggregate_function_approx_count_distinct.cpp @@ -17,6 +17,7 @@ #include "vec/aggregate_functions/aggregate_function_approx_count_distinct.h" +#include "vec/aggregate_functions/helpers.h" #include "vec/utils/template_helpers.hpp" namespace doris::vectorized { @@ -24,13 +25,14 @@ namespace doris::vectorized { AggregateFunctionPtr create_aggregate_function_approx_count_distinct( const std::string& name, const DataTypes& argument_types, const bool result_is_nullable) { AggregateFunctionPtr res = nullptr; - WhichDataType which(argument_types[0]->is_nullable() - ? reinterpret_cast(argument_types[0].get()) - ->get_nested_type() - : argument_types[0]); + WhichDataType which(remove_nullable(argument_types[0])); - res.reset(create_class_with_type(*argument_types[0], - argument_types)); +#define DISPATCH(TYPE, COLUMN_TYPE) \ + if (which.idx == TypeIndex::TYPE) \ + res.reset(creator_without_type::create>( \ + result_is_nullable, argument_types)); + TYPE_TO_COLUMN_TYPE(DISPATCH) +#undef DISPATCH if (!res) { LOG(WARNING) << fmt::format("Illegal type {} of argument for aggregate function {}", @@ -41,8 +43,8 @@ AggregateFunctionPtr create_aggregate_function_approx_count_distinct( } void register_aggregate_function_approx_count_distinct(AggregateFunctionSimpleFactory& factory) { - factory.register_function("approx_count_distinct", - create_aggregate_function_approx_count_distinct); + factory.register_function_both("approx_count_distinct", + create_aggregate_function_approx_count_distinct); factory.register_alias("approx_count_distinct", "ndv"); } diff --git a/be/src/vec/aggregate_functions/aggregate_function_avg_weighted.cpp b/be/src/vec/aggregate_functions/aggregate_function_avg_weighted.cpp index ea4e058550..c81bf4b42f 100644 --- a/be/src/vec/aggregate_functions/aggregate_function_avg_weighted.cpp +++ b/be/src/vec/aggregate_functions/aggregate_function_avg_weighted.cpp @@ -19,37 +19,18 @@ #include "vec/aggregate_functions/aggregate_function_simple_factory.h" #include "vec/aggregate_functions/helpers.h" +#include "vec/data_types/data_type_nullable.h" namespace doris::vectorized { AggregateFunctionPtr create_aggregate_function_avg_weight(const std::string& name, const DataTypes& argument_types, const bool result_is_nullable) { - auto type = argument_types[0].get(); - if (type->is_nullable()) { - type = assert_cast(type)->get_nested_type().get(); - } - - WhichDataType which(*type); - -#define DISPATCH(TYPE) \ - if (which.idx == TypeIndex::TYPE) \ - return AggregateFunctionPtr(new AggregateFunctionAvgWeight(argument_types)); - FOR_NUMERIC_TYPES(DISPATCH) -#undef DISPATCH - if (which.is_decimal128()) { - return AggregateFunctionPtr(new AggregateFunctionAvgWeight(argument_types)); - } - if (which.is_decimal()) { - return AggregateFunctionPtr(new AggregateFunctionAvgWeight(argument_types)); - } - - LOG(WARNING) << fmt::format("Illegal argument type for aggregate function topn_array is: {}", - type->get_name()); - return nullptr; + return AggregateFunctionPtr(creator_with_type::create( + result_is_nullable, argument_types)); } void register_aggregate_function_avg_weighted(AggregateFunctionSimpleFactory& factory) { - factory.register_function("avg_weighted", create_aggregate_function_avg_weight); + factory.register_function_both("avg_weighted", create_aggregate_function_avg_weight); } } // namespace doris::vectorized diff --git a/be/src/vec/aggregate_functions/aggregate_function_avg_weighted.h b/be/src/vec/aggregate_functions/aggregate_function_avg_weighted.h index aa3b70d4de..cc14f1e1b3 100644 --- a/be/src/vec/aggregate_functions/aggregate_function_avg_weighted.h +++ b/be/src/vec/aggregate_functions/aggregate_function_avg_weighted.h @@ -31,7 +31,7 @@ struct AggregateFunctionAvgWeightedData { void add(const T& data_val, double weight_val) { if constexpr (IsDecimalV2) { DecimalV2Value value = binary_cast(data_val); - data_sum = data_sum + (static_cast(value) * weight_val); + data_sum = data_sum + (double(value) * weight_val); } else { data_sum = data_sum + (data_val * weight_val); } @@ -81,8 +81,8 @@ public: void add(AggregateDataPtr __restrict place, const IColumn** columns, size_t row_num, Arena*) const override { - const auto& column = static_cast(*columns[0]); - const auto& weight = static_cast&>(*columns[1]); + const auto& column = assert_cast(*columns[0]); + const auto& weight = assert_cast&>(*columns[1]); this->data(place).add(column.get_data()[row_num], weight.get_element(row_num)); } @@ -103,7 +103,7 @@ public: } void insert_result_into(ConstAggregateDataPtr __restrict place, IColumn& to) const override { - auto& column = static_cast&>(to); + auto& column = assert_cast&>(to); column.get_data().push_back(this->data(place).get()); } }; diff --git a/be/src/vec/aggregate_functions/aggregate_function_bit.cpp b/be/src/vec/aggregate_functions/aggregate_function_bit.cpp index 6b9be5c92c..bdc51daaf9 100644 --- a/be/src/vec/aggregate_functions/aggregate_function_bit.cpp +++ b/be/src/vec/aggregate_functions/aggregate_function_bit.cpp @@ -29,12 +29,6 @@ template