[Improvement](agg) Improve count distinct distribute keys (#33167)

This commit is contained in:
Gabriel
2024-04-26 18:31:11 +08:00
committed by yiguolei
parent c125148deb
commit 970d0c80df
4 changed files with 329 additions and 1 deletions

View File

@ -40,6 +40,7 @@ void register_aggregate_function_count(AggregateFunctionSimpleFactory& factory);
void register_aggregate_function_count_by_enum(AggregateFunctionSimpleFactory& factory);
void register_aggregate_function_HLL_union_agg(AggregateFunctionSimpleFactory& factory);
void register_aggregate_function_uniq(AggregateFunctionSimpleFactory& factory);
void register_aggregate_function_uniq_distribute_key(AggregateFunctionSimpleFactory& factory);
void register_aggregate_function_bit(AggregateFunctionSimpleFactory& factory);
void register_aggregate_function_bitmap(AggregateFunctionSimpleFactory& factory);
void register_aggregate_function_quantile_state(AggregateFunctionSimpleFactory& factory);
@ -80,6 +81,7 @@ AggregateFunctionSimpleFactory& AggregateFunctionSimpleFactory::instance() {
register_aggregate_function_count(instance);
register_aggregate_function_count_by_enum(instance);
register_aggregate_function_uniq(instance);
register_aggregate_function_uniq_distribute_key(instance);
register_aggregate_function_bit(instance);
register_aggregate_function_bitmap(instance);
register_aggregate_function_group_array_intersect(instance);

View File

@ -75,7 +75,7 @@ struct AggregateFunctionUniqExactData {
Set set;
static String get_name() { return "uniqExact"; }
static String get_name() { return "multi_distinct"; }
};
namespace detail {

View File

@ -0,0 +1,73 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.
#include "vec/aggregate_functions/aggregate_function_uniq_distribute_key.h"
#include <string>
#include "vec/aggregate_functions/aggregate_function_simple_factory.h"
#include "vec/aggregate_functions/factory_helpers.h"
#include "vec/aggregate_functions/helpers.h"
namespace doris::vectorized {
template <template <typename> class Data>
AggregateFunctionPtr create_aggregate_function_uniq(const std::string& name,
const DataTypes& argument_types,
const bool result_is_nullable) {
if (argument_types.size() == 1) {
const IDataType& argument_type = *remove_nullable(argument_types[0]);
WhichDataType which(argument_type);
AggregateFunctionPtr res(
creator_with_numeric_type::create<AggregateFunctionUniqDistributeKey, Data>(
argument_types, result_is_nullable));
if (res) {
return res;
} else if (which.is_decimal32()) {
return creator_without_type::create<
AggregateFunctionUniqDistributeKey<Decimal32, Data<Int32>>>(argument_types,
result_is_nullable);
} else if (which.is_decimal64()) {
return creator_without_type::create<
AggregateFunctionUniqDistributeKey<Decimal64, Data<Int64>>>(argument_types,
result_is_nullable);
} else if (which.is_decimal128v3()) {
return creator_without_type::create<
AggregateFunctionUniqDistributeKey<Decimal128V3, Data<Int128>>>(
argument_types, result_is_nullable);
} else if (which.is_decimal128v2() || which.is_decimal128v3()) {
return creator_without_type::create<
AggregateFunctionUniqDistributeKey<Decimal128V2, Data<Int128>>>(
argument_types, result_is_nullable);
} else if (which.is_string_or_fixed_string()) {
return creator_without_type::create<
AggregateFunctionUniqDistributeKey<String, Data<String>>>(argument_types,
result_is_nullable);
}
}
return nullptr;
}
void register_aggregate_function_uniq_distribute_key(AggregateFunctionSimpleFactory& factory) {
AggregateFunctionCreator creator =
create_aggregate_function_uniq<AggregateFunctionUniqDistributeKeyData>;
factory.register_function_both("multi_distinct_count_distribute_key", creator);
}
} // namespace doris::vectorized

View File

@ -0,0 +1,253 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.
// This file is copied from
// https://github.com/ClickHouse/ClickHouse/blob/master/src/AggregateFunctions/AggregateFunctionUniq.h
// and modified by Doris
#pragma once
#include <stddef.h>
#include <algorithm>
#include <boost/iterator/iterator_facade.hpp>
#include <memory>
#include <vector>
#include "vec/aggregate_functions/aggregate_function.h"
#include "vec/aggregate_functions/aggregate_function_uniq.h"
#include "vec/columns/column.h"
#include "vec/columns/column_fixed_length_object.h"
#include "vec/columns/column_nullable.h"
#include "vec/columns/column_vector.h"
#include "vec/columns/columns_number.h"
#include "vec/common/assert_cast.h"
#include "vec/core/types.h"
#include "vec/data_types/data_type.h"
#include "vec/data_types/data_type_fixed_length_object.h"
#include "vec/data_types/data_type_number.h"
#include "vec/io/var_int.h"
namespace doris {
namespace vectorized {
class Arena;
class BufferReadable;
class BufferWritable;
} // namespace vectorized
} // namespace doris
template <typename T>
struct HashCRC32;
namespace doris::vectorized {
template <typename T>
struct AggregateFunctionUniqDistributeKeyData {
static constexpr bool is_string_key = std::is_same_v<T, String>;
using Key = std::conditional_t<is_string_key, UInt128, T>;
using Hash = std::conditional_t<is_string_key, UInt128TrivialHash, HashCRC32<Key>>;
using Set = flat_hash_set<Key, Hash>;
// TODO: replace SipHash with xxhash to speed up
static UInt128 ALWAYS_INLINE get_key(const StringRef& value) {
auto hash_value = XXH_INLINE_XXH128(value.data, value.size, 0);
return UInt128 {hash_value.high64, hash_value.low64};
}
Set set;
UInt64 count = 0;
};
template <typename T, typename Data>
class AggregateFunctionUniqDistributeKey final
: public IAggregateFunctionDataHelper<Data, AggregateFunctionUniqDistributeKey<T, Data>> {
public:
using KeyType = std::conditional_t<std::is_same_v<T, String>, UInt128, T>;
AggregateFunctionUniqDistributeKey(const DataTypes& argument_types_)
: IAggregateFunctionDataHelper<Data, AggregateFunctionUniqDistributeKey<T, Data>>(
argument_types_) {}
String get_name() const override { return "multi_distinct_distribute_key"; }
DataTypePtr get_return_type() const override { return std::make_shared<DataTypeInt64>(); }
void add(AggregateDataPtr __restrict place, const IColumn** columns, ssize_t row_num,
Arena*) const override {
detail::OneAdder<T, Data>::add(this->data(place), *columns[0], row_num);
}
static ALWAYS_INLINE const KeyType* get_keys(std::vector<KeyType>& keys_container,
const IColumn& column, size_t batch_size) {
if constexpr (std::is_same_v<T, String>) {
keys_container.resize(batch_size);
for (size_t i = 0; i != batch_size; ++i) {
StringRef value = column.get_data_at(i);
keys_container[i] = Data::get_key(value);
}
return keys_container.data();
} else {
using ColumnType =
std::conditional_t<IsDecimalNumber<T>, ColumnDecimal<T>, ColumnVector<T>>;
return assert_cast<const ColumnType&>(column).get_data().data();
}
}
void add_batch(size_t batch_size, AggregateDataPtr* places, size_t place_offset,
const IColumn** columns, Arena* arena, bool /*agg_many*/) const override {
std::vector<KeyType> keys_container;
const KeyType* keys = get_keys(keys_container, *columns[0], batch_size);
std::vector<typename Data::Set*> array_of_data_set(batch_size);
for (size_t i = 0; i != batch_size; ++i) {
array_of_data_set[i] = &(this->data(places[i] + place_offset).set);
}
for (size_t i = 0; i != batch_size; ++i) {
if (i + HASH_MAP_PREFETCH_DIST < batch_size) {
array_of_data_set[i + HASH_MAP_PREFETCH_DIST]->prefetch(
keys[i + HASH_MAP_PREFETCH_DIST]);
}
array_of_data_set[i]->insert(keys[i]);
}
}
void add_batch_single_place(size_t batch_size, AggregateDataPtr place, const IColumn** columns,
Arena* arena) const override {
std::vector<KeyType> keys_container;
const KeyType* keys = get_keys(keys_container, *columns[0], batch_size);
auto& set = this->data(place).set;
for (size_t i = 0; i != batch_size; ++i) {
if (i + HASH_MAP_PREFETCH_DIST < batch_size) {
set.prefetch(keys[i + HASH_MAP_PREFETCH_DIST]);
}
set.insert(keys[i]);
}
}
void merge(AggregateDataPtr __restrict place, ConstAggregateDataPtr rhs,
Arena*) const override {
this->data(place).count += this->data(rhs).count;
}
void serialize(ConstAggregateDataPtr __restrict place, BufferWritable& buf) const override {
write_var_uint(this->data(place).set.size(), buf);
}
void deserialize(AggregateDataPtr __restrict place, BufferReadable& buf,
Arena*) const override {
read_var_uint(this->data(place).count, buf);
}
void insert_result_into(ConstAggregateDataPtr __restrict place, IColumn& to) const override {
assert_cast<ColumnInt64&>(to).get_data().push_back(this->data(place).count);
}
void deserialize_from_column(AggregateDataPtr places, const IColumn& column, Arena* arena,
size_t num_rows) const override {
auto data = reinterpret_cast<const UInt64*>(
assert_cast<const ColumnFixedLengthObject&>(column).get_data().data());
for (size_t i = 0; i != num_rows; ++i) {
auto rhs_place = places + sizeof(Data) * i;
this->create(rhs_place);
(reinterpret_cast<Data*>(rhs_place))->count = data[i];
}
}
void serialize_to_column(const std::vector<AggregateDataPtr>& places, size_t offset,
MutableColumnPtr& dst, const size_t num_rows) const override {
auto& col = assert_cast<ColumnFixedLengthObject&>(*dst);
CHECK(col.item_size() == sizeof(UInt64))
<< "size is not equal: " << col.item_size() << " " << sizeof(UInt64);
col.resize(num_rows);
auto* data = reinterpret_cast<UInt64*>(col.get_data().data());
for (size_t i = 0; i != num_rows; ++i) {
data[i] = this->data(places[i] + offset).set.size();
}
}
void streaming_agg_serialize_to_column(const IColumn** columns, MutableColumnPtr& dst,
const size_t num_rows, Arena* arena) const override {
auto& dst_col = assert_cast<ColumnFixedLengthObject&>(*dst);
CHECK(dst_col.item_size() == sizeof(UInt64))
<< "size is not equal: " << dst_col.item_size() << " " << sizeof(UInt64);
dst_col.resize(num_rows);
auto* data = reinterpret_cast<UInt64*>(dst_col.get_data().data());
for (size_t i = 0; i != num_rows; ++i) {
data[i] = 1;
}
}
void deserialize_and_merge_from_column(AggregateDataPtr __restrict place, const IColumn& column,
Arena* arena) const override {
auto& col = assert_cast<const ColumnFixedLengthObject&>(column);
const size_t num_rows = column.size();
auto* data = reinterpret_cast<const UInt64*>(col.get_data().data());
for (size_t i = 0; i != num_rows; ++i) {
AggregateFunctionUniqDistributeKey::data(place).count += data[i];
}
}
void deserialize_and_merge_from_column_range(AggregateDataPtr __restrict place,
const IColumn& column, size_t begin, size_t end,
Arena* arena) const override {
CHECK(end <= column.size() && begin <= end)
<< ", begin:" << begin << ", end:" << end << ", column.size():" << column.size();
auto& col = assert_cast<const ColumnFixedLengthObject&>(column);
auto* data = reinterpret_cast<const UInt64*>(col.get_data().data());
for (size_t i = begin; i <= end; ++i) {
this->data(place).count += data[i];
}
}
void deserialize_and_merge_vec(const AggregateDataPtr* places, size_t offset,
AggregateDataPtr rhs, const ColumnString* column, Arena* arena,
const size_t num_rows) const override {
this->deserialize_from_column(rhs, *column, arena, num_rows);
DEFER({ this->destroy_vec(rhs, num_rows); });
this->merge_vec(places, offset, rhs, arena, num_rows);
}
void deserialize_and_merge_vec_selected(const AggregateDataPtr* places, size_t offset,
AggregateDataPtr rhs, const ColumnString* column,
Arena* arena, const size_t num_rows) const override {
this->deserialize_from_column(rhs, *column, arena, num_rows);
DEFER({ this->destroy_vec(rhs, num_rows); });
this->merge_vec_selected(places, offset, rhs, arena, num_rows);
}
void serialize_without_key_to_column(ConstAggregateDataPtr __restrict place,
IColumn& to) const override {
auto& col = assert_cast<ColumnFixedLengthObject&>(to);
CHECK(col.item_size() == sizeof(UInt64))
<< "size is not equal: " << col.item_size() << " " << sizeof(UInt64);
size_t old_size = col.size();
col.resize(old_size + 1);
*reinterpret_cast<UInt64*>(col.get_data().data() + old_size) =
AggregateFunctionUniqDistributeKey::data(place).set.size();
}
MutableColumnPtr create_serialize_column() const override {
return ColumnFixedLengthObject::create(sizeof(UInt64));
}
DataTypePtr get_serialized_type() const override {
return std::make_shared<DataTypeFixedLengthObject>();
}
};
} // namespace doris::vectorized