diff --git a/be/src/exprs/runtime_filter.cpp b/be/src/exprs/runtime_filter.cpp index b8700be923..1245ce370d 100644 --- a/be/src/exprs/runtime_filter.cpp +++ b/be/src/exprs/runtime_filter.cpp @@ -1057,6 +1057,8 @@ public: } } + friend class IRuntimeFilter; + private: RuntimeState* _state; ObjectPool* _pool; @@ -1082,6 +1084,37 @@ Status IRuntimeFilter::create(RuntimeState* state, ObjectPool* pool, const TRunt return (*res)->init_with_desc(desc, query_options, fragment_instance_id, node_id); } +Status IRuntimeFilter::apply_from_other(IRuntimeFilter* other) { + auto copy_hybrid_set = [](HybridSetBase* src, HybridSetBase* dst) { + auto it = src->begin(); + while (it->has_next()) { + dst->insert(it->get_value()); + it->next(); + } + }; + switch (other->_wrapper->_filter_type) { + case RuntimeFilterType::IN_FILTER: + copy_hybrid_set(other->_wrapper->_hybrid_set.get(), _wrapper->_hybrid_set.get()); + break; + case RuntimeFilterType::BLOOM_FILTER: + _wrapper->_bloomfilter_func->light_copy(other->_wrapper->get_bloomfilter()); + break; + case RuntimeFilterType::MINMAX_FILTER: + *(_wrapper->_minmax_func) = *(other->_wrapper->_minmax_func); + break; + case RuntimeFilterType::IN_OR_BLOOM_FILTER: + copy_hybrid_set(other->_wrapper->_hybrid_set.get(), _wrapper->_hybrid_set.get()); + _wrapper->_bloomfilter_func->light_copy(other->_wrapper->get_bloomfilter()); + break; + default: + return Status::InvalidArgument("unknown filter type"); + break; + } + _wrapper->_filter_type = other->_wrapper->_filter_type; + _runtime_filter_type = other->_runtime_filter_type; + return Status::OK(); +} + void IRuntimeFilter::insert(const void* data) { DCHECK(is_producer()); if (!_is_ignored) { @@ -1108,7 +1141,7 @@ Status IRuntimeFilter::publish() { RETURN_IF_ERROR( _state->runtime_filter_mgr()->get_consume_filter(_filter_id, &consumer_filter)); // push down - std::swap(this->_wrapper, consumer_filter->_wrapper); + consumer_filter->_wrapper = _wrapper; consumer_filter->update_runtime_filter_type_to_profile(); consumer_filter->signal(); return Status::OK(); diff --git a/be/src/exprs/runtime_filter.h b/be/src/exprs/runtime_filter.h index 14ed389531..b777cf80f7 100644 --- a/be/src/exprs/runtime_filter.h +++ b/be/src/exprs/runtime_filter.h @@ -141,6 +141,8 @@ public: const TQueryOptions* query_options, const RuntimeFilterRole role, int node_id, IRuntimeFilter** res); + Status apply_from_other(IRuntimeFilter* other); + // insert data to build filter // only used for producer void insert(const void* data); @@ -246,6 +248,8 @@ public: return be_exec_version > 0 && (is_int_or_bool(type) || is_float_or_double(type)); } + int filter_id() const { return _filter_id; } + protected: // serialize _wrapper to protobuf void to_protobuf(PInFilter* filter); diff --git a/be/src/exprs/runtime_filter_slots.h b/be/src/exprs/runtime_filter_slots.h index 43fca0a71f..b13fa69625 100644 --- a/be/src/exprs/runtime_filter_slots.h +++ b/be/src/exprs/runtime_filter_slots.h @@ -229,6 +229,24 @@ public: } } + Status apply_from_other(RuntimeFilterSlotsBase* other) { + for (auto& it : _runtime_filters) { + auto& other_filters = other->_runtime_filters[it.first]; + for (auto& filter : it.second) { + auto filter_id = filter->filter_id(); + auto ret = std::find_if(other_filters.begin(), other_filters.end(), + [&](IRuntimeFilter* other_filter) { + return other_filter->filter_id() == filter_id; + }); + if (ret == other_filters.end()) { + return Status::Aborted("invalid runtime filter id: {}", filter_id); + } + filter->apply_from_other(*ret); + } + } + return Status::OK(); + } + bool empty() { return !_runtime_filters.size(); } private: diff --git a/be/src/runtime/fragment_mgr.cpp b/be/src/runtime/fragment_mgr.cpp index 8085ddf13b..58b3b1d95b 100644 --- a/be/src/runtime/fragment_mgr.cpp +++ b/be/src/runtime/fragment_mgr.cpp @@ -240,7 +240,13 @@ Status FragmentExecState::execute() { } #ifndef BE_TEST if (_executor.runtime_state()->is_cancelled()) { - return Status::Cancelled("cancelled before execution"); + Status status = Status::Cancelled("cancelled before execution"); + _executor.runtime_state() + ->get_query_fragments_ctx() + ->get_shared_hash_table_controller() + ->release_ref_count_if_need(_executor.runtime_state()->fragment_instance_id(), + status); + return status; } #endif int64_t duration_ns = 0; @@ -248,10 +254,18 @@ Status FragmentExecState::execute() { SCOPED_RAW_TIMER(&duration_ns); CgroupsMgr::apply_system_cgroup(); opentelemetry::trace::Tracer::GetCurrentSpan()->AddEvent("start executing Fragment"); - WARN_IF_ERROR(_executor.open(), strings::Substitute("Got error while opening fragment $0", - print_id(_fragment_instance_id))); + Status status = _executor.open(); + WARN_IF_ERROR(status, strings::Substitute("Got error while opening fragment $0", + print_id(_fragment_instance_id))); _executor.close(); + if (!status.ok()) { + _executor.runtime_state() + ->get_query_fragments_ctx() + ->get_shared_hash_table_controller() + ->release_ref_count_if_need(_executor.runtime_state()->fragment_instance_id(), + status); + } } DorisMetrics::instance()->fragment_requests_total->increment(1); DorisMetrics::instance()->fragment_request_duration_us->increment(duration_ns / 1000); @@ -697,6 +711,9 @@ Status FragmentMgr::exec_plan_fragment(const TExecPlanFragmentParams& params, Fi } g_fragmentmgr_prepare_latency << (duration_ns / 1000); + _setup_shared_hashtable_for_broadcast_join(params, exec_state->executor()->runtime_state(), + fragments_ctx.get()); + std::shared_ptr handler; _runtimefilter_controller.add_entity(params, &handler, exec_state->executor()->runtime_state()); exec_state->set_merge_controller_handler(handler); @@ -766,6 +783,26 @@ void FragmentMgr::_set_scan_concurrency(const TExecPlanFragmentParams& params, #endif } +void FragmentMgr::_setup_shared_hashtable_for_broadcast_join(const TExecPlanFragmentParams& params, + RuntimeState* state, + QueryFragmentsCtx* fragments_ctx) { + if (!params.__isset.fragment || !params.fragment.__isset.plan || + params.fragment.plan.nodes.empty()) { + return; + } + + for (auto& node : params.fragment.plan.nodes) { + if (node.node_type != TPlanNodeType::HASH_JOIN_NODE || + !node.hash_join_node.__isset.is_broadcast_join || + !node.hash_join_node.is_broadcast_join) { + continue; + } + + std::lock_guard lock(_lock_for_shared_hash_table); + fragments_ctx->get_shared_hash_table_controller()->acquire_ref_count(state, node.node_id); + } +} + bool FragmentMgr::_is_scan_node(const TPlanNodeType::type& type) { return type == TPlanNodeType::OLAP_SCAN_NODE || type == TPlanNodeType::MYSQL_SCAN_NODE || type == TPlanNodeType::SCHEMA_SCAN_NODE || type == TPlanNodeType::META_SCAN_NODE || diff --git a/be/src/runtime/fragment_mgr.h b/be/src/runtime/fragment_mgr.h index 08be2edc0c..7c6a079e32 100644 --- a/be/src/runtime/fragment_mgr.h +++ b/be/src/runtime/fragment_mgr.h @@ -107,6 +107,10 @@ private: bool _is_scan_node(const TPlanNodeType::type& type); + void _setup_shared_hashtable_for_broadcast_join(const TExecPlanFragmentParams& params, + RuntimeState* state, + QueryFragmentsCtx* fragments_ctx); + // This is input params ExecEnv* _exec_env; @@ -114,6 +118,9 @@ private: std::condition_variable _cv; + std::mutex _lock_for_shared_hash_table; + std::condition_variable _cv_for_sharing_hashtable; + // Make sure that remove this before no data reference FragmentExecState std::unordered_map> _fragment_map; // query id -> QueryFragmentsCtx diff --git a/be/src/runtime/query_fragments_ctx.h b/be/src/runtime/query_fragments_ctx.h index 1fc58f2f28..1f7c053db1 100644 --- a/be/src/runtime/query_fragments_ctx.h +++ b/be/src/runtime/query_fragments_ctx.h @@ -18,6 +18,7 @@ #pragma once #include +#include #include #include "common/config.h" @@ -29,6 +30,7 @@ #include "runtime/memory/mem_tracker_limiter.h" #include "util/pretty_printer.h" #include "util/threadpool.h" +#include "vec/runtime/shared_hash_table_controller.h" namespace doris { @@ -41,6 +43,7 @@ public: QueryFragmentsCtx(int total_fragment_num, ExecEnv* exec_env) : fragment_num(total_fragment_num), timeout_second(-1), _exec_env(exec_env) { _start_time = DateTimeValue::local_time(); + _shared_hash_table_controller.reset(new vectorized::SharedHashTableController()); } ~QueryFragmentsCtx() { @@ -97,6 +100,10 @@ public: return _ready_to_execute.load() && !_is_cancelled.load(); } + vectorized::SharedHashTableController* get_shared_hash_table_controller() { + return _shared_hash_table_controller.get(); + } + public: TUniqueId query_id; DescriptorTbl* desc_tbl; @@ -136,6 +143,8 @@ private: // And all fragments of this query will start execution when this is set to true. std::atomic _ready_to_execute {false}; std::atomic _is_cancelled {false}; + + std::unique_ptr _shared_hash_table_controller; }; } // namespace doris diff --git a/be/src/vec/CMakeLists.txt b/be/src/vec/CMakeLists.txt index 61cd8e666d..9be4c6acbe 100644 --- a/be/src/vec/CMakeLists.txt +++ b/be/src/vec/CMakeLists.txt @@ -228,8 +228,9 @@ set(VEC_FILES runtime/vpartition_info.cpp runtime/vparquet_writer.cpp runtime/vorc_writer.cpp - utils/arrow_column_to_doris_column.cpp runtime/vsorted_run_merger.cpp + runtime/shared_hash_table_controller.cpp + utils/arrow_column_to_doris_column.cpp exec/format/parquet/vparquet_column_chunk_reader.cpp exec/format/parquet/vparquet_group_reader.cpp exec/format/parquet/vparquet_page_index.cpp diff --git a/be/src/vec/exec/join/vhash_join_node.cpp b/be/src/vec/exec/join/vhash_join_node.cpp index 8c0d5b3f7a..f6a8db2e2f 100644 --- a/be/src/vec/exec/join/vhash_join_node.cpp +++ b/be/src/vec/exec/join/vhash_join_node.cpp @@ -188,21 +188,22 @@ struct ProcessRuntimeFilterBuild { if (_join_node->_runtime_filter_descs.empty()) { return Status::OK(); } - VRuntimeFilterSlots runtime_filter_slots(_join_node->_probe_expr_ctxs, - _join_node->_build_expr_ctxs, - _join_node->_runtime_filter_descs); + _join_node->_runtime_filter_slots = _join_node->_pool->add( + new VRuntimeFilterSlots(_join_node->_probe_expr_ctxs, _join_node->_build_expr_ctxs, + _join_node->_runtime_filter_descs)); - RETURN_IF_ERROR(runtime_filter_slots.init(state, hash_table_ctx.hash_table.get_size())); + RETURN_IF_ERROR(_join_node->_runtime_filter_slots->init( + state, hash_table_ctx.hash_table_ptr->get_size())); - if (!runtime_filter_slots.empty() && !_join_node->_inserted_rows.empty()) { + if (!_join_node->_runtime_filter_slots->empty() && !_join_node->_inserted_rows.empty()) { { SCOPED_TIMER(_join_node->_push_compute_timer); - runtime_filter_slots.insert(_join_node->_inserted_rows); + _join_node->_runtime_filter_slots->insert(_join_node->_inserted_rows); } } { SCOPED_TIMER(_join_node->_push_down_timer); - runtime_filter_slots.publish(); + _join_node->_runtime_filter_slots->publish(); } return Status::OK(); @@ -395,14 +396,15 @@ Status ProcessHashTableProbe::do_process(HashTableType& hash_table_c } int last_offset = current_offset; auto find_result = - !need_null_map_for_probe - ? key_getter.find_key(hash_table_ctx.hash_table, probe_index, _arena) + !need_null_map_for_probe ? key_getter.find_key(*hash_table_ctx.hash_table_ptr, + probe_index, _arena) : (*null_map)[probe_index] - ? decltype(key_getter.find_key(hash_table_ctx.hash_table, probe_index, - _arena)) {nullptr, false} - : key_getter.find_key(hash_table_ctx.hash_table, probe_index, _arena); + ? decltype(key_getter.find_key(*hash_table_ctx.hash_table_ptr, + probe_index, _arena)) {nullptr, false} + : key_getter.find_key(*hash_table_ctx.hash_table_ptr, probe_index, + _arena); if (probe_index + PREFETCH_STEP < probe_rows) - key_getter.template prefetch(hash_table_ctx.hash_table, + key_getter.template prefetch(*hash_table_ctx.hash_table_ptr, probe_index + PREFETCH_STEP, _arena); if constexpr (JoinOpType::value == TJoinOp::LEFT_ANTI_JOIN || @@ -559,14 +561,15 @@ Status ProcessHashTableProbe::do_process_with_other_join_conjuncts( auto last_offset = current_offset; auto find_result = - !need_null_map_for_probe - ? key_getter.find_key(hash_table_ctx.hash_table, probe_index, _arena) + !need_null_map_for_probe ? key_getter.find_key(*hash_table_ctx.hash_table_ptr, + probe_index, _arena) : (*null_map)[probe_index] - ? decltype(key_getter.find_key(hash_table_ctx.hash_table, probe_index, - _arena)) {nullptr, false} - : key_getter.find_key(hash_table_ctx.hash_table, probe_index, _arena); + ? decltype(key_getter.find_key(*hash_table_ctx.hash_table_ptr, + probe_index, _arena)) {nullptr, false} + : key_getter.find_key(*hash_table_ctx.hash_table_ptr, probe_index, + _arena); if (probe_index + PREFETCH_STEP < probe_rows) - key_getter.template prefetch(hash_table_ctx.hash_table, + key_getter.template prefetch(*hash_table_ctx.hash_table_ptr, probe_index + PREFETCH_STEP, _arena); if (find_result.is_found()) { auto& mapped = find_result.get_mapped(); @@ -810,7 +813,7 @@ Status ProcessHashTableProbe::process_data_in_hashtable(HashTableTyp } }; - for (; iter != hash_table_ctx.hash_table.end() && block_size < _batch_size; ++iter) { + for (; iter != hash_table_ctx.hash_table_ptr->end() && block_size < _batch_size; ++iter) { auto& mapped = iter->get_second(); if constexpr (std::is_same_v) { if (mapped.visited) { @@ -853,7 +856,7 @@ Status ProcessHashTableProbe::process_data_in_hashtable(HashTableTyp } _tuple_is_null_left_flags->resize_fill(block_size, 1); } - *eos = iter == hash_table_ctx.hash_table.end(); + *eos = iter == hash_table_ctx.hash_table_ptr->end(); output_block->swap( mutable_block.to_block(right_semi_anti_without_other ? right_col_idx : 0)); return Status::OK(); @@ -879,6 +882,8 @@ HashJoinNode::HashJoinNode(ObjectPool* pool, const TPlanNode& tnode, const Descr _is_right_semi_anti(_join_op == TJoinOp::RIGHT_ANTI_JOIN || _join_op == TJoinOp::RIGHT_SEMI_JOIN), _is_outer_join(_match_all_build || _match_all_probe), + _is_broadcast_join(tnode.hash_join_node.__isset.is_broadcast_join && + tnode.hash_join_node.is_broadcast_join), _hash_output_slot_ids(tnode.hash_join_node.__isset.hash_output_slot_ids ? tnode.hash_join_node.hash_output_slot_ids : std::vector {}), @@ -1052,6 +1057,10 @@ Status HashJoinNode::prepare(RuntimeState* state) { _push_compute_timer = ADD_TIMER(runtime_profile(), "PushDownComputeTime"); _build_buckets_counter = ADD_COUNTER(runtime_profile(), "BuildBuckets", TUnit::UNIT); + if (_is_broadcast_join) { + runtime_profile()->add_info_string("BroadcastJoin", "true"); + } + RETURN_IF_ERROR(VExpr::prepare(_build_expr_ctxs, state, child(1)->row_desc())); RETURN_IF_ERROR(VExpr::prepare(_probe_expr_ctxs, state, child(0)->row_desc())); @@ -1079,6 +1088,11 @@ Status HashJoinNode::close(RuntimeState* state) { return Status::OK(); } + if (_shared_hashtable_controller) { + _shared_hashtable_controller->release_ref_count(state, id()); + _shared_hashtable_controller->wait_for_closable(state, id()); + } + START_AND_SCOPE_SPAN(state->get_tracer(), span, "ashJoinNode::close"); VExpr::close(_build_expr_ctxs, state); VExpr::close(_probe_expr_ctxs, state); @@ -1343,6 +1357,14 @@ Status HashJoinNode::_hash_table_build(RuntimeState* state) { // make one block for each 4 gigabytes constexpr static auto BUILD_BLOCK_MAX_SIZE = 4 * 1024UL * 1024UL * 1024UL; + auto should_build_hash_table = true; + if (_is_broadcast_join) { + _shared_hashtable_controller = + state->get_query_fragments_ctx()->get_shared_hash_table_controller(); + should_build_hash_table = + _shared_hashtable_controller->should_build_hash_table(state, id()); + } + Block block; // If eos or have already met a null value using short-circuit strategy, we do not need to pull // data from data. @@ -1352,6 +1374,10 @@ Status HashJoinNode::_hash_table_build(RuntimeState* state) { RETURN_IF_ERROR_AND_CHECK_SPAN(child(1)->get_next_after_projects(state, &block, &eos), child(1)->get_next_span(), eos); + if (!should_build_hash_table) { + continue; + } + _mem_used += block.allocated_bytes(); if (block.rows() != 0) { @@ -1376,7 +1402,8 @@ Status HashJoinNode::_hash_table_build(RuntimeState* state) { } } - if (!mutable_block.empty() && !_short_circuit_for_null_in_probe_side) { + if (should_build_hash_table && !mutable_block.empty() && + !_short_circuit_for_null_in_probe_side) { if (_build_blocks.size() == _MAX_BUILD_BLOCK_COUNT) { return Status::NotSupported( strings::Substitute("data size of right table in hash join > $0", @@ -1390,8 +1417,37 @@ Status HashJoinNode::_hash_table_build(RuntimeState* state) { [&](auto&& arg) -> Status { using HashTableCtxType = std::decay_t; if constexpr (!std::is_same_v) { - ProcessRuntimeFilterBuild runtime_filter_build_process(this); - return runtime_filter_build_process(state, arg); + using HashTableType = typename HashTableCtxType::HashTable; + if (!should_build_hash_table) { + auto& ret = _shared_hashtable_controller->wait_for_hash_table(id()); + if (!ret.status.ok()) { + return ret.status; + } + arg.hash_table_ptr = reinterpret_cast(ret.hash_table_ptr); + _build_blocks = *ret.blocks; + _runtime_filter_slots = _pool->add(new VRuntimeFilterSlots( + _probe_expr_ctxs, _build_expr_ctxs, _runtime_filter_descs)); + RETURN_IF_ERROR( + _runtime_filter_slots->init(state, arg.hash_table_ptr->get_size())); + RETURN_IF_ERROR( + _runtime_filter_slots->apply_from_other(ret.runtime_filter_slots)); + { + SCOPED_TIMER(_push_down_timer); + _runtime_filter_slots->publish(); + } + return Status::OK(); + } else { + arg.hash_table_ptr = &arg.hash_table; + ProcessRuntimeFilterBuild runtime_filter_build_process( + this); + auto ret = runtime_filter_build_process(state, arg); + if (_shared_hashtable_controller) { + SharedHashTableEntry entry(ret, arg.hash_table_ptr, &_build_blocks, + _runtime_filter_slots); + _shared_hashtable_controller->put_hash_table(std::move(entry), id()); + } + return ret; + } } else { LOG(FATAL) << "FATAL: uninited hash table"; } diff --git a/be/src/vec/exec/join/vhash_join_node.h b/be/src/vec/exec/join/vhash_join_node.h index 977edcb29c..277795a4a7 100644 --- a/be/src/vec/exec/join/vhash_join_node.h +++ b/be/src/vec/exec/join/vhash_join_node.h @@ -28,6 +28,7 @@ #include "vec/exec/join/join_op.h" #include "vec/exec/join/vacquire_list.hpp" #include "vec/functions/function.h" +#include "vec/runtime/shared_hash_table_controller.h" namespace doris { namespace vectorized { @@ -40,6 +41,7 @@ struct SerializedHashTableContext { using Iter = typename HashTable::iterator; HashTable hash_table; + HashTable* hash_table_ptr = &hash_table; Iter iter; bool inited = false; @@ -71,6 +73,7 @@ struct PrimaryTypeHashTableContext { using Iter = typename HashTable::iterator; HashTable hash_table; + HashTable* hash_table_ptr = &hash_table; Iter iter; bool inited = false; @@ -105,6 +108,7 @@ struct FixedKeyHashTableContext { using Iter = typename HashTable::iterator; HashTable hash_table; + HashTable* hash_table_ptr = &hash_table; Iter iter; bool inited = false; @@ -353,6 +357,9 @@ private: // 3. In probe phase, if _short_circuit_for_null_in_probe_side is true, join node returns empty block directly. Otherwise, probing will continue as the same as generic left anti join. bool _short_circuit_for_null_in_build_side = false; bool _short_circuit_for_null_in_probe_side = false; + bool _is_broadcast_join = false; + SharedHashTableController* _shared_hashtable_controller = nullptr; + VRuntimeFilterSlots* _runtime_filter_slots; Block _join_block; diff --git a/be/src/vec/runtime/shared_hash_table_controller.cpp b/be/src/vec/runtime/shared_hash_table_controller.cpp new file mode 100644 index 0000000000..8ca3656ad4 --- /dev/null +++ b/be/src/vec/runtime/shared_hash_table_controller.cpp @@ -0,0 +1,132 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include "shared_hash_table_controller.h" + +#include + +namespace doris { +namespace vectorized { + +bool SharedHashTableController::should_build_hash_table(RuntimeState* state, int my_node_id) { + std::lock_guard lock(_mutex); + auto it = _builder_fragment_ids.find(my_node_id); + if (it == _builder_fragment_ids.cend()) { + _builder_fragment_ids[my_node_id] = state->fragment_instance_id(); + return true; + } + return false; +} + +bool SharedHashTableController::supposed_to_build_hash_table(RuntimeState* state, int my_node_id) { + std::lock_guard lock(_mutex); + auto it = _builder_fragment_ids.find(my_node_id); + if (it != _builder_fragment_ids.cend()) { + return _builder_fragment_ids[my_node_id] == state->fragment_instance_id(); + } + return false; +} + +void SharedHashTableController::put_hash_table(SharedHashTableEntry&& entry, int my_node_id) { + std::lock_guard lock(_mutex); + DCHECK(_hash_table_entries.find(my_node_id) == _hash_table_entries.cend()); + _hash_table_entries.insert({my_node_id, std::move(entry)}); + _cv.notify_all(); +} + +SharedHashTableEntry& SharedHashTableController::wait_for_hash_table(int my_node_id) { + std::unique_lock lock(_mutex); + auto it = _hash_table_entries.find(my_node_id); + if (it == _hash_table_entries.cend()) { + _cv.wait(lock, [this, &it, my_node_id]() { + it = _hash_table_entries.find(my_node_id); + return it != _hash_table_entries.cend(); + }); + } + return it->second; +} + +void SharedHashTableController::acquire_ref_count(RuntimeState* state, int my_node_id) { + std::unique_lock lock(_mutex); + _ref_fragments[my_node_id].emplace_back(state->fragment_instance_id()); +} + +Status SharedHashTableController::release_ref_count(RuntimeState* state, int my_node_id) { + std::unique_lock lock(_mutex); + auto id = state->fragment_instance_id(); + auto it = std::find(_ref_fragments[my_node_id].begin(), _ref_fragments[my_node_id].end(), id); + CHECK(it != _ref_fragments[my_node_id].end()); + _ref_fragments[my_node_id].erase(it); + _put_an_empty_entry_if_need(Status::Cancelled("hash table not build"), id, my_node_id); + _cv.notify_all(); + return Status::OK(); +} + +void SharedHashTableController::_put_an_empty_entry_if_need(Status status, TUniqueId fragment_id, + int node_id) { + auto builder_it = _builder_fragment_ids.find(node_id); + if (builder_it != _builder_fragment_ids.end()) { + if (builder_it->second == fragment_id) { + if (_hash_table_entries.find(builder_it->first) == _hash_table_entries.cend()) { + // "here put an empty SharedHashTableEntry to avoid deadlocking" + _hash_table_entries.insert( + {builder_it->first, SharedHashTableEntry::empty_entry_with_status(status)}); + } + } + } +} + +Status SharedHashTableController::release_ref_count_if_need(TUniqueId fragment_id, Status status) { + std::unique_lock lock(_mutex); + bool need_to_notify = false; + for (auto& ref : _ref_fragments) { + auto it = std::find(ref.second.begin(), ref.second.end(), fragment_id); + if (it == ref.second.end()) { + continue; + } + ref.second.erase(it); + need_to_notify = true; + LOG(INFO) << "release_ref_count in node: " << ref.first + << " for fragment id: " << fragment_id; + } + + for (auto& builder : _builder_fragment_ids) { + if (builder.second == fragment_id) { + if (_hash_table_entries.find(builder.first) == _hash_table_entries.cend()) { + _hash_table_entries.insert( + {builder.first, SharedHashTableEntry::empty_entry_with_status(status)}); + } + } + } + + if (need_to_notify) { + _cv.notify_all(); + } + return Status::OK(); +} + +Status SharedHashTableController::wait_for_closable(RuntimeState* state, int my_node_id) { + std::unique_lock lock(_mutex); + RETURN_IF_CANCELLED(state); + if (!_ref_fragments[my_node_id].empty()) { + _cv.wait(lock, [&]() { return _ref_fragments[my_node_id].empty(); }); + } + return Status::OK(); +} + +} // namespace vectorized +} // namespace doris \ No newline at end of file diff --git a/be/src/vec/runtime/shared_hash_table_controller.h b/be/src/vec/runtime/shared_hash_table_controller.h new file mode 100644 index 0000000000..8f45be3bec --- /dev/null +++ b/be/src/vec/runtime/shared_hash_table_controller.h @@ -0,0 +1,89 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#pragma once + +#include +#include +#include +#include +#include + +#include "vec/core/block.h" + +namespace doris { + +class RuntimeState; +class TUniqueId; + +template +class RuntimeFilterSlotsBase; + +namespace vectorized { + +class VExprContext; + +struct SharedHashTableEntry { + SharedHashTableEntry(Status status_, void* hash_table_ptr_, std::vector* blocks_, + RuntimeFilterSlotsBase* runtime_filter_slots_) + : status(status_), + hash_table_ptr(hash_table_ptr_), + blocks(blocks_), + runtime_filter_slots(runtime_filter_slots_) {} + SharedHashTableEntry(SharedHashTableEntry&& entry) + : status(entry.status), + hash_table_ptr(entry.hash_table_ptr), + blocks(entry.blocks), + runtime_filter_slots(entry.runtime_filter_slots) {} + + static SharedHashTableEntry empty_entry_with_status(const Status& status) { + return SharedHashTableEntry(status, nullptr, nullptr, nullptr); + } + + Status status; + void* hash_table_ptr; + std::vector* blocks; + RuntimeFilterSlotsBase* runtime_filter_slots; +}; + +class SharedHashTableController { +public: + bool should_build_hash_table(RuntimeState* state, int my_node_id); + bool supposed_to_build_hash_table(RuntimeState* state, int my_node_id); + void acquire_ref_count(RuntimeState* state, int my_node_id); + SharedHashTableEntry& wait_for_hash_table(int my_node_id); + Status release_ref_count(RuntimeState* state, int my_node_id); + Status release_ref_count_if_need(TUniqueId fragment_id, Status status); + void put_hash_table(SharedHashTableEntry&& entry, int my_node_id); + Status wait_for_closable(RuntimeState* state, int my_node_id); + +private: + // If the fragment instance was supposed to build hash table, but it didn't build. + // To avoid deadlocking other fragment instances, + // here need to put an empty SharedHashTableEntry with canceled status. + void _put_an_empty_entry_if_need(Status status, TUniqueId fragment_id, int node_id); + +private: + std::mutex _mutex; + std::condition_variable _cv; + std::map _builder_fragment_ids; + std::map _hash_table_entries; + std::map> _ref_fragments; +}; + +} // namespace vectorized +} // namespace doris \ No newline at end of file diff --git a/be/src/vec/runtime/shared_hashtable_controller.cpp b/be/src/vec/runtime/shared_hashtable_controller.cpp new file mode 100644 index 0000000000..a761da8c2e --- /dev/null +++ b/be/src/vec/runtime/shared_hashtable_controller.cpp @@ -0,0 +1,95 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include "shared_hashtable_controller.h" + +#include + +namespace doris { +namespace vectorized { + +bool SharedHashTableController::should_build_hash_table(RuntimeState* state, int my_node_id) { + std::lock_guard lock(_mutex); + auto it = _builder_fragment_ids.find(my_node_id); + if (it == _builder_fragment_ids.cend()) { + _builder_fragment_ids[my_node_id] = state->fragment_instance_id(); + return true; + } + return false; +} + +void SharedHashTableController::put_hash_table(SharedHashTableEntry&& entry, int my_node_id) { + std::lock_guard lock(_mutex); + DCHECK(_hash_table_entries.find(my_node_id) == _hash_table_entries.cend()); + _hash_table_entries.insert({my_node_id, std::move(entry)}); + _cv.notify_all(); +} + +SharedHashTableEntry& SharedHashTableController::wait_for_hash_table(int my_node_id) { + std::unique_lock lock(_mutex); + auto it = _hash_table_entries.find(my_node_id); + if (it == _hash_table_entries.cend()) { + _cv.wait(lock, [this, &it, my_node_id]() { + it = _hash_table_entries.find(my_node_id); + return it != _hash_table_entries.cend(); + }); + } + return it->second; +} + +void SharedHashTableController::acquire_ref_count(RuntimeState* state, int my_node_id) { + std::unique_lock lock(_mutex); + _ref_fragments[my_node_id].emplace_back(state->fragment_instance_id()); +} + +Status SharedHashTableController::release_ref_count(RuntimeState* state, int my_node_id) { + std::unique_lock lock(_mutex); + RETURN_IF_CANCELLED(state); + auto id = state->fragment_instance_id(); + auto it = std::find(_ref_fragments[my_node_id].begin(), _ref_fragments[my_node_id].end(), id); + CHECK(it != _ref_fragments[my_node_id].end()); + _ref_fragments[my_node_id].erase(it); + _cv.notify_all(); + return Status::OK(); +} + +Status SharedHashTableController::release_ref_count_if_need(TUniqueId fragment_id) { + std::unique_lock lock(_mutex); + bool need_to_notify = false; + for (auto& ref : _ref_fragments) { + auto it = std::find(ref.second.begin(), ref.second.end(), fragment_id); + if (it == ref.second.end()) continue; + ref.second.erase(it); + need_to_notify = true; + LOG(INFO) << "release_ref_count in node: " << ref.first + << " for fragment id: " << fragment_id; + } + if (need_to_notify) _cv.notify_all(); + return Status::OK(); +} + +Status SharedHashTableController::wait_for_closable(RuntimeState* state, int my_node_id) { + std::unique_lock lock(_mutex); + RETURN_IF_CANCELLED(state); + if (!_ref_fragments[my_node_id].empty()) { + _cv.wait(lock, [&]() { return _ref_fragments[my_node_id].empty(); }); + } + return Status::OK(); +} + +} // namespace vectorized +} // namespace doris \ No newline at end of file diff --git a/be/src/vec/runtime/shared_hashtable_controller.h b/be/src/vec/runtime/shared_hashtable_controller.h new file mode 100644 index 0000000000..842dc89e90 --- /dev/null +++ b/be/src/vec/runtime/shared_hashtable_controller.h @@ -0,0 +1,75 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#pragma once + +#include +#include +#include +#include +#include + +#include "vec/core/block.h" + +namespace doris { + +class RuntimeState; +class TUniqueId; + +namespace vectorized { + +class VExprContext; + +struct SharedHashTableEntry { + SharedHashTableEntry(void* hash_table_ptr_, std::vector& blocks_, + std::unordered_map>& inserted_rows_, + const std::vector& exprs) + : hash_table_ptr(hash_table_ptr_), + blocks(blocks_), + inserted_rows(inserted_rows_), + build_exprs(exprs) {} + SharedHashTableEntry(SharedHashTableEntry&& entry) + : hash_table_ptr(entry.hash_table_ptr), + blocks(entry.blocks), + inserted_rows(entry.inserted_rows), + build_exprs(entry.build_exprs) {} + void* hash_table_ptr; + std::vector& blocks; + std::unordered_map>& inserted_rows; + std::vector build_exprs; +}; + +class SharedHashTableController { +public: + bool should_build_hash_table(RuntimeState* state, int my_node_id); + void acquire_ref_count(RuntimeState* state, int my_node_id); + SharedHashTableEntry& wait_for_hash_table(int my_node_id); + Status release_ref_count(RuntimeState* state, int my_node_id); + Status release_ref_count_if_need(TUniqueId fragment_id); + void put_hash_table(SharedHashTableEntry&& entry, int my_node_id); + Status wait_for_closable(RuntimeState* state, int my_node_id); + +private: + std::mutex _mutex; + std::condition_variable _cv; + std::map _builder_fragment_ids; + std::map _hash_table_entries; + std::map> _ref_fragments; +}; + +} // namespace vectorized +} // namespace doris \ No newline at end of file diff --git a/fe/fe-core/src/main/java/org/apache/doris/planner/HashJoinNode.java b/fe/fe-core/src/main/java/org/apache/doris/planner/HashJoinNode.java index 23aa12eec5..d04c5b28c1 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/planner/HashJoinNode.java +++ b/fe/fe-core/src/main/java/org/apache/doris/planner/HashJoinNode.java @@ -1068,6 +1068,7 @@ public class HashJoinNode extends PlanNode { msg.node_type = TPlanNodeType.HASH_JOIN_NODE; msg.hash_join_node = new THashJoinNode(); msg.hash_join_node.join_op = joinOp.toThrift(); + msg.hash_join_node.setIsBroadcastJoin(distrMode == DistributionMode.BROADCAST); for (BinaryPredicate eqJoinPredicate : eqJoinConjuncts) { TEqJoinCondition eqJoinCondition = new TEqJoinCondition(eqJoinPredicate.getChild(0).treeToThrift(), eqJoinPredicate.getChild(1).treeToThrift()); diff --git a/gensrc/thrift/PlanNodes.thrift b/gensrc/thrift/PlanNodes.thrift index f8566c2b5d..7e41a54d6f 100644 --- a/gensrc/thrift/PlanNodes.thrift +++ b/gensrc/thrift/PlanNodes.thrift @@ -579,6 +579,8 @@ struct THashJoinNode { 8: optional Types.TTupleId voutput_tuple_id 9: optional list vintermediate_tuple_id_list + + 10: optional bool is_broadcast_join; } struct TMergeJoinNode {