[improvement](agg) iterate aggregation data in memory written order (#12704)

Following the iteration order of the hash table will result in out-of-order access to aggregate states, which is very inefficient.
Traversing aggregate states in memory write order can significantly improve memory read efficiency.

Test
hash table items count: 3.35M

Before this optimization: insert keys into column takes 500ms
With this optimization only takes 80ms
This commit is contained in:
Jerry Hu
2022-09-21 14:58:50 +08:00
committed by GitHub
parent 27f7ae258d
commit 8f4bb0f804
12 changed files with 523 additions and 51 deletions

View File

@ -97,6 +97,8 @@ AggregationNode::AggregationNode(ObjectPool* pool, const TPlanNode& tnode,
_serialize_result_timer(nullptr),
_deserialize_data_timer(nullptr),
_hash_table_compute_timer(nullptr),
_hash_table_iterate_timer(nullptr),
_insert_keys_to_column_timer(nullptr),
_streaming_agg_timer(nullptr),
_hash_table_size_counter(nullptr),
_hash_table_input_counter(nullptr) {
@ -295,6 +297,8 @@ Status AggregationNode::prepare(RuntimeState* state) {
_serialize_result_timer = ADD_TIMER(runtime_profile(), "SerializeResultTime");
_deserialize_data_timer = ADD_TIMER(runtime_profile(), "DeserializeDataTime");
_hash_table_compute_timer = ADD_TIMER(runtime_profile(), "HashTableComputeTime");
_hash_table_iterate_timer = ADD_TIMER(runtime_profile(), "HashTableIterateTime");
_insert_keys_to_column_timer = ADD_TIMER(runtime_profile(), "InsertKeysToColumnTime");
_streaming_agg_timer = ADD_TIMER(runtime_profile(), "StreamingAggTime");
_hash_table_size_counter = ADD_COUNTER(runtime_profile(), "HashTableSize", TUnit::UNIT);
_hash_table_input_counter = ADD_COUNTER(runtime_profile(), "HashTableInputCount", TUnit::UNIT);
@ -384,6 +388,20 @@ Status AggregationNode::prepare(RuntimeState* state) {
_executor.close = std::bind<void>(&AggregationNode::_close_without_key, this);
} else {
_init_hash_method(_probe_expr_ctxs);
std::visit(
[&](auto&& agg_method) {
using HashTableType = std::decay_t<decltype(agg_method.data)>;
using KeyType = typename HashTableType::key_type;
/// some aggregate functions (like AVG for decimal) have align issues.
_aggregate_data_container.reset(new AggregateDataContainer(
sizeof(KeyType),
((_total_size_of_aggregate_states + _align_aggregate_states - 1) /
_align_aggregate_states) *
_align_aggregate_states));
},
_agg_data._aggregated_method_variant);
if (_is_merge) {
_executor.execute = std::bind<Status>(&AggregationNode::_merge_with_serialized_key,
this, std::placeholders::_1);
@ -787,43 +805,64 @@ void AggregationNode::_emplace_into_hash_table(AggregateDataPtr* places, ColumnR
}
}
auto creator = [this](const auto& ctor, const auto& key) {
using KeyType = std::decay_t<decltype(key)>;
if constexpr (HashTableTraits<HashTableType>::is_string_hash_table &&
!std::is_same_v<StringRef, KeyType>) {
StringRef string_ref = to_string_ref(key);
ArenaKeyHolder key_holder {string_ref, _agg_arena_pool};
key_holder_persist_key(key_holder);
auto mapped = _aggregate_data_container->append_data(key_holder.key);
_create_agg_status(mapped);
ctor(key, mapped);
} else {
auto mapped = _aggregate_data_container->append_data(key);
_create_agg_status(mapped);
ctor(key, mapped);
}
};
auto creator_for_null_key = [this](auto& mapped) {
mapped = _agg_arena_pool.aligned_alloc(_total_size_of_aggregate_states,
_align_aggregate_states);
_create_agg_status(mapped);
};
/// For all rows.
COUNTER_UPDATE(_hash_table_input_counter, num_rows);
for (size_t i = 0; i < num_rows; ++i) {
AggregateDataPtr aggregate_data = nullptr;
auto emplace_result = [&]() {
if constexpr (HashTableTraits<HashTableType>::is_phmap) {
if (LIKELY(i + HASH_MAP_PREFETCH_DIST < num_rows)) {
if constexpr (HashTableTraits<HashTableType>::is_parallel_phmap) {
agg_method.data.prefetch_by_key(state.get_key_holder(
i + HASH_MAP_PREFETCH_DIST, _agg_arena_pool));
} else
agg_method.data.prefetch_by_hash(
_hash_values[i + HASH_MAP_PREFETCH_DIST]);
}
return state.emplace_key(agg_method.data, _hash_values[i], i,
_agg_arena_pool);
} else {
return state.emplace_key(agg_method.data, i, _agg_arena_pool);
AggregateDataPtr mapped = nullptr;
if constexpr (HashTableTraits<HashTableType>::is_phmap) {
if (LIKELY(i + HASH_MAP_PREFETCH_DIST < num_rows)) {
if constexpr (HashTableTraits<HashTableType>::is_parallel_phmap) {
agg_method.data.prefetch_by_key(state.get_key_holder(
i + HASH_MAP_PREFETCH_DIST, _agg_arena_pool));
} else
agg_method.data.prefetch_by_hash(
_hash_values[i + HASH_MAP_PREFETCH_DIST]);
}
}();
/// If a new key is inserted, initialize the states of the aggregate functions, and possibly something related to the key.
if (emplace_result.is_inserted()) {
/// exception-safety - if you can not allocate memory or create states, then destructors will not be called.
emplace_result.set_mapped(nullptr);
if constexpr (ColumnsHashing::IsSingleNullableColumnMethod<
AggState>::value) {
mapped = state.lazy_emplace_key(agg_method.data, _hash_values[i], i,
_agg_arena_pool, creator,
creator_for_null_key);
} else {
mapped = state.lazy_emplace_key(agg_method.data, _hash_values[i], i,
_agg_arena_pool, creator);
}
} else {
if constexpr (ColumnsHashing::IsSingleNullableColumnMethod<
AggState>::value) {
mapped = state.lazy_emplace_key(agg_method.data, i, _agg_arena_pool,
creator, creator_for_null_key);
} else {
mapped = state.lazy_emplace_key(agg_method.data, i, _agg_arena_pool,
creator);
}
}
aggregate_data = _agg_arena_pool.aligned_alloc(
_total_size_of_aggregate_states, _align_aggregate_states);
_create_agg_status(aggregate_data);
emplace_result.set_mapped(aggregate_data);
} else
aggregate_data = emplace_result.get_mapped();
places[i] = aggregate_data;
places[i] = mapped;
assert(places[i] != nullptr);
}
},
@ -1051,24 +1090,33 @@ Status AggregationNode::_get_with_serialized_key_result(RuntimeState* state, Blo
std::visit(
[&](auto&& agg_method) -> void {
auto& data = agg_method.data;
auto& iter = agg_method.iterator;
agg_method.init_once();
const auto size = std::min(data.size(), size_t(state->batch_size()));
using KeyType = std::decay_t<decltype(iter->get_first())>;
using KeyType = std::decay_t<decltype(agg_method.iterator->get_first())>;
std::vector<KeyType> keys(size);
if (_values.size() < size) {
_values.resize(size);
}
size_t num_rows = 0;
while (iter != data.end() && num_rows < state->batch_size()) {
keys[num_rows] = iter->get_first();
_values[num_rows] = iter->get_second();
++iter;
++num_rows;
_aggregate_data_container->init_once();
auto& iter = _aggregate_data_container->iterator;
{
SCOPED_TIMER(_hash_table_iterate_timer);
while (iter != _aggregate_data_container->end() &&
num_rows < state->batch_size()) {
keys[num_rows] = iter.get_key<KeyType>();
_values[num_rows] = iter.get_aggregate_data();
++iter;
++num_rows;
}
}
agg_method.insert_keys_into_columns(keys, key_columns, num_rows, _probe_key_sz);
{
SCOPED_TIMER(_insert_keys_to_column_timer);
agg_method.insert_keys_into_columns(keys, key_columns, num_rows, _probe_key_sz);
}
for (size_t i = 0; i < _aggregate_evaluators.size(); ++i) {
_aggregate_evaluators[i]->insert_result_info_vec(
@ -1076,7 +1124,7 @@ Status AggregationNode::_get_with_serialized_key_result(RuntimeState* state, Blo
num_rows);
}
if (iter == data.end()) {
if (iter == _aggregate_data_container->end()) {
if (agg_method.data.has_null_key_data()) {
// only one key of group by support wrap null key
// here need additional processing logic on the null key / value
@ -1137,27 +1185,37 @@ Status AggregationNode::_serialize_with_serialized_key_result(RuntimeState* stat
[&](auto&& agg_method) -> void {
agg_method.init_once();
auto& data = agg_method.data;
auto& iter = agg_method.iterator;
const auto size = std::min(data.size(), size_t(state->batch_size()));
using KeyType = std::decay_t<decltype(iter->get_first())>;
using KeyType = std::decay_t<decltype(agg_method.iterator->get_first())>;
std::vector<KeyType> keys(size);
if (_values.size() < size + 1) {
_values.resize(size + 1);
}
size_t num_rows = 0;
while (iter != data.end() && num_rows < state->batch_size()) {
keys[num_rows] = iter->get_first();
_values[num_rows] = iter->get_second();
++iter;
++num_rows;
_aggregate_data_container->init_once();
auto& iter = _aggregate_data_container->iterator;
{
SCOPED_TIMER(_hash_table_iterate_timer);
while (iter != _aggregate_data_container->end() &&
num_rows < state->batch_size()) {
keys[num_rows] = iter.get_key<KeyType>();
_values[num_rows] = iter.get_aggregate_data();
++iter;
++num_rows;
}
}
agg_method.insert_keys_into_columns(keys, key_columns, num_rows, _probe_key_sz);
{
SCOPED_TIMER(_insert_keys_to_column_timer);
agg_method.insert_keys_into_columns(keys, key_columns, num_rows, _probe_key_sz);
}
if (iter == data.end()) {
if (iter == _aggregate_data_container->end()) {
if (agg_method.data.has_null_key_data()) {
// only one key of group by support wrap null key
// here need additional processing logic on the null key / value
DCHECK(key_columns.size() == 1);
DCHECK(key_columns[0]->is_nullable());
if (agg_method.data.has_null_key_data()) {