diff --git a/be/src/vec/aggregate_functions/aggregate_function.h b/be/src/vec/aggregate_functions/aggregate_function.h index a1cdc1a93e..ae9a3667e5 100644 --- a/be/src/vec/aggregate_functions/aggregate_function.h +++ b/be/src/vec/aggregate_functions/aggregate_function.h @@ -37,6 +37,9 @@ class IDataType; template class AggregateFunctionBitmapCount; +template +class AggregateFunctionBitmapOp; +struct AggregateFunctionBitmapUnionOp; using DataTypePtr = std::shared_ptr; using DataTypes = std::vector; @@ -184,7 +187,9 @@ public: void add_batch(size_t batch_size, AggregateDataPtr* places, size_t place_offset, const IColumn** columns, Arena* arena, bool agg_many) const override { if constexpr (std::is_same_v> || - std::is_same_v>) { + std::is_same_v> || + std::is_same_v>) { if (agg_many) { phmap::flat_hash_map> place_rows; for (int i = 0; i < batch_size; ++i) { diff --git a/be/src/vec/aggregate_functions/aggregate_function_bitmap.cpp b/be/src/vec/aggregate_functions/aggregate_function_bitmap.cpp index e429e93f44..abe445fb0f 100644 --- a/be/src/vec/aggregate_functions/aggregate_function_bitmap.cpp +++ b/be/src/vec/aggregate_functions/aggregate_function_bitmap.cpp @@ -58,34 +58,41 @@ AggregateFunctionPtr create_aggregate_function_bitmap_intersect(const std::strin return std::make_shared>( argument_types); } -template + AggregateFunctionPtr create_aggregate_function_bitmap_union_count(const std::string& name, const DataTypes& argument_types, const Array& parameters, const bool result_is_nullable) { - return std::make_shared>(argument_types); + const bool arg_is_nullable = argument_types[0]->is_nullable(); + if (arg_is_nullable) { + return std::make_shared>(argument_types); + } else { + return std::make_shared>(argument_types); + } } -template AggregateFunctionPtr create_aggregate_function_bitmap_union_int(const std::string& name, const DataTypes& argument_types, const Array& parameters, const bool result_is_nullable) { - return std::shared_ptr( - createWithIntDataType(argument_types)); + const bool arg_is_nullable = argument_types[0]->is_nullable(); + if (arg_is_nullable) { + return std::shared_ptr( + createWithIntDataType(argument_types)); + } else { + return std::shared_ptr( + createWithIntDataType(argument_types)); + } } void register_aggregate_function_bitmap(AggregateFunctionSimpleFactory& factory) { factory.register_function("bitmap_union", create_aggregate_function_bitmap_union); factory.register_function("bitmap_intersect", create_aggregate_function_bitmap_intersect); - factory.register_function("bitmap_union_count", - create_aggregate_function_bitmap_union_count); - factory.register_function("bitmap_union_count", - create_aggregate_function_bitmap_union_count, true); - - factory.register_function("bitmap_union_int", - create_aggregate_function_bitmap_union_int); - factory.register_function("bitmap_union_int", create_aggregate_function_bitmap_union_int, + factory.register_function("bitmap_union_count", create_aggregate_function_bitmap_union_count); + factory.register_function("bitmap_union_count", create_aggregate_function_bitmap_union_count, true); + + factory.register_function("bitmap_union_int", create_aggregate_function_bitmap_union_int); + factory.register_function("bitmap_union_int", create_aggregate_function_bitmap_union_int, true); } } // namespace doris::vectorized \ No newline at end of file diff --git a/be/src/vec/aggregate_functions/aggregate_function_bitmap.h b/be/src/vec/aggregate_functions/aggregate_function_bitmap.h index b24041b962..3b7022e51a 100644 --- a/be/src/vec/aggregate_functions/aggregate_function_bitmap.h +++ b/be/src/vec/aggregate_functions/aggregate_function_bitmap.h @@ -19,6 +19,7 @@ #include +#include "util/bitmap_value.h" #include "vec/aggregate_functions/aggregate_function.h" #include "vec/columns/column_complex.h" #include "vec/columns/column_nullable.h" @@ -129,6 +130,18 @@ public: this->data(place).add(column.get_data()[row_num]); } + void add_many(AggregateDataPtr __restrict place, const IColumn** columns, + std::vector& rows, Arena*) const override { + if constexpr (std::is_same_v) { + const auto& column = static_cast(*columns[0]); + std::vector values; + for (int i = 0; i < rows.size(); ++i) { + values.push_back(&(column.get_data()[rows[i]])); + } + this->data(place).add_batch(values); + } + } + void merge(AggregateDataPtr __restrict place, ConstAggregateDataPtr rhs, Arena*) const override { this->data(place).merge( @@ -153,11 +166,11 @@ public: void reset(AggregateDataPtr __restrict place) const override { this->data(place).reset(); } }; -template +template class AggregateFunctionBitmapCount final : public IAggregateFunctionDataHelper< AggregateFunctionBitmapData, - AggregateFunctionBitmapCount> { + AggregateFunctionBitmapCount> { public: // using ColVecType = ColumnBitmap; using ColVecResult = ColumnVector; @@ -166,14 +179,15 @@ public: AggregateFunctionBitmapCount(const DataTypes& argument_types_) : IAggregateFunctionDataHelper< AggregateFunctionBitmapData, - AggregateFunctionBitmapCount>(argument_types_, {}) {} + AggregateFunctionBitmapCount>(argument_types_, + {}) {} String get_name() const override { return "count"; } DataTypePtr get_return_type() const override { return std::make_shared(); } void add(AggregateDataPtr __restrict place, const IColumn** columns, size_t row_num, Arena*) const override { - if constexpr (nullable) { + if constexpr (arg_is_nullable) { auto& nullable_column = assert_cast(*columns[0]); if (!nullable_column.is_null_at(row_num)) { const auto& column = @@ -188,7 +202,7 @@ public: void add_many(AggregateDataPtr __restrict place, const IColumn** columns, std::vector& rows, Arena*) const override { - if constexpr (nullable && std::is_same_v) { + if constexpr (arg_is_nullable && std::is_same_v) { auto& nullable_column = assert_cast(*columns[0]); const auto& column = static_cast(nullable_column.get_nested_column());