// Licensed to the Apache Software Foundation (ASF) under one // or more contributor license agreements. See the NOTICE file // distributed with this work for additional information // regarding copyright ownership. The ASF licenses this file // to you under the Apache License, Version 2.0 (the // "License"); you may not use this file except in compliance // with the License. You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, // software distributed under the License is distributed on an // "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY // KIND, either express or implied. See the License for the // specific language governing permissions and limitations // under the License. #pragma once #include #include #include #include #include #include #include #include #include #include #include #include #include #include "vec/aggregate_functions/aggregate_function.h" #include "vec/aggregate_functions/aggregate_function_simple_factory.h" #include "vec/columns/column.h" #include "vec/columns/column_array.h" #include "vec/columns/column_nullable.h" #include "vec/columns/column_string.h" #include "vec/columns/column_vector.h" #include "vec/columns/columns_number.h" #include "vec/common/assert_cast.h" #include "vec/common/hash_table/phmap_fwd_decl.h" #include "vec/common/string_ref.h" #include "vec/core/types.h" #include "vec/data_types/data_type_array.h" #include "vec/data_types/data_type_nullable.h" #include "vec/data_types/data_type_string.h" #include "vec/io/io_helper.h" namespace doris { namespace vectorized { class Arena; class BufferReadable; class BufferWritable; template class ColumnDecimal; } // namespace vectorized } // namespace doris namespace doris::vectorized { // space-saving algorithm template struct AggregateFunctionTopNData { using ColVecType = std::conditional_t, ColumnDecimal, ColumnVector>; void set_paramenters(int input_top_num, int space_expand_rate = 50) { top_num = input_top_num; capacity = (uint64_t)top_num * space_expand_rate; } void add(const StringRef& value, const UInt64& increment = 1) { std::string data = value.to_string(); auto it = counter_map.find(data); if (it != counter_map.end()) { it->second = it->second + increment; } else { counter_map.insert({data, increment}); } } void add(const T& value, const UInt64& increment = 1) { auto it = counter_map.find(value); if (it != counter_map.end()) { it->second = it->second + increment; } else { counter_map.insert({value, increment}); } } void merge(const AggregateFunctionTopNData& rhs) { if (!rhs.top_num) { return; } top_num = rhs.top_num; capacity = rhs.capacity; bool lhs_full = (counter_map.size() >= capacity); bool rhs_full = (rhs.counter_map.size() >= capacity); uint64_t lhs_min = 0; uint64_t rhs_min = 0; if (lhs_full) { lhs_min = UINT64_MAX; for (auto it : counter_map) { lhs_min = std::min(lhs_min, it.second); } } if (rhs_full) { rhs_min = UINT64_MAX; for (auto it : rhs.counter_map) { rhs_min = std::min(rhs_min, it.second); } for (auto& it : counter_map) { it.second += rhs_min; } } for (auto rhs_it : rhs.counter_map) { auto lhs_it = counter_map.find(rhs_it.first); if (lhs_it != counter_map.end()) { lhs_it->second += rhs_it.second - rhs_min; } else { counter_map.insert({rhs_it.first, rhs_it.second + lhs_min}); } } } std::vector> get_remain_vector() const { std::vector> counter_vector; for (auto it : counter_map) { counter_vector.emplace_back(it.second, it.first); } std::sort(counter_vector.begin(), counter_vector.end(), std::greater>()); return counter_vector; } void write(BufferWritable& buf) const { write_binary(top_num, buf); write_binary(capacity, buf); uint64_t element_number = std::min(capacity, (uint64_t)counter_map.size()); write_binary(element_number, buf); auto counter_vector = get_remain_vector(); for (auto i = 0; i < element_number; i++) { auto element = counter_vector[i]; write_binary(element.second, buf); write_binary(element.first, buf); } } void read(BufferReadable& buf) { read_binary(top_num, buf); read_binary(capacity, buf); uint64_t element_number = 0; read_binary(element_number, buf); counter_map.clear(); std::pair element; for (auto i = 0; i < element_number; i++) { read_binary(element.first, buf); read_binary(element.second, buf); counter_map.insert(element); } } std::string get() const { auto counter_vector = get_remain_vector(); rapidjson::StringBuffer buffer; rapidjson::Writer writer(buffer); writer.StartObject(); for (int i = 0; i < std::min((int)counter_vector.size(), top_num); i++) { const auto& element = counter_vector[i]; writer.Key(element.second.c_str()); writer.Uint64(element.first); } writer.EndObject(); return buffer.GetString(); } void insert_result_into(IColumn& to) const { auto counter_vector = get_remain_vector(); for (int i = 0; i < std::min((int)counter_vector.size(), top_num); i++) { const auto& element = counter_vector[i]; if constexpr (std::is_same_v) { assert_cast(to).insert_data(element.second.c_str(), element.second.length()); } else { assert_cast(to).get_data().push_back(element.second); } } } void reset() { counter_map.clear(); } int top_num = 0; uint64_t capacity = 0; flat_hash_map counter_map; }; struct AggregateFunctionTopNImplInt { static void add(AggregateFunctionTopNData& __restrict place, const IColumn** columns, size_t row_num) { place.set_paramenters(assert_cast(columns[1])->get_element(row_num)); place.add(assert_cast(*columns[0]).get_data_at(row_num)); } }; struct AggregateFunctionTopNImplIntInt { static void add(AggregateFunctionTopNData& __restrict place, const IColumn** columns, size_t row_num) { place.set_paramenters(assert_cast(columns[1])->get_element(row_num), assert_cast(columns[2])->get_element(row_num)); place.add(assert_cast(*columns[0]).get_data_at(row_num)); } }; //for topn_array agg template struct AggregateFunctionTopNImplArray { using ColVecType = std::conditional_t, ColumnDecimal, ColumnVector>; static void add(AggregateFunctionTopNData& __restrict place, const IColumn** columns, size_t row_num) { if constexpr (has_default_param) { place.set_paramenters( assert_cast(columns[1])->get_element(row_num), assert_cast(columns[2])->get_element(row_num)); } else { place.set_paramenters( assert_cast(columns[1])->get_element(row_num)); } if constexpr (std::is_same_v) { place.add(assert_cast(*columns[0]).get_data_at(row_num)); } else { T val = assert_cast(*columns[0]).get_data()[row_num]; place.add(val); } } }; //for topn_weighted agg template struct AggregateFunctionTopNImplWeight { using ColVecType = std::conditional_t, ColumnDecimal, ColumnVector>; static void add(AggregateFunctionTopNData& __restrict place, const IColumn** columns, size_t row_num) { if constexpr (has_default_param) { place.set_paramenters( assert_cast(columns[2])->get_element(row_num), assert_cast(columns[3])->get_element(row_num)); } else { place.set_paramenters( assert_cast(columns[2])->get_element(row_num)); } if constexpr (std::is_same_v) { auto weight = assert_cast&>(*columns[1]).get_data()[row_num]; place.add(assert_cast(*columns[0]).get_data_at(row_num), weight); } else { T val = assert_cast(*columns[0]).get_data()[row_num]; auto weight = assert_cast&>(*columns[1]).get_data()[row_num]; place.add(val, weight); } } }; //base function template class AggregateFunctionTopNBase : public IAggregateFunctionDataHelper, AggregateFunctionTopNBase> { public: AggregateFunctionTopNBase(const DataTypes& argument_types_) : IAggregateFunctionDataHelper, AggregateFunctionTopNBase>(argument_types_) {} void add(AggregateDataPtr __restrict place, const IColumn** columns, size_t row_num, Arena*) const override { Impl::add(this->data(place), columns, row_num); } void reset(AggregateDataPtr __restrict place) const override { this->data(place).reset(); } void merge(AggregateDataPtr __restrict place, ConstAggregateDataPtr rhs, Arena*) const override { this->data(place).merge(this->data(rhs)); } void serialize(ConstAggregateDataPtr __restrict place, BufferWritable& buf) const override { this->data(place).write(buf); } void deserialize(AggregateDataPtr __restrict place, BufferReadable& buf, Arena*) const override { this->data(place).read(buf); } }; //topn function return string template class AggregateFunctionTopN final : public AggregateFunctionTopNBase { public: AggregateFunctionTopN(const DataTypes& argument_types_) : AggregateFunctionTopNBase(argument_types_) {} String get_name() const override { return "topn"; } DataTypePtr get_return_type() const override { return std::make_shared(); } void insert_result_into(ConstAggregateDataPtr __restrict place, IColumn& to) const override { std::string result = this->data(place).get(); assert_cast(to).insert_data(result.c_str(), result.length()); } }; //topn function return array template class AggregateFunctionTopNArray final : public AggregateFunctionTopNBase { public: AggregateFunctionTopNArray(const DataTypes& argument_types_) : AggregateFunctionTopNBase(argument_types_), _argument_type(argument_types_[0]) {} String get_name() const override { if constexpr (is_weighted) { return "topn_weighted"; } else { return "topn_array"; } } DataTypePtr get_return_type() const override { return std::make_shared(make_nullable(_argument_type)); } void insert_result_into(ConstAggregateDataPtr __restrict place, IColumn& to) const override { auto& to_arr = assert_cast(to); auto& to_nested_col = to_arr.get_data(); if (to_nested_col.is_nullable()) { auto col_null = reinterpret_cast(&to_nested_col); this->data(place).insert_result_into(col_null->get_nested_column()); col_null->get_null_map_data().resize_fill(col_null->get_nested_column().size(), 0); } else { this->data(place).insert_result_into(to_nested_col); } to_arr.get_offsets().push_back(to_nested_col.size()); } private: DataTypePtr _argument_type; }; } // namespace doris::vectorized