335 lines
11 KiB
C++
335 lines
11 KiB
C++
// Licensed to the Apache Software Foundation (ASF) under one
|
|
// or more contributor license agreements. See the NOTICE file
|
|
// distributed with this work for additional information
|
|
// regarding copyright ownership. The ASF licenses this file
|
|
// to you under the Apache License, Version 2.0 (the
|
|
// "License"); you may not use this file except in compliance
|
|
// with the License. You may obtain a copy of the License at
|
|
//
|
|
// http://www.apache.org/licenses/LICENSE-2.0
|
|
//
|
|
// Unless required by applicable law or agreed to in writing,
|
|
// software distributed under the License is distributed on an
|
|
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
|
// KIND, either express or implied. See the License for the
|
|
// specific language governing permissions and limitations
|
|
// under the License.
|
|
|
|
#include "exec/merge_join_node.h"
|
|
|
|
#include <sstream>
|
|
|
|
#include "exprs/expr.h"
|
|
#include "exprs/expr_context.h"
|
|
#include "exprs/in_predicate.h"
|
|
#include "gen_cpp/PlanNodes_types.h"
|
|
#include "runtime/row_batch.h"
|
|
#include "runtime/runtime_state.h"
|
|
#include "util/debug_util.h"
|
|
#include "util/runtime_profile.h"
|
|
|
|
namespace doris {
|
|
|
|
template <class T>
|
|
int compare_value(const void* left_value, const void* right_value) {
|
|
if (*(T*)left_value < *(T*)right_value) {
|
|
return -1;
|
|
} else if (*(T*)left_value == *(T*)right_value) {
|
|
return 0;
|
|
} else {
|
|
return 1;
|
|
}
|
|
}
|
|
|
|
template <class T>
|
|
int compare_value(const StringValue* left_value, const StringValue* right_value) {
|
|
return left_value->compare(*right_value);
|
|
}
|
|
|
|
MergeJoinNode::MergeJoinNode(ObjectPool* pool, const TPlanNode& tnode, const DescriptorTbl& descs)
|
|
: ExecNode(pool, tnode, descs), _out_batch(NULL) {}
|
|
|
|
MergeJoinNode::~MergeJoinNode() {}
|
|
|
|
Status MergeJoinNode::init(const TPlanNode& tnode, RuntimeState* state) {
|
|
DCHECK(tnode.__isset.merge_join_node);
|
|
RETURN_IF_ERROR(ExecNode::init(tnode, state));
|
|
const std::vector<TEqJoinCondition>& cmp_conjuncts = tnode.merge_join_node.cmp_conjuncts;
|
|
|
|
for (int i = 0; i < cmp_conjuncts.size(); ++i) {
|
|
ExprContext* ctx = NULL;
|
|
RETURN_IF_ERROR(Expr::create_expr_tree(_pool, cmp_conjuncts[i].left, &ctx));
|
|
_left_expr_ctxs.push_back(ctx);
|
|
RETURN_IF_ERROR(Expr::create_expr_tree(_pool, cmp_conjuncts[i].right, &ctx));
|
|
_right_expr_ctxs.push_back(ctx);
|
|
}
|
|
|
|
RETURN_IF_ERROR(Expr::create_expr_trees(_pool, tnode.merge_join_node.other_join_conjuncts,
|
|
&_other_join_conjunct_ctxs));
|
|
return Status::OK();
|
|
}
|
|
|
|
Status MergeJoinNode::prepare(RuntimeState* state) {
|
|
RETURN_IF_ERROR(ExecNode::prepare(state));
|
|
|
|
// build and probe exprs are evaluated in the context of the rows produced by our
|
|
// right and left children, respectively
|
|
RETURN_IF_ERROR(
|
|
Expr::prepare(_left_expr_ctxs, state, child(0)->row_desc(), expr_mem_tracker()));
|
|
RETURN_IF_ERROR(
|
|
Expr::prepare(_right_expr_ctxs, state, child(1)->row_desc(), expr_mem_tracker()));
|
|
|
|
for (int i = 0; i < _left_expr_ctxs.size(); ++i) {
|
|
switch (_left_expr_ctxs[i]->root()->type().type) {
|
|
case TYPE_TINYINT:
|
|
_cmp_func.push_back(compare_value<int8_t>);
|
|
break;
|
|
|
|
case TYPE_SMALLINT:
|
|
_cmp_func.push_back(compare_value<int16_t>);
|
|
break;
|
|
|
|
case TYPE_INT:
|
|
_cmp_func.push_back(compare_value<int32_t>);
|
|
break;
|
|
|
|
case TYPE_BIGINT:
|
|
_cmp_func.push_back(compare_value<int64_t>);
|
|
break;
|
|
|
|
case TYPE_LARGEINT:
|
|
_cmp_func.push_back(compare_value<__int128>);
|
|
break;
|
|
|
|
case TYPE_CHAR:
|
|
case TYPE_VARCHAR:
|
|
_cmp_func.push_back(compare_value<StringValue>);
|
|
break;
|
|
|
|
default:
|
|
return Status::InternalError("unsupported compare type.");
|
|
break;
|
|
}
|
|
}
|
|
|
|
// _other_join_conjuncts are evaluated in the context of the rows produced by this node
|
|
RETURN_IF_ERROR(
|
|
Expr::prepare(_other_join_conjunct_ctxs, state, _row_descriptor, expr_mem_tracker()));
|
|
|
|
_result_tuple_row_size = _row_descriptor.tuple_descriptors().size() * sizeof(Tuple*);
|
|
// pre-compute the tuple index of build tuples in the output row
|
|
|
|
_left_tuple_size = child(0)->row_desc().tuple_descriptors().size();
|
|
_right_tuple_size = child(1)->row_desc().tuple_descriptors().size();
|
|
_right_tuple_idx.reserve(_right_tuple_size);
|
|
|
|
for (int i = 0; i < _right_tuple_size; ++i) {
|
|
TupleDescriptor* right_tuple_desc = child(1)->row_desc().tuple_descriptors()[i];
|
|
_right_tuple_idx.push_back(_row_descriptor.get_tuple_idx(right_tuple_desc->id()));
|
|
}
|
|
|
|
_left_child_ctx.reset(
|
|
new ChildReaderContext(row_desc(), state->batch_size(), state->instance_mem_tracker()));
|
|
_right_child_ctx.reset(
|
|
new ChildReaderContext(row_desc(), state->batch_size(), state->instance_mem_tracker()));
|
|
|
|
return Status::OK();
|
|
}
|
|
|
|
Status MergeJoinNode::close(RuntimeState* state) {
|
|
if (is_closed()) {
|
|
return Status::OK();
|
|
}
|
|
RETURN_IF_ERROR(exec_debug_action(TExecNodePhase::CLOSE));
|
|
Expr::close(_left_expr_ctxs, state);
|
|
Expr::close(_right_expr_ctxs, state);
|
|
Expr::close(_other_join_conjunct_ctxs, state);
|
|
return ExecNode::close(state);
|
|
}
|
|
|
|
Status MergeJoinNode::open(RuntimeState* state) {
|
|
SCOPED_TIMER(_runtime_profile->total_time_counter());
|
|
RETURN_IF_CANCELLED(state);
|
|
RETURN_IF_ERROR(ExecNode::open(state));
|
|
RETURN_IF_ERROR(Expr::open(_left_expr_ctxs, state));
|
|
RETURN_IF_ERROR(Expr::open(_right_expr_ctxs, state));
|
|
RETURN_IF_ERROR(Expr::open(_other_join_conjunct_ctxs, state));
|
|
|
|
_eos = false;
|
|
// Open the probe-side child so that it may perform any initialisation in parallel.
|
|
// Don't exit even if we see an error, we still need to wait for the build thread
|
|
// to finish.
|
|
RETURN_IF_ERROR(child(0)->open(state));
|
|
RETURN_IF_ERROR(child(1)->open(state));
|
|
|
|
RETURN_IF_ERROR(get_input_row(state, 0));
|
|
RETURN_IF_ERROR(get_input_row(state, 1));
|
|
|
|
return Status::OK();
|
|
}
|
|
|
|
Status MergeJoinNode::get_next(RuntimeState* state, RowBatch* out_batch, bool* eos) {
|
|
RETURN_IF_ERROR(exec_debug_action(TExecNodePhase::GETNEXT));
|
|
RETURN_IF_CANCELLED(state);
|
|
SCOPED_TIMER(_runtime_profile->total_time_counter());
|
|
|
|
if (reached_limit() || _eos) {
|
|
*eos = true;
|
|
return Status::OK();
|
|
}
|
|
|
|
while (true) {
|
|
int row_idx = out_batch->add_row();
|
|
DCHECK(row_idx != RowBatch::INVALID_ROW_INDEX);
|
|
TupleRow* row = out_batch->get_row(row_idx);
|
|
|
|
_out_batch = out_batch;
|
|
RETURN_IF_ERROR(get_next_row(state, row, eos));
|
|
|
|
if (*eos) {
|
|
_eos = true;
|
|
return Status::OK();
|
|
}
|
|
|
|
if (eval_conjuncts(&_other_join_conjunct_ctxs[0], _other_join_conjunct_ctxs.size(), row)) {
|
|
out_batch->commit_last_row();
|
|
++_num_rows_returned;
|
|
COUNTER_SET(_rows_returned_counter, _num_rows_returned);
|
|
}
|
|
|
|
if (out_batch->is_full() || out_batch->at_resource_limit() || reached_limit()) {
|
|
break;
|
|
}
|
|
}
|
|
|
|
return Status::OK();
|
|
}
|
|
|
|
void MergeJoinNode::create_output_row(TupleRow* out, TupleRow* left, TupleRow* right) {
|
|
if (left == NULL) {
|
|
memset(out, 0, _left_tuple_size);
|
|
} else {
|
|
memcpy(out, left, _left_tuple_size);
|
|
}
|
|
|
|
if (right != NULL) {
|
|
for (int i = 0; i < _right_tuple_size; ++i) {
|
|
out->set_tuple(_right_tuple_idx[i], right->get_tuple(i));
|
|
}
|
|
} else {
|
|
for (int i = 0; i < _right_tuple_size; ++i) {
|
|
out->set_tuple(_right_tuple_idx[i], NULL);
|
|
}
|
|
}
|
|
}
|
|
|
|
Status MergeJoinNode::compare_row(TupleRow* left_row, TupleRow* right_row, bool* is_lt) {
|
|
if (left_row == NULL) {
|
|
*is_lt = false;
|
|
return Status::OK();
|
|
} else if (right_row == NULL) {
|
|
*is_lt = true;
|
|
return Status::OK();
|
|
}
|
|
|
|
for (int i = 0; i < _left_expr_ctxs.size(); ++i) {
|
|
void* left_value = _left_expr_ctxs[i]->get_value(left_row);
|
|
void* right_value = _right_expr_ctxs[i]->get_value(right_row);
|
|
int cmp_val = _cmp_func[i](left_value, right_value);
|
|
|
|
if (cmp_val < 0) {
|
|
*is_lt = true;
|
|
return Status::OK();
|
|
} else if (cmp_val == 0) {
|
|
// do nothing
|
|
} else {
|
|
*is_lt = false;
|
|
return Status::OK();
|
|
}
|
|
}
|
|
|
|
// equal
|
|
*is_lt = false;
|
|
|
|
return Status::OK();
|
|
}
|
|
|
|
Status MergeJoinNode::get_next_row(RuntimeState* state, TupleRow* out_row, bool* eos) {
|
|
TupleRow* left_row = _left_child_ctx->current_row;
|
|
TupleRow* right_row = _right_child_ctx->current_row;
|
|
|
|
if (left_row == NULL && right_row == NULL) {
|
|
*eos = true;
|
|
return Status::OK();
|
|
}
|
|
|
|
bool is_lt = true;
|
|
RETURN_IF_ERROR(compare_row(left_row, right_row, &is_lt));
|
|
|
|
if (is_lt) {
|
|
create_output_row(out_row, left_row, NULL);
|
|
RETURN_IF_ERROR(get_input_row(state, 0));
|
|
} else {
|
|
create_output_row(out_row, NULL, right_row);
|
|
RETURN_IF_ERROR(get_input_row(state, 1));
|
|
}
|
|
|
|
return Status::OK();
|
|
}
|
|
|
|
Status MergeJoinNode::get_input_row(RuntimeState* state, int child_idx) {
|
|
ChildReaderContext* ctx = NULL;
|
|
|
|
if (child_idx == 0) {
|
|
ctx = _left_child_ctx.get();
|
|
} else {
|
|
ctx = _right_child_ctx.get();
|
|
}
|
|
|
|
// loop util read a valid data
|
|
while (!ctx->is_eos && ctx->row_idx >= ctx->batch.num_rows()) {
|
|
// transfer ownership before get new batch
|
|
if (NULL != _out_batch) {
|
|
ctx->batch.transfer_resource_ownership(_out_batch);
|
|
}
|
|
|
|
if (child_idx == 0) {
|
|
_left_child_ctx.reset(new ChildReaderContext(child(child_idx)->row_desc(),
|
|
state->batch_size(),
|
|
state->instance_mem_tracker()));
|
|
ctx = _left_child_ctx.get();
|
|
} else {
|
|
_right_child_ctx.reset(new ChildReaderContext(child(child_idx)->row_desc(),
|
|
state->batch_size(),
|
|
state->instance_mem_tracker()));
|
|
ctx = _right_child_ctx.get();
|
|
}
|
|
|
|
RETURN_IF_ERROR(child(child_idx)->get_next(state, &ctx->batch, &ctx->is_eos));
|
|
}
|
|
|
|
if (ctx->row_idx >= ctx->batch.num_rows()) {
|
|
ctx->current_row = NULL;
|
|
return Status::OK();
|
|
}
|
|
|
|
ctx->current_row = ctx->batch.get_row(ctx->row_idx++);
|
|
return Status::OK();
|
|
}
|
|
|
|
void MergeJoinNode::debug_string(int indentation_level, std::stringstream* out) const {
|
|
*out << string(indentation_level * 2, ' ');
|
|
*out << "MergeJoin(eos=" << (_eos ? "true" : "false")
|
|
<< " _left_child_pos=" << (_left_child_ctx.get() ? _left_child_ctx->row_idx : -1)
|
|
<< " _right_child_pos=" << (_right_child_ctx.get() ? _right_child_ctx->row_idx : -1)
|
|
<< " join_conjuncts=";
|
|
*out << "Conjunct(";
|
|
// << " left_exprs=" << Expr::debug_string(_left_exprs)
|
|
// << " right_exprs=" << Expr::debug_string(_right_exprs);
|
|
*out << ")";
|
|
ExecNode::debug_string(indentation_level, out);
|
|
*out << ")";
|
|
}
|
|
|
|
} // namespace doris
|