From f8d086d87fa3291d0efce2dee46f999ce9087c94 Mon Sep 17 00:00:00 2001 From: Zhengguo Yang Date: Tue, 8 Feb 2022 09:25:09 +0800 Subject: [PATCH] [feature](rpc) (experimental)Support implement UDF through GRPC protocol. (#7519) Support implement UDF through GRPC protocol. This brings several benefits: 1. The udf implementation language is not limited to c++, users can use any familiar language to implement udf 2. UDF is decoupled from Doris, udf will not cause doris coredump, udf computing resources are separated from doris, and doris services are not affected But RPC's UDF has a fixed overhead, so its performance is much slower than C++ UDF, especially when the amount of data is large. Create function like ``` CREATE FUNCTION rpc_add(INT, INT) RETURNS INT PROPERTIES ( "SYMBOL"="add_int", "OBJECT_FILE"="127.0.0.1:9999", "TYPE"="RPC" ); ``` Function service need to implement `check_fn` and `fn_call` methods Note: THIS IS AN EXPERIMENTAL FEATURE, THE INTERFACE AND DATA STRUCTURE MAY BE CHANGED IN FUTURE !!! --- be/src/common/config.h | 6 +- be/src/common/status.h | 2 +- be/src/exec/tablet_sink.cpp | 127 +++-- be/src/exprs/CMakeLists.txt | 1 + be/src/exprs/expr.cpp | 3 + be/src/exprs/expr_context.h | 1 + be/src/exprs/rpc_fn_call.cpp | 327 +++++++++++ be/src/exprs/rpc_fn_call.h | 63 +++ be/src/exprs/runtime_filter_rpc.cpp | 7 +- be/src/gen_cpp/CMakeLists.txt | 2 +- .../http/action/check_rpc_channel_action.cpp | 4 +- .../http/action/reset_rpc_channel_action.cpp | 12 +- be/src/plugin/plugin_loader.cpp | 6 +- be/src/runtime/data_stream_sender.cpp | 4 +- be/src/runtime/exec_env.h | 21 +- be/src/runtime/exec_env_init.cpp | 8 +- be/src/runtime/runtime_filter_mgr.cpp | 4 +- be/src/service/internal_service.cpp | 12 +- be/src/udf/udf.cpp | 15 + be/src/udf/udf_internal.h | 4 + be/src/util/CMakeLists.txt | 2 +- ...c_stub_cache.cpp => brpc_client_cache.cpp} | 20 +- be/src/util/brpc_client_cache.h | 150 +++++ be/src/util/brpc_stub_cache.h | 159 ------ be/src/util/doris_metrics.h | 16 +- be/src/vec/CMakeLists.txt | 1 + be/src/vec/columns/column_decimal.h | 9 +- be/src/vec/exprs/vectorized_fn_call.cpp | 10 +- be/src/vec/functions/function_rpc.cpp | 527 ++++++++++++++++++ be/src/vec/functions/function_rpc.h | 68 +++ be/src/vec/sink/vdata_stream_sender.cpp | 8 +- be/src/vec/sink/vdata_stream_sender.h | 24 +- be/test/exec/tablet_sink_test.cpp | 8 +- be/test/http/stream_load_test.cpp | 11 +- be/test/util/CMakeLists.txt | 2 +- ...he_test.cpp => brpc_client_cache_test.cpp} | 24 +- be/test/vec/runtime/vdata_stream_test.cpp | 17 +- .../doris/analysis/CreateFunctionStmt.java | 159 +++++- .../apache/doris/catalog/ScalarFunction.java | 5 +- .../java/org/apache/doris/common/Status.java | 2 +- .../java/org/apache/doris/qe/Coordinator.java | 2 +- .../load/sync/canal/CanalSyncDataTest.java | 13 +- .../doris/utframe/MockedBackendFactory.java | 12 +- gensrc/proto/function_service.proto | 63 +++ gensrc/proto/internal_service.proto | 10 - gensrc/proto/status.proto | 27 - gensrc/proto/types.proto | 151 +++++ gensrc/thrift/Types.thrift | 5 +- run-be-ut.sh | 1 + 49 files changed, 1765 insertions(+), 370 deletions(-) create mode 100644 be/src/exprs/rpc_fn_call.cpp create mode 100644 be/src/exprs/rpc_fn_call.h rename be/src/util/{brpc_stub_cache.cpp => brpc_client_cache.cpp} (64%) create mode 100644 be/src/util/brpc_client_cache.h delete mode 100644 be/src/util/brpc_stub_cache.h create mode 100644 be/src/vec/functions/function_rpc.cpp create mode 100644 be/src/vec/functions/function_rpc.h rename be/test/util/{brpc_stub_cache_test.cpp => brpc_client_cache_test.cpp} (73%) create mode 100644 gensrc/proto/function_service.proto delete mode 100644 gensrc/proto/status.proto diff --git a/be/src/common/config.h b/be/src/common/config.h index f4b38bad4c..2730b62163 100644 --- a/be/src/common/config.h +++ b/be/src/common/config.h @@ -653,7 +653,7 @@ CONF_mInt32(default_remote_storage_s3_max_conn, "50"); CONF_mInt32(default_remote_storage_s3_request_timeout_ms, "3000"); CONF_mInt32(default_remote_storage_s3_conn_timeout_ms, "1000"); // Set to true to disable the minidump feature. -CONF_Bool(disable_minidump , "false"); +CONF_Bool(disable_minidump, "false"); // The dir to save minidump file. // Make sure that the user who run Doris has permission to create and visit this dir, @@ -688,7 +688,11 @@ CONF_mInt32(load_task_high_priority_threshold_second, "120"); // Increase this config may avoid rpc timeout. CONF_mInt32(min_load_rpc_timeout_ms, "20000"); +// use which protocol to access function service, candicate is baidu_std/h2:grpc +CONF_String(function_service_protocol, "h2:grpc"); +// use which load balancer to select server to connect +CONF_String(rpc_load_balancer, "rr"); } // namespace config diff --git a/be/src/common/status.h b/be/src/common/status.h index 89bd6fe972..23bf764192 100644 --- a/be/src/common/status.h +++ b/be/src/common/status.h @@ -10,7 +10,7 @@ #include "common/compiler_util.h" #include "common/logging.h" #include "gen_cpp/Status_types.h" // for TStatus -#include "gen_cpp/status.pb.h" // for PStatus +#include "gen_cpp/types.pb.h" // for PStatus #include "util/slice.h" // for Slice namespace doris { diff --git a/be/src/exec/tablet_sink.cpp b/be/src/exec/tablet_sink.cpp index 7f3a62caff..2fe349294d 100644 --- a/be/src/exec/tablet_sink.cpp +++ b/be/src/exec/tablet_sink.cpp @@ -18,6 +18,7 @@ #include "exec/tablet_sink.h" #include + #include #include @@ -31,7 +32,7 @@ #include "runtime/tuple_row.h" #include "service/backend_options.h" #include "service/brpc.h" -#include "util/brpc_stub_cache.h" +#include "util/brpc_client_cache.h" #include "util/debug/sanitizer_scopes.h" #include "util/monotime.h" #include "util/proto_util.h" @@ -85,7 +86,8 @@ Status NodeChannel::init(RuntimeState* state) { _batch_size = state->batch_size(); _cur_batch.reset(new RowBatch(*_row_desc, _batch_size, _parent->_mem_tracker.get())); - _stub = state->exec_env()->brpc_stub_cache()->get_stub(_node_info.host, _node_info.brpc_port); + _stub = state->exec_env()->brpc_internal_client_cache()->get_client(_node_info.host, + _node_info.brpc_port); if (_stub == nullptr) { LOG(WARNING) << "Get rpc stub failed, host=" << _node_info.host << ", port=" << _node_info.brpc_port; @@ -156,9 +158,10 @@ void NodeChannel::_cancel_with_msg(const std::string& msg) { Status NodeChannel::open_wait() { _open_closure->join(); if (_open_closure->cntl.Failed()) { - if (!ExecEnv::GetInstance()->brpc_stub_cache()->available(_stub, _node_info.host, - _node_info.brpc_port)) { - ExecEnv::GetInstance()->brpc_stub_cache()->erase(_open_closure->cntl.remote_side()); + if (!ExecEnv::GetInstance()->brpc_internal_client_cache()->available( + _stub, _node_info.host, _node_info.brpc_port)) { + ExecEnv::GetInstance()->brpc_internal_client_cache()->erase( + _open_closure->cntl.remote_side()); } std::stringstream ss; ss << "failed to open tablet writer, error=" << berror(_open_closure->cntl.ErrorCode()) @@ -193,7 +196,7 @@ Status NodeChannel::open_wait() { bool is_last_rpc) { Status status(result.status()); if (status.ok()) { - // if has error tablet, handle them first + // if has error tablet, handle them first for (auto& error : result.tablet_errors()) { _index_channel->mark_as_failed(this, error.msg(), error.tablet_id()); } @@ -313,8 +316,9 @@ Status NodeChannel::add_row(BlockRow& block_row, int64_t tablet_id) { } DCHECK_NE(row_no, RowBatch::INVALID_ROW_INDEX); - _cur_batch->get_row(row_no)->set_tuple(0, - block_row.first->deep_copy_tuple(*_tuple_desc, _cur_batch->tuple_data_pool(), block_row.second, 0, true)); + _cur_batch->get_row(row_no)->set_tuple( + 0, block_row.first->deep_copy_tuple(*_tuple_desc, _cur_batch->tuple_data_pool(), + block_row.second, 0, true)); _cur_batch->commit_last_row(); _cur_add_batch_request.add_tablet_ids(tablet_id); return Status::OK(); @@ -338,7 +342,8 @@ Status NodeChannel::mark_close() { _pending_batches.emplace(std::move(_cur_batch), _cur_add_batch_request); _pending_batches_num++; DCHECK(_pending_batches.back().second.eos()); - LOG(INFO) << channel_info() << " mark closed, left pending batch size: " << _pending_batches.size(); + LOG(INFO) << channel_info() + << " mark closed, left pending batch size: " << _pending_batches.size(); } _eos_is_produced = true; @@ -377,7 +382,7 @@ Status NodeChannel::close_wait(RuntimeState* state) { std::make_move_iterator(_tablet_commit_infos.begin()), std::make_move_iterator(_tablet_commit_infos.end())); - _index_channel->set_error_tablet_in_state(state); + _index_channel->set_error_tablet_in_state(state); return Status::OK(); } @@ -455,7 +460,7 @@ void NodeChannel::try_send_batch() { size_t uncompressed_bytes = 0, compressed_bytes = 0; Status st = row_batch->serialize(request.mutable_row_batch(), &uncompressed_bytes, &compressed_bytes, _tuple_data_buffer_ptr); if (!st.ok()) { - cancel(fmt::format("{}, err: {}", channel_info(), st.get_error_msg())); + cancel(fmt::format("{}, err: {}", channel_info(), st.get_error_msg())); return; } if (compressed_bytes >= double(config::brpc_max_body_size) * 0.95f) { @@ -541,8 +546,8 @@ Status IndexChannel::init(RuntimeState* state, const std::vector_pool->add( - new NodeChannel(_parent, this, node_id, _schema_hash)); + channel = + _parent->_pool->add(new NodeChannel(_parent, this, node_id, _schema_hash)); _node_channels.emplace(node_id, channel); } else { channel = it->second; @@ -586,41 +591,44 @@ void IndexChannel::add_row(BlockRow& block_row, int64_t tablet_id) { } } -void IndexChannel::mark_as_failed(const NodeChannel* ch, const std::string& err, int64_t tablet_id) { +void IndexChannel::mark_as_failed(const NodeChannel* ch, const std::string& err, + int64_t tablet_id) { const auto& it = _tablets_by_channel.find(ch->node_id()); if (it == _tablets_by_channel.end()) { return; } { - std::lock_guard l(_fail_lock); + std::lock_guard l(_fail_lock); if (tablet_id == -1) { for (const auto the_tablet_id : it->second) { _failed_channels[the_tablet_id].insert(ch->node_id()); _failed_channels_msgs.emplace(the_tablet_id, err + ", host: " + ch->host()); if (_failed_channels[the_tablet_id].size() >= ((_parent->_num_replicas + 1) / 2)) { - _intolerable_failure_status = Status::InternalError(_failed_channels_msgs[the_tablet_id]); + _intolerable_failure_status = + Status::InternalError(_failed_channels_msgs[the_tablet_id]); } } } else { _failed_channels[tablet_id].insert(ch->node_id()); _failed_channels_msgs.emplace(tablet_id, err + ", host: " + ch->host()); if (_failed_channels[tablet_id].size() >= ((_parent->_num_replicas + 1) / 2)) { - _intolerable_failure_status = Status::InternalError(_failed_channels_msgs[tablet_id]); + _intolerable_failure_status = + Status::InternalError(_failed_channels_msgs[tablet_id]); } } } } Status IndexChannel::check_intolerable_failure() { - std::lock_guard l(_fail_lock); + std::lock_guard l(_fail_lock); return _intolerable_failure_status; } void IndexChannel::set_error_tablet_in_state(RuntimeState* state) { std::vector& error_tablet_infos = state->error_tablet_infos(); - std::lock_guard l(_fail_lock); + std::lock_guard l(_fail_lock); for (const auto& it : _failed_channels_msgs) { TErrorTabletInfo error_info; error_info.__set_tabletId(it.first); @@ -684,7 +692,8 @@ Status OlapTableSink::prepare(RuntimeState* state) { _sender_id = state->per_fragment_instance_idx(); _num_senders = state->num_per_fragment_instances(); - _is_high_priority = (state->query_options().query_timeout <= config::load_task_high_priority_threshold_second); + _is_high_priority = (state->query_options().query_timeout <= + config::load_task_high_priority_threshold_second); // profile must add to state's object pool _profile = state->obj_pool()->add(new RuntimeProfile("OlapTableSink")); @@ -810,7 +819,10 @@ Status OlapTableSink::open(RuntimeState* state) { // The open() phase is mainly to generate DeltaWriter instances on the nodes corresponding to each node channel. // This phase will not fail due to a single tablet. // Therefore, if the open() phase fails, all tablets corresponding to the node need to be marked as failed. - index_channel->mark_as_failed(ch, fmt::format("{}, open failed, err: {}", ch->channel_info(), st.get_error_msg()), -1); + index_channel->mark_as_failed(ch, + fmt::format("{}, open failed, err: {}", + ch->channel_info(), st.get_error_msg()), + -1); } }); @@ -851,7 +863,8 @@ Status OlapTableSink::send(RuntimeState* state, RowBatch* input_batch) { SCOPED_RAW_TIMER(&_validate_data_ns); _filter_bitmap.Reset(batch->num_rows()); bool stop_processing = false; - RETURN_IF_ERROR(_validate_data(state, batch, &_filter_bitmap, &filtered_rows, &stop_processing)); + RETURN_IF_ERROR( + _validate_data(state, batch, &_filter_bitmap, &filtered_rows, &stop_processing)); _number_filtered_rows += filtered_rows; if (stop_processing) { // should be returned after updating "_number_filtered_rows", to make sure that load job can be cancelled @@ -870,12 +883,15 @@ Status OlapTableSink::send(RuntimeState* state, RowBatch* input_batch) { const OlapTablePartition* partition = nullptr; uint32_t dist_hash = 0; if (!_partition->find_tablet(tuple, &partition, &dist_hash)) { - RETURN_IF_ERROR(state->append_error_msg_to_file([]() -> std::string { return ""; }, + RETURN_IF_ERROR(state->append_error_msg_to_file( + []() -> std::string { return ""; }, [&]() -> std::string { - fmt::memory_buffer buf; - fmt::format_to(buf, "no partition for this tuple. tuple={}", Tuple::to_string(tuple, *_output_tuple_desc)); - return buf.data(); - }, &stop_processing)); + fmt::memory_buffer buf; + fmt::format_to(buf, "no partition for this tuple. tuple={}", + Tuple::to_string(tuple, *_output_tuple_desc)); + return buf.data(); + }, + &stop_processing)); _number_filtered_rows++; if (stop_processing) { return Status::EndOfFile("Encountered unqualified data, stop processing"); @@ -892,7 +908,7 @@ Status OlapTableSink::send(RuntimeState* state, RowBatch* input_batch) { } // check intolerable failure - for (auto index_channel : _channels) { + for (auto index_channel : _channels) { RETURN_IF_ERROR(index_channel->check_intolerable_failure()); } return Status::OK(); @@ -953,7 +969,6 @@ Status OlapTableSink::close(RuntimeState* state, Status close_status) { status = index_st; } } // end for index channels - } // TODO need to be improved LOG(INFO) << "total mem_exceeded_block_ns=" << mem_exceeded_block_ns @@ -1031,7 +1046,8 @@ Status OlapTableSink::_convert_batch(RuntimeState* state, RowBatch* input_batch, // Only when the expr return value is null, we will check the error message. std::string expr_error = _output_expr_ctxs[j]->get_error_msg(); if (!expr_error.empty()) { - RETURN_IF_ERROR(state->append_error_msg_to_file([&]() -> std::string { return slot_desc->col_name(); }, + RETURN_IF_ERROR(state->append_error_msg_to_file( + [&]() -> std::string { return slot_desc->col_name(); }, [&]() -> std::string { return expr_error; }, &stop_processing)); _number_filtered_rows++; ignore_this_row = true; @@ -1040,12 +1056,15 @@ Status OlapTableSink::_convert_batch(RuntimeState* state, RowBatch* input_batch, break; } if (!slot_desc->is_nullable()) { - RETURN_IF_ERROR(state->append_error_msg_to_file([]() -> std::string { return ""; }, + RETURN_IF_ERROR(state->append_error_msg_to_file( + []() -> std::string { return ""; }, [&]() -> std::string { - fmt::memory_buffer buf; - fmt::format_to(buf, "null value for not null column, column={}", slot_desc->col_name()); - return buf.data(); - }, &stop_processing)); + fmt::memory_buffer buf; + fmt::format_to(buf, "null value for not null column, column={}", + slot_desc->col_name()); + return buf.data(); + }, + &stop_processing)); _number_filtered_rows++; ignore_this_row = true; break; @@ -1073,8 +1092,8 @@ Status OlapTableSink::_convert_batch(RuntimeState* state, RowBatch* input_batch, return Status::OK(); } -Status OlapTableSink::_validate_data(RuntimeState* state, RowBatch* batch, Bitmap* filter_bitmap, int* filtered_rows, - bool* stop_processing) { +Status OlapTableSink::_validate_data(RuntimeState* state, RowBatch* batch, Bitmap* filter_bitmap, + int* filtered_rows, bool* stop_processing) { for (int row_no = 0; row_no < batch->num_rows(); ++row_no) { Tuple* tuple = batch->get_row(row_no)->get_tuple(0); bool row_valid = true; @@ -1083,8 +1102,9 @@ Status OlapTableSink::_validate_data(RuntimeState* state, RowBatch* batch, Bitma SlotDescriptor* desc = _output_tuple_desc->slots()[i]; if (desc->is_nullable() && tuple->is_null(desc->null_indicator_offset())) { if (desc->type().type == TYPE_OBJECT) { - fmt::format_to(error_msg, "null is not allowed for bitmap column, column_name: {}; ", - desc->col_name()); + fmt::format_to(error_msg, + "null is not allowed for bitmap column, column_name: {}; ", + desc->col_name()); row_valid = false; } continue; @@ -1096,9 +1116,11 @@ Status OlapTableSink::_validate_data(RuntimeState* state, RowBatch* batch, Bitma // Fixed length string StringValue* str_val = (StringValue*)slot; if (str_val->len > desc->type().len) { - fmt::format_to(error_msg, "{}", "the length of input is too long than schema. "); + fmt::format_to(error_msg, "{}", + "the length of input is too long than schema. "); fmt::format_to(error_msg, "column_name: {}; ", desc->col_name()); - fmt::format_to(error_msg, "input str: [{}] ", std::string(str_val->ptr, str_val->len)); + fmt::format_to(error_msg, "input str: [{}] ", + std::string(str_val->ptr, str_val->len)); fmt::format_to(error_msg, "schema length: {}; ", desc->type().len); fmt::format_to(error_msg, "actual length: {}; ", str_val->len); row_valid = false; @@ -1118,9 +1140,11 @@ Status OlapTableSink::_validate_data(RuntimeState* state, RowBatch* batch, Bitma case TYPE_STRING: { StringValue* str_val = (StringValue*)slot; if (str_val->len > OLAP_STRING_MAX_LENGTH) { - fmt::format_to(error_msg, "{}", "the length of input is too long than schema. "); + fmt::format_to(error_msg, "{}", + "the length of input is too long than schema. "); fmt::format_to(error_msg, "column_name: {}; ", desc->col_name()); - fmt::format_to(error_msg, "first 128 bytes of input str: [{}] ", std::string(str_val->ptr, 128)); + fmt::format_to(error_msg, "first 128 bytes of input str: [{}] ", + std::string(str_val->ptr, 128)); fmt::format_to(error_msg, "schema length: {}; ", OLAP_STRING_MAX_LENGTH); fmt::format_to(error_msg, "actual length: {}; ", str_val->len); row_valid = false; @@ -1134,15 +1158,19 @@ Status OlapTableSink::_validate_data(RuntimeState* state, RowBatch* batch, Bitma int code = dec_val.round(&dec_val, desc->type().scale, HALF_UP); reinterpret_cast(slot)->value = dec_val.value(); if (code != E_DEC_OK) { - fmt::format_to(error_msg, "round one decimal failed.value={}; ", dec_val.to_string()); + fmt::format_to(error_msg, "round one decimal failed.value={}; ", + dec_val.to_string()); row_valid = false; continue; } } if (dec_val > _max_decimalv2_val[i] || dec_val < _min_decimalv2_val[i]) { - fmt::format_to(error_msg, "decimal value is not valid for definition, column={}", desc->col_name()); + fmt::format_to(error_msg, + "decimal value is not valid for definition, column={}", + desc->col_name()); fmt::format_to(error_msg, ", value={}", dec_val.to_string()); - fmt::format_to(error_msg, ", precision={}, scale={}; ", desc->type().precision, desc->type().scale); + fmt::format_to(error_msg, ", precision={}, scale={}; ", desc->type().precision, + desc->type().scale); row_valid = false; continue; } @@ -1151,7 +1179,9 @@ Status OlapTableSink::_validate_data(RuntimeState* state, RowBatch* batch, Bitma case TYPE_HLL: { Slice* hll_val = (Slice*)slot; if (!HyperLogLog::is_valid(*hll_val)) { - fmt::format_to(error_msg, "Content of HLL type column is invalid. column name: {}; ", desc->col_name()); + fmt::format_to(error_msg, + "Content of HLL type column is invalid. column name: {}; ", + desc->col_name()); row_valid = false; continue; } @@ -1165,7 +1195,8 @@ Status OlapTableSink::_validate_data(RuntimeState* state, RowBatch* batch, Bitma if (!row_valid) { (*filtered_rows)++; filter_bitmap->Set(row_no, true); - RETURN_IF_ERROR(state->append_error_msg_to_file([]() -> std::string { return ""; }, + RETURN_IF_ERROR(state->append_error_msg_to_file( + []() -> std::string { return ""; }, [&]() -> std::string { return error_msg.data(); }, stop_processing)); } } diff --git a/be/src/exprs/CMakeLists.txt b/be/src/exprs/CMakeLists.txt index 5d69aa79e5..3b4b86cb97 100644 --- a/be/src/exprs/CMakeLists.txt +++ b/be/src/exprs/CMakeLists.txt @@ -54,6 +54,7 @@ add_library(Exprs math_functions.cpp null_literal.cpp scalar_fn_call.cpp + rpc_fn_call.cpp slot_ref.cpp string_functions.cpp array_functions.cpp diff --git a/be/src/exprs/expr.cpp b/be/src/exprs/expr.cpp index 97352dade5..1c29d5d840 100644 --- a/be/src/exprs/expr.cpp +++ b/be/src/exprs/expr.cpp @@ -38,6 +38,7 @@ #include "exprs/is_null_predicate.h" #include "exprs/literal.h" #include "exprs/null_literal.h" +#include "exprs/rpc_fn_call.h" #include "exprs/scalar_fn_call.h" #include "exprs/slot_ref.h" #include "exprs/tuple_is_null_predicate.h" @@ -357,6 +358,8 @@ Status Expr::create_expr(ObjectPool* pool, const TExprNode& texpr_node, Expr** e *expr = pool->add(new IfNullExpr(texpr_node)); } else if (texpr_node.fn.name.function_name == "coalesce") { *expr = pool->add(new CoalesceExpr(texpr_node)); + } else if (texpr_node.fn.binary_type == TFunctionBinaryType::RPC) { + *expr = pool->add(new RPCFnCall(texpr_node)); } else { *expr = pool->add(new ScalarFnCall(texpr_node)); } diff --git a/be/src/exprs/expr_context.h b/be/src/exprs/expr_context.h index 45896a2636..f176240f72 100644 --- a/be/src/exprs/expr_context.h +++ b/be/src/exprs/expr_context.h @@ -153,6 +153,7 @@ public: private: friend class Expr; friend class ScalarFnCall; + friend class RPCFnCall; friend class InPredicate; friend class RuntimePredicateWrapper; friend class BloomFilterPredicate; diff --git a/be/src/exprs/rpc_fn_call.cpp b/be/src/exprs/rpc_fn_call.cpp new file mode 100644 index 0000000000..92b67e5949 --- /dev/null +++ b/be/src/exprs/rpc_fn_call.cpp @@ -0,0 +1,327 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include "exprs/rpc_fn_call.h" + +#include "exprs/anyval_util.h" +#include "exprs/expr_context.h" +#include "fmt/format.h" +#include "gen_cpp/function_service.pb.h" +#include "runtime/runtime_state.h" +#include "runtime/user_function_cache.h" +#include "service/brpc.h" +#include "util/brpc_client_cache.h" + +namespace doris { + +RPCFnCall::RPCFnCall(const TExprNode& node) : Expr(node), _fn_context_index(-1) { + DCHECK_EQ(_fn.binary_type, TFunctionBinaryType::RPC); +} + +Status RPCFnCall::prepare(RuntimeState* state, const RowDescriptor& desc, ExprContext* context) { + RETURN_IF_ERROR(Expr::prepare(state, desc, context)); + DCHECK(!_fn.scalar_fn.symbol.empty()); + + FunctionContext::TypeDesc return_type = AnyValUtil::column_type_to_type_desc(_type); + std::vector arg_types; + bool char_arg = false; + for (int i = 0; i < _children.size(); ++i) { + arg_types.push_back(AnyValUtil::column_type_to_type_desc(_children[i]->type())); + char_arg = char_arg || (_children[i]->type().type == TYPE_CHAR); + } + _fn_context_index = context->register_func(state, return_type, arg_types, 0); + + // _fn.scalar_fn.symbol + _rpc_function_symbol = _fn.scalar_fn.symbol; + + _client = state->exec_env()->brpc_function_client_cache()->get_client(_fn.hdfs_location); + + if (_client == nullptr) { + return Status::InternalError( + fmt::format("rpc env init error: {}/{}", _fn.hdfs_location, _rpc_function_symbol)); + } + return Status::OK(); +} + +Status RPCFnCall::open(RuntimeState* state, ExprContext* ctx, + FunctionContext::FunctionStateScope scope) { + RETURN_IF_ERROR(Expr::open(state, ctx, scope)); + return Status::OK(); +} + +void RPCFnCall::close(RuntimeState* state, ExprContext* context, + FunctionContext::FunctionStateScope scope) { + Expr::close(state, context, scope); +} + +Status RPCFnCall::_eval_children(ExprContext* context, TupleRow* row, + PFunctionCallResponse* response) { + PFunctionCallRequest request; + request.set_function_name(_rpc_function_symbol); + for (int i = 0; i < _children.size(); ++i) { + PValues* arg = request.add_args(); + void* src_slot = context->get_value(_children[i], row); + PGenericType* ptype = arg->mutable_type(); + if (src_slot == nullptr) { + arg->set_has_null(true); + arg->add_null_map(true); + } else { + arg->set_has_null(false); + } + switch (_children[i]->type().type) { + case TYPE_BOOLEAN: { + ptype->set_id(PGenericType::BOOLEAN); + arg->add_bool_value(*(bool*)src_slot); + break; + } + case TYPE_TINYINT: { + ptype->set_id(PGenericType::INT8); + arg->add_int32_value(*(int8_t*)src_slot); + break; + } + case TYPE_SMALLINT: { + ptype->set_id(PGenericType::INT16); + arg->add_int32_value(*(int16_t*)src_slot); + break; + } + case TYPE_INT: { + ptype->set_id(PGenericType::INT32); + arg->add_int32_value(*(int*)src_slot); + break; + } + case TYPE_BIGINT: { + ptype->set_id(PGenericType::INT64); + arg->add_int64_value(*(int64_t*)src_slot); + break; + } + case TYPE_LARGEINT: { + ptype->set_id(PGenericType::INT128); + char buffer[sizeof(__int128)]; + memcpy(buffer, src_slot, sizeof(__int128)); + arg->add_bytes_value(buffer, sizeof(__int128)); + break; + } + case TYPE_DOUBLE: { + ptype->set_id(PGenericType::DOUBLE); + arg->add_double_value(*(double*)src_slot); + break; + } + case TYPE_FLOAT: { + ptype->set_id(PGenericType::FLOAT); + arg->add_float_value(*(float*)src_slot); + break; + } + case TYPE_VARCHAR: + case TYPE_STRING: + case TYPE_CHAR: { + ptype->set_id(PGenericType::STRING); + StringValue value = *reinterpret_cast(src_slot); + arg->add_string_value(value.ptr, value.len); + break; + } + case TYPE_HLL: { + ptype->set_id(PGenericType::HLL); + StringValue value = *reinterpret_cast(src_slot); + arg->add_string_value(value.ptr, value.len); + break; + } + case TYPE_OBJECT: { + ptype->set_id(PGenericType::BITMAP); + StringValue value = *reinterpret_cast(src_slot); + arg->add_string_value(value.ptr, value.len); + break; + } + case TYPE_DECIMALV2: { + ptype->set_id(PGenericType::DECIMAL128); + ptype->mutable_decimal_type()->set_precision(_children[i]->type().precision); + ptype->mutable_decimal_type()->set_scale(_children[i]->type().scale); + char buffer[sizeof(__int128)]; + memcpy(buffer, src_slot, sizeof(__int128)); + arg->add_bytes_value(buffer, sizeof(__int128)); + break; + } + case TYPE_DATE: { + ptype->set_id(PGenericType::DATE); + const auto* time_val = (const DateTimeValue*)(src_slot); + PDateTime* date_time = arg->add_datetime_value(); + date_time->set_day(time_val->day()); + date_time->set_month(time_val->month()); + date_time->set_year(time_val->year()); + break; + } + case TYPE_DATETIME: { + ptype->set_id(PGenericType::DATETIME); + const auto* time_val = (const DateTimeValue*)(src_slot); + PDateTime* date_time = arg->add_datetime_value(); + date_time->set_day(time_val->day()); + date_time->set_month(time_val->month()); + date_time->set_year(time_val->year()); + date_time->set_hour(time_val->hour()); + date_time->set_minute(time_val->minute()); + date_time->set_second(time_val->second()); + date_time->set_microsecond(time_val->microsecond()); + break; + } + case TYPE_TIME: { + ptype->set_id(PGenericType::DATETIME); + const auto* time_val = (const DateTimeValue*)(src_slot); + PDateTime* date_time = arg->add_datetime_value(); + date_time->set_hour(time_val->hour()); + date_time->set_minute(time_val->minute()); + date_time->set_second(time_val->second()); + date_time->set_microsecond(time_val->microsecond()); + break; + } + default: { + FunctionContext* fn_ctx = context->fn_context(_fn_context_index); + fn_ctx->set_error( + fmt::format("data time not supported: {}", _children[i]->type().type).c_str()); + break; + } + } + } + + brpc::Controller cntl; + _client->fn_call(&cntl, &request, response, nullptr); + if (cntl.Failed()) { + FunctionContext* fn_ctx = context->fn_context(_fn_context_index); + fn_ctx->set_error(cntl.ErrorText().c_str()); + return Status::InternalError(fmt::format("call rpc function {} failed: {}", + _rpc_function_symbol, cntl.ErrorText()) + .c_str()); + } + if (response->status().status_code() != 0) { + FunctionContext* fn_ctx = context->fn_context(_fn_context_index); + fn_ctx->set_error(response->status().DebugString().c_str()); + return Status::InternalError(fmt::format("call rpc function {} failed: {}", + _rpc_function_symbol, + response->status().DebugString())); + } + return Status::OK(); +} + +template +T RPCFnCall::interpret_eval(ExprContext* context, TupleRow* row) { + PFunctionCallResponse response; + Status st = _eval_children(context, row, &response); + WARN_IF_ERROR(st, "call rpc udf error"); + if (!st.ok() || (response.result().has_null() && response.result().null_map(0))) { + return T::null(); + } + T res_val; + // TODO(yangzhg) deal with udtf and udaf + const PValues& result = response.result(); + if constexpr (std::is_same_v) { + DCHECK(result.type().id() == PGenericType::INT8); + res_val.val = static_cast(result.int32_value(0)); + } else if constexpr (std::is_same_v) { + DCHECK(result.type().id() == PGenericType::INT16); + res_val.val = static_cast(result.int32_value(0)); + } else if constexpr (std::is_same_v) { + DCHECK(result.type().id() == PGenericType::INT32); + res_val.val = result.int32_value(0); + } else if constexpr (std::is_same_v) { + DCHECK(result.type().id() == PGenericType::INT64); + res_val.val = result.int64_value(0); + } else if constexpr (std::is_same_v) { + DCHECK(result.type().id() == PGenericType::FLOAT); + res_val.val = result.float_value(0); + } else if constexpr (std::is_same_v) { + DCHECK(result.type().id() == PGenericType::DOUBLE); + res_val.val = result.double_value(0); + } else if constexpr (std::is_same_v) { + DCHECK(result.type().id() == PGenericType::STRING); + FunctionContext* fn_ctx = context->fn_context(_fn_context_index); + StringVal val(fn_ctx, result.string_value(0).size()); + res_val = val.copy_from(fn_ctx, + reinterpret_cast(result.string_value(0).c_str()), + result.string_value(0).size()); + } else if constexpr (std::is_same_v) { + DCHECK(result.type().id() == PGenericType::INT128); + memcpy(&(res_val.val), result.bytes_value(0).data(), sizeof(__int128_t)); + } else if constexpr (std::is_same_v) { + DCHECK(result.type().id() == PGenericType::DATE || + result.type().id() == PGenericType::DATETIME); + DateTimeValue value; + value.set_time(result.datetime_value(0).year(), result.datetime_value(0).month(), + result.datetime_value(0).day(), result.datetime_value(0).hour(), + result.datetime_value(0).minute(), result.datetime_value(0).second(), + result.datetime_value(0).microsecond()); + if (result.type().id() == PGenericType::DATE) { + value.set_type(TimeType::TIME_DATE); + } else if (result.type().id() == PGenericType::DATETIME) { + if (result.datetime_value(0).has_year()) { + value.set_type(TimeType::TIME_DATETIME); + } else + value.set_type(TimeType::TIME_TIME); + } + value.to_datetime_val(&res_val); + } else if constexpr (std::is_same_v) { + DCHECK(result.type().id() == PGenericType::DECIMAL128); + memcpy(&(res_val.val), result.bytes_value(0).data(), sizeof(__int128_t)); + } + return res_val; +} // namespace doris + +doris_udf::IntVal RPCFnCall::get_int_val(ExprContext* context, TupleRow* row) { + return interpret_eval(context, row); +} + +doris_udf::BooleanVal RPCFnCall::get_boolean_val(ExprContext* context, TupleRow* row) { + return interpret_eval(context, row); +} + +doris_udf::TinyIntVal RPCFnCall::get_tiny_int_val(ExprContext* context, TupleRow* row) { + return interpret_eval(context, row); +} + +doris_udf::SmallIntVal RPCFnCall::get_small_int_val(ExprContext* context, TupleRow* row) { + return interpret_eval(context, row); +} + +doris_udf::BigIntVal RPCFnCall::get_big_int_val(ExprContext* context, TupleRow* row) { + return interpret_eval(context, row); +} + +doris_udf::FloatVal RPCFnCall::get_float_val(ExprContext* context, TupleRow* row) { + return interpret_eval(context, row); +} + +doris_udf::DoubleVal RPCFnCall::get_double_val(ExprContext* context, TupleRow* row) { + return interpret_eval(context, row); +} + +doris_udf::StringVal RPCFnCall::get_string_val(ExprContext* context, TupleRow* row) { + return interpret_eval(context, row); +} + +doris_udf::LargeIntVal RPCFnCall::get_large_int_val(ExprContext* context, TupleRow* row) { + return interpret_eval(context, row); +} + +doris_udf::DateTimeVal RPCFnCall::get_datetime_val(ExprContext* context, TupleRow* row) { + return interpret_eval(context, row); +} + +doris_udf::DecimalV2Val RPCFnCall::get_decimalv2_val(ExprContext* context, TupleRow* row) { + return interpret_eval(context, row); +} +doris_udf::CollectionVal RPCFnCall::get_array_val(ExprContext* context, TupleRow* row) { + return interpret_eval(context, row); +} + +} // namespace doris diff --git a/be/src/exprs/rpc_fn_call.h b/be/src/exprs/rpc_fn_call.h new file mode 100644 index 0000000000..c04e2ec081 --- /dev/null +++ b/be/src/exprs/rpc_fn_call.h @@ -0,0 +1,63 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#pragma once + +#include "common/object_pool.h" +#include "exprs/expr.h" +#include "udf/udf.h" + +namespace doris { +class TExprNode; +class PFunctionService_Stub; +class PFunctionCallResponse; + +class RPCFnCall : public Expr { +public: + RPCFnCall(const TExprNode& node); + + virtual Status prepare(RuntimeState* state, const RowDescriptor& desc, + ExprContext* context) override; + virtual Status open(RuntimeState* state, ExprContext* context, + FunctionContext::FunctionStateScope scope) override; + virtual void close(RuntimeState* state, ExprContext* context, + FunctionContext::FunctionStateScope scope) override; + virtual Expr* clone(ObjectPool* pool) const override { return pool->add(new RPCFnCall(*this)); } + + virtual doris_udf::BooleanVal get_boolean_val(ExprContext* context, TupleRow*) override; + virtual doris_udf::TinyIntVal get_tiny_int_val(ExprContext* context, TupleRow*) override; + virtual doris_udf::SmallIntVal get_small_int_val(ExprContext* context, TupleRow*) override; + virtual doris_udf::IntVal get_int_val(ExprContext* context, TupleRow*) override; + virtual doris_udf::BigIntVal get_big_int_val(ExprContext* context, TupleRow*) override; + virtual doris_udf::LargeIntVal get_large_int_val(ExprContext* context, TupleRow*) override; + virtual doris_udf::FloatVal get_float_val(ExprContext* context, TupleRow*) override; + virtual doris_udf::DoubleVal get_double_val(ExprContext* context, TupleRow*) override; + virtual doris_udf::StringVal get_string_val(ExprContext* context, TupleRow*) override; + virtual doris_udf::DateTimeVal get_datetime_val(ExprContext* context, TupleRow*) override; + virtual doris_udf::DecimalV2Val get_decimalv2_val(ExprContext* context, TupleRow*) override; + virtual doris_udf::CollectionVal get_array_val(ExprContext* context, TupleRow*) override; + +private: + Status _eval_children(ExprContext* context, TupleRow* row, PFunctionCallResponse* response); + template + RETURN_TYPE interpret_eval(ExprContext* context, TupleRow* row); + + std::shared_ptr _client = nullptr; + int _fn_context_index; + std::string _rpc_function_symbol; +}; +}; // namespace doris \ No newline at end of file diff --git a/be/src/exprs/runtime_filter_rpc.cpp b/be/src/exprs/runtime_filter_rpc.cpp index c20779d63a..764dcf9092 100644 --- a/be/src/exprs/runtime_filter_rpc.cpp +++ b/be/src/exprs/runtime_filter_rpc.cpp @@ -25,7 +25,7 @@ #include "gen_cpp/PlanNodes_types.h" #include "gen_cpp/internal_service.pb.h" #include "service/brpc.h" -#include "util/brpc_stub_cache.h" +#include "util/brpc_client_cache.h" namespace doris { @@ -40,7 +40,7 @@ Status IRuntimeFilter::push_to_remote(RuntimeState* state, const TNetworkAddress DCHECK(is_producer()); DCHECK(_rpc_context == nullptr); std::shared_ptr stub( - state->exec_env()->brpc_stub_cache()->get_stub(*addr)); + state->exec_env()->brpc_internal_client_cache()->get_client(*addr)); if (!stub) { std::string msg = fmt::format("Get rpc stub failed, host={}, port=", addr->hostname, addr->port); @@ -94,7 +94,8 @@ Status IRuntimeFilter::join_rpc() { if (_rpc_context->cntl.Failed()) { LOG(WARNING) << "runtimefilter rpc err:" << _rpc_context->cntl.ErrorText(); // reset stub cache - ExecEnv::GetInstance()->brpc_stub_cache()->erase(_rpc_context->cntl.remote_side()); + ExecEnv::GetInstance()->brpc_internal_client_cache()->erase( + _rpc_context->cntl.remote_side()); } } return Status::OK(); diff --git a/be/src/gen_cpp/CMakeLists.txt b/be/src/gen_cpp/CMakeLists.txt index cc6d52b08d..22aa8c9cfe 100644 --- a/be/src/gen_cpp/CMakeLists.txt +++ b/be/src/gen_cpp/CMakeLists.txt @@ -84,8 +84,8 @@ set(SRC_FILES ${GEN_CPP_DIR}/data.pb.cc ${GEN_CPP_DIR}/descriptors.pb.cc ${GEN_CPP_DIR}/internal_service.pb.cc + ${GEN_CPP_DIR}/function_service.pb.cc ${GEN_CPP_DIR}/types.pb.cc - ${GEN_CPP_DIR}/status.pb.cc ${GEN_CPP_DIR}/segment_v2.pb.cc #$${GEN_CPP_DIR}/opcode/functions.cc #$${GEN_CPP_DIR}/opcode/vector-functions.cc diff --git a/be/src/http/action/check_rpc_channel_action.cpp b/be/src/http/action/check_rpc_channel_action.cpp index a26031f837..6a688e836d 100644 --- a/be/src/http/action/check_rpc_channel_action.cpp +++ b/be/src/http/action/check_rpc_channel_action.cpp @@ -24,7 +24,7 @@ #include "http/http_request.h" #include "runtime/exec_env.h" #include "service/brpc.h" -#include "util/brpc_stub_cache.h" +#include "util/brpc_client_cache.h" #include "util/md5.h" namespace doris { @@ -71,7 +71,7 @@ void CheckRPCChannelAction::handle(HttpRequest* req) { digest.digest(); request.set_md5(digest.hex()); std::shared_ptr stub( - _exec_env->brpc_stub_cache()->get_stub(req_ip, port)); + _exec_env->brpc_internal_client_cache()->get_client(req_ip, port)); if (!stub) { HttpChannel::send_reply( req, HttpStatus::INTERNAL_SERVER_ERROR, diff --git a/be/src/http/action/reset_rpc_channel_action.cpp b/be/src/http/action/reset_rpc_channel_action.cpp index 38e4a7e8d0..242bfe7a05 100644 --- a/be/src/http/action/reset_rpc_channel_action.cpp +++ b/be/src/http/action/reset_rpc_channel_action.cpp @@ -22,7 +22,7 @@ #include "http/http_channel.h" #include "http/http_request.h" #include "runtime/exec_env.h" -#include "util/brpc_stub_cache.h" +#include "util/brpc_client_cache.h" #include "util/string_util.h" namespace doris { @@ -30,11 +30,11 @@ ResetRPCChannelAction::ResetRPCChannelAction(ExecEnv* exec_env) : _exec_env(exec void ResetRPCChannelAction::handle(HttpRequest* req) { std::string endpoints = req->param("endpoints"); if (iequal(endpoints, "all")) { - int size = _exec_env->brpc_stub_cache()->size(); + int size = _exec_env->brpc_internal_client_cache()->size(); if (size > 0) { std::vector endpoints; - _exec_env->brpc_stub_cache()->get_all(&endpoints); - _exec_env->brpc_stub_cache()->clear(); + _exec_env->brpc_internal_client_cache()->get_all(&endpoints); + _exec_env->brpc_internal_client_cache()->clear(); HttpChannel::send_reply(req, HttpStatus::OK, fmt::format("reseted: {0}", join(endpoints, ","))); return; @@ -45,14 +45,14 @@ void ResetRPCChannelAction::handle(HttpRequest* req) { } else { std::vector reseted; for (const std::string& endpoint : split(endpoints, ",")) { - if (!_exec_env->brpc_stub_cache()->exist(endpoint)) { + if (!_exec_env->brpc_internal_client_cache()->exist(endpoint)) { std::string err = fmt::format("{0}: not found.", endpoint); LOG(WARNING) << err; HttpChannel::send_reply(req, HttpStatus::INTERNAL_SERVER_ERROR, err); return; } - if (_exec_env->brpc_stub_cache()->erase(endpoint)) { + if (_exec_env->brpc_internal_client_cache()->erase(endpoint)) { reseted.push_back(endpoint); } else { std::string err = fmt::format("{0}: reset failed.", endpoint); diff --git a/be/src/plugin/plugin_loader.cpp b/be/src/plugin/plugin_loader.cpp index 1e2876d65e..a0d0674022 100644 --- a/be/src/plugin/plugin_loader.cpp +++ b/be/src/plugin/plugin_loader.cpp @@ -58,13 +58,13 @@ Status DynamicPluginLoader::install() { // no, need download zip install PluginZip zip(_source); - RETURN_IF_ERROR(zip.extract(_install_path, _name)); + RETURN_NOT_OK_STATUS_WITH_WARN(zip.extract(_install_path, _name), "plugin install failed"); } // open plugin - RETURN_IF_ERROR(open_plugin()); + RETURN_NOT_OK_STATUS_WITH_WARN(open_plugin(), "plugin install failed"); - RETURN_IF_ERROR(open_valid()); + RETURN_NOT_OK_STATUS_WITH_WARN(open_valid(), "plugin install failed"); // plugin init // todo: what should be send? diff --git a/be/src/runtime/data_stream_sender.cpp b/be/src/runtime/data_stream_sender.cpp index 99a08b021f..681f5fc20b 100644 --- a/be/src/runtime/data_stream_sender.cpp +++ b/be/src/runtime/data_stream_sender.cpp @@ -42,7 +42,7 @@ #include "runtime/tuple_row.h" #include "service/backend_options.h" #include "service/brpc.h" -#include "util/brpc_stub_cache.h" +#include "util/brpc_client_cache.h" #include "util/debug_util.h" #include "util/defer_op.h" #include "util/network_util.h" @@ -112,7 +112,7 @@ Status DataStreamSender::Channel::init(RuntimeState* state) { // so the empty channel not need call function close_internal() _need_close = (_fragment_instance_id.hi != -1 && _fragment_instance_id.lo != -1); if (_need_close) { - _brpc_stub = state->exec_env()->brpc_stub_cache()->get_stub(_brpc_dest_addr); + _brpc_stub = state->exec_env()->brpc_internal_client_cache()->get_client(_brpc_dest_addr); if (!_brpc_stub) { std::string msg = fmt::format("Get rpc stub failed, dest_addr={}:{}", _brpc_dest_addr.hostname, _brpc_dest_addr.port); diff --git a/be/src/runtime/exec_env.h b/be/src/runtime/exec_env.h index 0b51e47533..8c8a9fbc2b 100644 --- a/be/src/runtime/exec_env.h +++ b/be/src/runtime/exec_env.h @@ -28,7 +28,10 @@ class VDataStreamMgr; } class BfdParser; class BrokerMgr; -class BrpcStubCache; + +template +class BrpcClientCache; + class BufferPool; class CgroupsMgr; class DataStreamMgr; @@ -61,8 +64,12 @@ class BackendServiceClient; class FrontendServiceClient; class TPaloBrokerServiceClient; class TExtDataSourceServiceClient; +class PBackendService_Stub; +class PFunctionService_Stub; + template class ClientCache; + class HeartbeatFlags; // Execution environment for queries/plan fragments. @@ -126,7 +133,12 @@ public: TmpFileMgr* tmp_file_mgr() { return _tmp_file_mgr; } BfdParser* bfd_parser() const { return _bfd_parser; } BrokerMgr* broker_mgr() const { return _broker_mgr; } - BrpcStubCache* brpc_stub_cache() const { return _brpc_stub_cache; } + BrpcClientCache* brpc_internal_client_cache() const { + return _internal_client_cache; + } + BrpcClientCache* brpc_function_client_cache() const { + return _function_client_cache; + } ReservationTracker* buffer_reservation() { return _buffer_reservation; } BufferPool* buffer_pool() { return _buffer_pool; } LoadChannelMgr* load_channel_mgr() { return _load_channel_mgr; } @@ -180,7 +192,7 @@ private: // Scanner threads for common queries will use this thread pool, // and the priority of each scan task is set according to the size of the query. - // _limited_scan_thread_pool is also the thread pool used for scanner. + // _limited_scan_thread_pool is also the thread pool used for scanner. // The difference is that it is no longer a priority queue, but according to the concurrency // set by the user to control the number of threads that can be used by a query. @@ -203,7 +215,8 @@ private: BrokerMgr* _broker_mgr = nullptr; LoadChannelMgr* _load_channel_mgr = nullptr; LoadStreamMgr* _load_stream_mgr = nullptr; - BrpcStubCache* _brpc_stub_cache = nullptr; + BrpcClientCache* _internal_client_cache = nullptr; + BrpcClientCache* _function_client_cache = nullptr; ReservationTracker* _buffer_reservation = nullptr; BufferPool* _buffer_pool = nullptr; diff --git a/be/src/runtime/exec_env_init.cpp b/be/src/runtime/exec_env_init.cpp index 35630d05a6..128f52e6ab 100644 --- a/be/src/runtime/exec_env_init.cpp +++ b/be/src/runtime/exec_env_init.cpp @@ -54,7 +54,7 @@ #include "runtime/thread_resource_mgr.h" #include "runtime/tmp_file_mgr.h" #include "util/bfd_parser.h" -#include "util/brpc_stub_cache.h" +#include "util/brpc_client_cache.h" #include "util/debug_util.h" #include "util/doris_metrics.h" #include "util/mem_info.h" @@ -125,7 +125,8 @@ Status ExecEnv::_init(const std::vector& store_paths) { _broker_mgr = new BrokerMgr(this); _load_channel_mgr = new LoadChannelMgr(); _load_stream_mgr = new LoadStreamMgr(); - _brpc_stub_cache = new BrpcStubCache(); + _internal_client_cache = new BrpcClientCache(); + _function_client_cache = new BrpcClientCache(); _stream_load_executor = new StreamLoadExecutor(this); _routine_load_task_executor = new RoutineLoadTaskExecutor(this); _small_file_mgr = new SmallFileMgr(this, config::small_file_dir); @@ -285,7 +286,8 @@ void ExecEnv::_destroy() { return; } _deregister_metrics(); - SAFE_DELETE(_brpc_stub_cache); + SAFE_DELETE(_internal_client_cache); + SAFE_DELETE(_function_client_cache); SAFE_DELETE(_load_stream_mgr); SAFE_DELETE(_load_channel_mgr); SAFE_DELETE(_broker_mgr); diff --git a/be/src/runtime/runtime_filter_mgr.cpp b/be/src/runtime/runtime_filter_mgr.cpp index e7b3a0c981..b5302aeace 100644 --- a/be/src/runtime/runtime_filter_mgr.cpp +++ b/be/src/runtime/runtime_filter_mgr.cpp @@ -28,7 +28,7 @@ #include "runtime/runtime_filter_mgr.h" #include "runtime/runtime_state.h" #include "service/brpc.h" -#include "util/brpc_stub_cache.h" +#include "util/brpc_client_cache.h" #include "util/time.h" namespace doris { @@ -251,7 +251,7 @@ Status RuntimeFilterMergeControllerEntity::merge(const PMergeFilterRequest* requ request_fragment_id->set_lo(targets[i].target_fragment_instance_id.lo); std::shared_ptr stub( - ExecEnv::GetInstance()->brpc_stub_cache()->get_stub( + ExecEnv::GetInstance()->brpc_internal_client_cache()->get_client( targets[i].target_fragment_instance_addr)); VLOG_NOTICE << "send filter " << rpc_contexts[i]->request.filter_id() << " to:" << targets[i].target_fragment_instance_addr.hostname << ":" diff --git a/be/src/service/internal_service.cpp b/be/src/service/internal_service.cpp index a948db8ed4..7cf7b28d30 100644 --- a/be/src/service/internal_service.cpp +++ b/be/src/service/internal_service.cpp @@ -30,7 +30,7 @@ #include "runtime/routine_load/routine_load_task_executor.h" #include "runtime/runtime_state.h" #include "service/brpc.h" -#include "util/brpc_stub_cache.h" +#include "util/brpc_client_cache.h" #include "util/md5.h" #include "util/proto_util.h" #include "util/string_util.h" @@ -476,21 +476,21 @@ void PInternalServiceImpl::reset_rpc_channel(google::protobuf::RpcController* brpc::ClosureGuard closure_guard(done); response->mutable_status()->set_status_code(0); if (request->all()) { - int size = ExecEnv::GetInstance()->brpc_stub_cache()->size(); + int size = ExecEnv::GetInstance()->brpc_internal_client_cache()->size(); if (size > 0) { std::vector endpoints; - ExecEnv::GetInstance()->brpc_stub_cache()->get_all(&endpoints); - ExecEnv::GetInstance()->brpc_stub_cache()->clear(); + ExecEnv::GetInstance()->brpc_internal_client_cache()->get_all(&endpoints); + ExecEnv::GetInstance()->brpc_internal_client_cache()->clear(); *response->mutable_channels() = {endpoints.begin(), endpoints.end()}; } } else { for (const std::string& endpoint : request->endpoints()) { - if (!ExecEnv::GetInstance()->brpc_stub_cache()->exist(endpoint)) { + if (!ExecEnv::GetInstance()->brpc_internal_client_cache()->exist(endpoint)) { response->mutable_status()->add_error_msgs(endpoint + ": not found."); continue; } - if (ExecEnv::GetInstance()->brpc_stub_cache()->erase(endpoint)) { + if (ExecEnv::GetInstance()->brpc_internal_client_cache()->erase(endpoint)) { response->add_channels(endpoint); } else { response->mutable_status()->add_error_msgs(endpoint + ": reset failed."); diff --git a/be/src/udf/udf.cpp b/be/src/udf/udf.cpp index ce88fd67e1..b9ec504585 100644 --- a/be/src/udf/udf.cpp +++ b/be/src/udf/udf.cpp @@ -23,6 +23,7 @@ #include #include "common/logging.h" +#include "gen_cpp/types.pb.h" #include "olap/hll.h" #include "runtime/decimalv2_value.h" @@ -196,6 +197,20 @@ FunctionContext* FunctionContextImpl::clone(MemPool* pool) { return new_context; } +// TODO: to be implemented +void FunctionContextImpl::serialize(PFunctionContext* pcontext) const { + // pcontext->set_string_result(_string_result); + // pcontext->set_num_updates(_num_updates); + // pcontext->set_num_removes(_num_removes); + // pcontext->set_num_warnings(_num_warnings); + // pcontext->set_error_msg(_error_msg); + // PUniqueId* query_id = pcontext->mutable_query_id(); + // query_id->set_hi(_context->query_id().hi); + // query_id->set_lo(_context->query_id().lo); +} + +void FunctionContextImpl::derialize(const PFunctionContext& pcontext) {} + } // namespace doris namespace doris_udf { diff --git a/be/src/udf/udf_internal.h b/be/src/udf/udf_internal.h index 085002d8ec..36cf8ad474 100644 --- a/be/src/udf/udf_internal.h +++ b/be/src/udf/udf_internal.h @@ -33,6 +33,7 @@ class FreePool; class MemPool; class RuntimeState; class ColumnPtrWrapper; +class PFunctionContext; // This class actually implements the interface of FunctionContext. This is split to // hide the details from the external header. @@ -107,6 +108,9 @@ public: const doris_udf::FunctionContext::TypeDesc& get_return_type() const { return _return_type; } + void serialize(PFunctionContext* pcontext) const; + void derialize(const PFunctionContext& pcontext); + private: friend class doris_udf::FunctionContext; friend class ExprContext; diff --git a/be/src/util/CMakeLists.txt b/be/src/util/CMakeLists.txt index a5f244864a..0582c57b80 100644 --- a/be/src/util/CMakeLists.txt +++ b/be/src/util/CMakeLists.txt @@ -100,7 +100,7 @@ set(UTIL_FILES timezone_utils.cpp easy_json.cc mustache/mustache.cc - brpc_stub_cache.cpp + brpc_client_cache.cpp zlib.cpp pprof_utils.cpp s3_uri.cpp diff --git a/be/src/util/brpc_stub_cache.cpp b/be/src/util/brpc_client_cache.cpp similarity index 64% rename from be/src/util/brpc_stub_cache.cpp rename to be/src/util/brpc_client_cache.cpp index b62f34add2..df89585e73 100644 --- a/be/src/util/brpc_stub_cache.cpp +++ b/be/src/util/brpc_client_cache.cpp @@ -15,17 +15,31 @@ // specific language governing permissions and limitations // under the License. -#include "util/brpc_stub_cache.h" +#include "util/brpc_client_cache.h" namespace doris { DEFINE_GAUGE_METRIC_PROTOTYPE_2ARG(brpc_endpoint_stub_count, MetricUnit::NOUNIT); -BrpcStubCache::BrpcStubCache() { +DEFINE_GAUGE_METRIC_PROTOTYPE_2ARG(brpc_function_endpoint_stub_count, MetricUnit::NOUNIT); + +template <> +BrpcClientCache::BrpcClientCache() { REGISTER_HOOK_METRIC(brpc_endpoint_stub_count, [this]() { return _stub_map.size(); }); } -BrpcStubCache::~BrpcStubCache() { +template <> +BrpcClientCache::~BrpcClientCache() { DEREGISTER_HOOK_METRIC(brpc_endpoint_stub_count); } + +template <> +BrpcClientCache::BrpcClientCache() { + REGISTER_HOOK_METRIC(brpc_function_endpoint_stub_count, [this]() { return _stub_map.size(); }); +} + +template <> +BrpcClientCache::~BrpcClientCache() { + DEREGISTER_HOOK_METRIC(brpc_function_endpoint_stub_count); +} } // namespace doris diff --git a/be/src/util/brpc_client_cache.h b/be/src/util/brpc_client_cache.h new file mode 100644 index 0000000000..f310cd18f6 --- /dev/null +++ b/be/src/util/brpc_client_cache.h @@ -0,0 +1,150 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#pragma once + +#include + +#include +#include + +#include "common/config.h" +#include "gen_cpp/Types_types.h" // TNetworkAddress +#include "gen_cpp/function_service.pb.h" +#include "gen_cpp/internal_service.pb.h" +#include "service/brpc.h" +#include "util/doris_metrics.h" + +template +using SubMap = phmap::parallel_flat_hash_map< + std::string, std::shared_ptr, std::hash, std::equal_to, + std::allocator>>, 8, std::mutex>; +namespace doris { + +template +class BrpcClientCache { +public: + BrpcClientCache(); + virtual ~BrpcClientCache(); + + inline std::shared_ptr get_client(const butil::EndPoint& endpoint) { + return get_client(butil::endpoint2str(endpoint).c_str()); + } + +#ifdef BE_TEST + virtual inline std::shared_ptr get_client(const TNetworkAddress& taddr) { + std::string host_port = fmt::format("{}:{}", taddr.hostname, taddr.port); + return get_client(host_port); + } +#else + inline std::shared_ptr get_client(const TNetworkAddress& taddr) { + std::string host_port = fmt::format("{}:{}", taddr.hostname, taddr.port); + return get_client(host_port); + } +#endif + + inline std::shared_ptr get_client(const std::string& host, int port) { + std::string host_port = fmt::format("{}:{}", host, port); + return get_client(host_port); + } + + inline std::shared_ptr get_client(const std::string& host_port) { + auto stub_ptr = _stub_map.find(host_port); + if (LIKELY(stub_ptr != _stub_map.end())) { + return stub_ptr->second; + } + // new one stub and insert into map + brpc::ChannelOptions options; + if constexpr (std::is_same_v) { + options.protocol = config::function_service_protocol; + } + std::unique_ptr channel(new brpc::Channel()); + int ret_code = 0; + if (host_port.find("://") == std::string::npos) { + ret_code = channel->Init(host_port.c_str(), &options); + } else { + ret_code = + channel->Init(host_port.c_str(), config::rpc_load_balancer.c_str(), &options); + } + if (ret_code) { + return nullptr; + } + auto stub = std::make_shared(channel.release(), + google::protobuf::Service::STUB_OWNS_CHANNEL); + _stub_map[host_port] = stub; + return stub; + } + + inline size_t size() { return _stub_map.size(); } + + inline void clear() { _stub_map.clear(); } + + inline size_t erase(const std::string& host_port) { return _stub_map.erase(host_port); } + + size_t erase(const std::string& host, int port) { + std::string host_port = fmt::format("{}:{}", host, port); + return erase(host_port); + } + + inline size_t erase(const butil::EndPoint& endpoint) { + return _stub_map.erase(butil::endpoint2str(endpoint).c_str()); + } + + inline bool exist(const std::string& host_port) { + return _stub_map.find(host_port) != _stub_map.end(); + } + + inline void get_all(std::vector* endpoints) { + for (auto it = _stub_map.begin(); it != _stub_map.end(); ++it) { + endpoints->emplace_back(it->first.c_str()); + } + } + + inline bool available(std::shared_ptr stub, const butil::EndPoint& endpoint) { + return available(stub, butil::endpoint2str(endpoint).c_str()); + } + + inline bool available(std::shared_ptr stub, const std::string& host_port) { + if (!stub) { + LOG(WARNING) << "stub is null to: " << host_port; + return false; + } + PHandShakeRequest request; + PHandShakeResponse response; + brpc::Controller cntl; + stub->hand_shake(&cntl, &request, &response, nullptr); + if (!cntl.Failed()) { + return true; + } else { + LOG(WARNING) << "open brpc connection to " << host_port + << " failed: " << cntl.ErrorText(); + return false; + } + } + + inline bool available(std::shared_ptr stub, const std::string& host, int port) { + std::string host_port = fmt::format("{}:{}", host, port); + return available(stub, host_port); + } + +private: + SubMap _stub_map; +}; + +using InternalServiceClientCache = BrpcClientCache; +using FunctionServiceClientCache = BrpcClientCache; +} // namespace doris diff --git a/be/src/util/brpc_stub_cache.h b/be/src/util/brpc_stub_cache.h deleted file mode 100644 index 21800f3588..0000000000 --- a/be/src/util/brpc_stub_cache.h +++ /dev/null @@ -1,159 +0,0 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - -#pragma once - -#include - -#include -#include - -#include "common/config.h" -#include "gen_cpp/Types_types.h" // TNetworkAddress -#include "gen_cpp/internal_service.pb.h" -#include "service/brpc.h" -#include "util/doris_metrics.h" - -namespace std { -template <> -struct hash { - std::size_t operator()(butil::EndPoint const& p) const { - return phmap::HashState().combine(0, butil::ip2int(p.ip), p.port); - } -}; -} // namespace std -using SubMap = phmap::parallel_flat_hash_map< - butil::EndPoint, std::shared_ptr, std::hash, - std::equal_to, - std::allocator< - std::pair>>, - 8, std::mutex>; -namespace doris { - -class BrpcStubCache { -public: - BrpcStubCache(); - virtual ~BrpcStubCache(); - - inline std::shared_ptr get_stub(const butil::EndPoint& endpoint) { - auto stub_ptr = _stub_map.find(endpoint); - if (LIKELY(stub_ptr != _stub_map.end())) { - return stub_ptr->second; - } - // new one stub and insert into map - brpc::ChannelOptions options; - std::unique_ptr channel(new brpc::Channel()); - if (channel->Init(endpoint, &options)) { - return nullptr; - } - auto stub = std::make_shared( - channel.release(), google::protobuf::Service::STUB_OWNS_CHANNEL); - _stub_map[endpoint] = stub; - return stub; - } - - virtual std::shared_ptr get_stub(const TNetworkAddress& taddr) { - butil::EndPoint endpoint; - if (str2endpoint(taddr.hostname.c_str(), taddr.port, &endpoint)) { - LOG(WARNING) << "unknown endpoint, hostname=" << taddr.hostname - << ", port=" << taddr.port; - return nullptr; - } - return get_stub(endpoint); - } - - inline std::shared_ptr get_stub(const std::string& host, int port) { - butil::EndPoint endpoint; - if (str2endpoint(host.c_str(), port, &endpoint)) { - LOG(WARNING) << "unknown endpoint, hostname=" << host << ", port=" << port; - return nullptr; - } - return get_stub(endpoint); - } - - inline size_t size() { return _stub_map.size(); } - - inline void clear() { _stub_map.clear(); } - - inline size_t erase(const std::string& host_port) { - butil::EndPoint endpoint; - if (str2endpoint(host_port.c_str(), &endpoint)) { - LOG(WARNING) << "unknown endpoint: " << host_port; - return 0; - } - return erase(endpoint); - } - - size_t erase(const std::string& host, int port) { - butil::EndPoint endpoint; - if (str2endpoint(host.c_str(), port, &endpoint)) { - LOG(WARNING) << "unknown endpoint, hostname=" << host << ", port=" << port; - return 0; - } - return erase(endpoint); - } - - inline size_t erase(const butil::EndPoint& endpoint) { return _stub_map.erase(endpoint); } - - inline bool exist(const std::string& host_port) { - butil::EndPoint endpoint; - if (str2endpoint(host_port.c_str(), &endpoint)) { - LOG(WARNING) << "unknown endpoint: " << host_port; - return false; - } - return _stub_map.find(endpoint) != _stub_map.end(); - } - - inline void get_all(std::vector* endpoints) { - for (SubMap::const_iterator it = _stub_map.begin(); it != _stub_map.end(); ++it) { - endpoints->emplace_back(endpoint2str(it->first).c_str()); - } - } - - inline bool available(std::shared_ptr stub, - const butil::EndPoint& endpoint) { - if (!stub) { - return false; - } - PHandShakeRequest request; - PHandShakeResponse response; - brpc::Controller cntl; - stub->hand_shake(&cntl, &request, &response, nullptr); - if (!cntl.Failed()) { - return true; - } else { - LOG(WARNING) << "open brpc connection to " << endpoint2str(endpoint).c_str() - << " failed: " << cntl.ErrorText(); - return false; - } - } - - inline bool available(std::shared_ptr stub, const std::string& host, - int port) { - butil::EndPoint endpoint; - if (str2endpoint(host.c_str(), port, &endpoint)) { - LOG(WARNING) << "unknown endpoint, hostname=" << host; - return false; - } - return available(stub, endpoint); - } - -private: - SubMap _stub_map; -}; - -} // namespace doris diff --git a/be/src/util/doris_metrics.h b/be/src/util/doris_metrics.h index 67d60a3e67..8015dcaefa 100644 --- a/be/src/util/doris_metrics.h +++ b/be/src/util/doris_metrics.h @@ -28,18 +28,19 @@ namespace doris { -#define REGISTER_ENTITY_HOOK_METRIC(entity, owner, metric, func) \ - owner->metric = (UIntGauge*)(entity->register_metric(&METRIC_##metric)); \ +#define REGISTER_ENTITY_HOOK_METRIC(entity, owner, metric, func) \ + owner->metric = (UIntGauge*)(entity->register_metric(&METRIC_##metric)); \ entity->register_hook(#metric, [&]() { owner->metric->set_value(func()); }); -#define REGISTER_HOOK_METRIC(metric, func) \ - REGISTER_ENTITY_HOOK_METRIC(DorisMetrics::instance()->server_entity(), DorisMetrics::instance(), metric, func) +#define REGISTER_HOOK_METRIC(metric, func) \ + REGISTER_ENTITY_HOOK_METRIC(DorisMetrics::instance()->server_entity(), \ + DorisMetrics::instance(), metric, func) -#define DEREGISTER_ENTITY_HOOK_METRIC(entity, name) \ - entity->deregister_metric(&METRIC_##name); \ +#define DEREGISTER_ENTITY_HOOK_METRIC(entity, name) \ + entity->deregister_metric(&METRIC_##name); \ entity->deregister_hook(#name); -#define DEREGISTER_HOOK_METRIC(name) \ +#define DEREGISTER_HOOK_METRIC(name) \ DEREGISTER_ENTITY_HOOK_METRIC(DorisMetrics::instance()->server_entity(), name) class DorisMetrics { @@ -177,6 +178,7 @@ public: UIntGauge* small_file_cache_count; UIntGauge* stream_load_pipe_count; UIntGauge* brpc_endpoint_stub_count; + UIntGauge* brpc_function_endpoint_stub_count; UIntGauge* tablet_writer_count; UIntGauge* compaction_mem_consumption; diff --git a/be/src/vec/CMakeLists.txt b/be/src/vec/CMakeLists.txt index 09e86e2f7d..6201c67a44 100644 --- a/be/src/vec/CMakeLists.txt +++ b/be/src/vec/CMakeLists.txt @@ -141,6 +141,7 @@ set(VEC_FILES functions/function_date_or_datetime_to_string.cpp functions/function_datetime_string_to_string.cpp functions/function_grouping.cpp + functions/function_rpc.cpp olap/vgeneric_iterators.cpp olap/vcollect_iterator.cpp olap/block_reader.cpp diff --git a/be/src/vec/columns/column_decimal.h b/be/src/vec/columns/column_decimal.h index 017d891e4f..b4b4a68b09 100644 --- a/be/src/vec/columns/column_decimal.h +++ b/be/src/vec/columns/column_decimal.h @@ -25,8 +25,8 @@ #include "vec/columns/column.h" #include "vec/columns/column_impl.h" #include "vec/columns/column_vector_helper.h" -#include "vec/common/typeid_cast.h" #include "vec/common/assert_cast.h" +#include "vec/common/typeid_cast.h" #include "vec/core/field.h" namespace doris::vectorized { @@ -97,7 +97,8 @@ public: data.push_back(static_cast(src).get_data()[n]); } - void insert_indices_from(const IColumn& src, const int* indices_begin, const int* indices_end) override { + void insert_indices_from(const IColumn& src, const int* indices_begin, + const int* indices_end) override { const Self& src_vec = assert_cast(src); data.reserve(size() + (indices_end - indices_begin)); for (auto x = indices_begin; x != indices_end; ++x) { @@ -226,4 +227,8 @@ ColumnPtr ColumnDecimal::index_impl(const PaddedPODArray& indexes, size return res; } +using ColumnDecimal32 = ColumnDecimal; +using ColumnDecimal64 = ColumnDecimal; +using ColumnDecimal128 = ColumnDecimal; + } // namespace doris::vectorized diff --git a/be/src/vec/exprs/vectorized_fn_call.cpp b/be/src/vec/exprs/vectorized_fn_call.cpp index deecc16103..6f01a126d6 100644 --- a/be/src/vec/exprs/vectorized_fn_call.cpp +++ b/be/src/vec/exprs/vectorized_fn_call.cpp @@ -25,6 +25,7 @@ #include "udf/udf_internal.h" #include "vec/data_types/data_type_nullable.h" #include "vec/data_types/data_type_number.h" +#include "vec/functions/function_rpc.h" #include "vec/functions/simple_function_factory.h" namespace doris::vectorized { @@ -42,8 +43,13 @@ doris::Status VectorizedFnCall::prepare(doris::RuntimeState* state, argument_template.emplace_back(std::move(column), child->data_type(), child->expr_name()); child_expr_name.emplace_back(child->expr_name()); } - _function = SimpleFunctionFactory::instance().get_function(_fn.name.function_name, - argument_template, _data_type); + if (_fn.binary_type == TFunctionBinaryType::RPC) { + _function = RPCFnCall::create(_fn.name.function_name, _fn.hdfs_location, argument_template, + _data_type); + } else { + _function = SimpleFunctionFactory::instance().get_function(_fn.name.function_name, + argument_template, _data_type); + } if (_function == nullptr) { return Status::InternalError( fmt::format("Function {} is not implemented", _fn.name.function_name)); diff --git a/be/src/vec/functions/function_rpc.cpp b/be/src/vec/functions/function_rpc.cpp new file mode 100644 index 0000000000..43d5a694e1 --- /dev/null +++ b/be/src/vec/functions/function_rpc.cpp @@ -0,0 +1,527 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include "vec/functions/function_rpc.h" + +#include + +#include + +#include "gen_cpp/function_service.pb.h" +#include "runtime/exec_env.h" +#include "runtime/user_function_cache.h" +#include "service/brpc.h" +#include "util/brpc_client_cache.h" +#include "vec/columns/column_vector.h" +#include "vec/core/block.h" +#include "vec/data_types/data_type_bitmap.h" +#include "vec/data_types/data_type_date.h" +#include "vec/data_types/data_type_date_time.h" +#include "vec/data_types/data_type_decimal.h" +#include "vec/data_types/data_type_nullable.h" +#include "vec/data_types/data_type_number.h" +#include "vec/data_types/data_type_string.h" + +namespace doris::vectorized { +RPCFnCall::RPCFnCall(const std::string& symbol, const std::string& server, + const DataTypes& argument_types, const DataTypePtr& return_type) + : _symbol(symbol), + _server(server), + _name(fmt::format("{}/{}", server, symbol)), + _argument_types(argument_types), + _return_type(return_type) {} +Status RPCFnCall::prepare(FunctionContext* context, FunctionContext::FunctionStateScope scope) { + _client = ExecEnv::GetInstance()->brpc_function_client_cache()->get_client(_server); + + if (_client == nullptr) { + return Status::InternalError("rpc env init error"); + } + return Status::OK(); +} + +template +void convert_col_to_pvalue(const ColumnPtr& column, const DataTypePtr& data_type, PValues* arg, + size_t row_count) { + PGenericType* ptype = arg->mutable_type(); + switch (data_type->get_type_id()) { + case TypeIndex::UInt8: { + ptype->set_id(PGenericType::UINT8); + auto* values = arg->mutable_bool_value(); + values->Reserve(row_count); + const auto* col = check_and_get_column(column); + auto& data = col->get_data(); + values->Add(data.begin(), data.begin() + row_count); + break; + } + case TypeIndex::UInt16: { + ptype->set_id(PGenericType::UINT16); + auto* values = arg->mutable_uint32_value(); + values->Reserve(row_count); + const auto* col = check_and_get_column(column); + auto& data = col->get_data(); + values->Add(data.begin(), data.begin() + row_count); + break; + } + case TypeIndex::UInt32: { + ptype->set_id(PGenericType::UINT32); + auto* values = arg->mutable_uint32_value(); + values->Reserve(row_count); + const auto* col = check_and_get_column(column); + auto& data = col->get_data(); + values->Add(data.begin(), data.begin() + row_count); + break; + } + case TypeIndex::UInt64: { + ptype->set_id(PGenericType::UINT64); + auto* values = arg->mutable_uint64_value(); + values->Reserve(row_count); + const auto* col = check_and_get_column(column); + auto& data = col->get_data(); + values->Add(data.begin(), data.begin() + row_count); + break; + } + case TypeIndex::UInt128: { + ptype->set_id(PGenericType::UINT128); + arg->mutable_bytes_value()->Reserve(row_count); + for (size_t row_num = 0; row_num < row_count; ++row_num) { + if constexpr (nullable) { + if (column->is_null_at(row_num)) { + arg->add_bytes_value(nullptr); + } else { + StringRef data = column->get_data_at(row_num); + arg->add_bytes_value(data.data, data.size); + } + } else { + StringRef data = column->get_data_at(row_num); + arg->add_bytes_value(data.data, data.size); + } + } + break; + } + case TypeIndex::Int8: { + ptype->set_id(PGenericType::INT8); + auto* values = arg->mutable_int32_value(); + values->Reserve(row_count); + const auto* col = check_and_get_column(column); + auto& data = col->get_data(); + values->Add(data.begin(), data.begin() + row_count); + break; + } + case TypeIndex::Int16: { + ptype->set_id(PGenericType::INT16); + auto* values = arg->mutable_int32_value(); + values->Reserve(row_count); + const auto* col = check_and_get_column(column); + auto& data = col->get_data(); + values->Add(data.begin(), data.begin() + row_count); + break; + } + case TypeIndex::Int32: { + ptype->set_id(PGenericType::INT32); + auto* values = arg->mutable_int32_value(); + values->Reserve(row_count); + const auto* col = check_and_get_column(column); + auto& data = col->get_data(); + values->Add(data.begin(), data.begin() + row_count); + break; + } + case TypeIndex::Int64: { + ptype->set_id(PGenericType::INT64); + auto* values = arg->mutable_int64_value(); + values->Reserve(row_count); + const auto* col = check_and_get_column(column); + auto& data = col->get_data(); + values->Add(data.begin(), data.begin() + row_count); + break; + } + case TypeIndex::Int128: { + ptype->set_id(PGenericType::INT128); + arg->mutable_bytes_value()->Reserve(row_count); + for (size_t row_num = 0; row_num < row_count; ++row_num) { + if constexpr (nullable) { + if (column->is_null_at(row_num)) { + arg->add_bytes_value(nullptr); + } else { + StringRef data = column->get_data_at(row_num); + arg->add_bytes_value(data.data, data.size); + } + } else { + StringRef data = column->get_data_at(row_num); + arg->add_bytes_value(data.data, data.size); + } + } + break; + } + case TypeIndex::Float32: { + ptype->set_id(PGenericType::FLOAT); + auto* values = arg->mutable_float_value(); + values->Reserve(row_count); + const auto* col = check_and_get_column(column); + auto& data = col->get_data(); + values->Add(data.begin(), data.begin() + row_count); + break; + } + + case TypeIndex::Float64: { + ptype->set_id(PGenericType::DOUBLE); + auto* values = arg->mutable_double_value(); + values->Reserve(row_count); + const auto* col = check_and_get_column(column); + auto& data = col->get_data(); + values->Add(data.begin(), data.begin() + row_count); + break; + } + case TypeIndex::Decimal128: { + ptype->set_id(PGenericType::DECIMAL128); + auto dec_type = std::reinterpret_pointer_cast>(data_type); + ptype->mutable_decimal_type()->set_precision(dec_type->get_precision()); + ptype->mutable_decimal_type()->set_scale(dec_type->get_scale()); + arg->mutable_bytes_value()->Reserve(row_count); + for (size_t row_num = 0; row_num < row_count; ++row_num) { + if constexpr (nullable) { + if (column->is_null_at(row_num)) { + arg->add_bytes_value(nullptr); + } else { + StringRef data = column->get_data_at(row_num); + arg->add_bytes_value(data.data, data.size); + } + } else { + StringRef data = column->get_data_at(row_num); + arg->add_bytes_value(data.data, data.size); + } + } + break; + } + case TypeIndex::String: { + ptype->set_id(PGenericType::STRING); + arg->mutable_bytes_value()->Reserve(row_count); + for (size_t row_num = 0; row_num < row_count; ++row_num) { + if constexpr (nullable) { + if (column->is_null_at(row_num)) { + arg->add_string_value(nullptr); + } else { + StringRef data = column->get_data_at(row_num); + arg->add_string_value(data.to_string()); + } + } else { + StringRef data = column->get_data_at(row_num); + arg->add_string_value(data.to_string()); + } + } + break; + } + case TypeIndex::Date: { + ptype->set_id(PGenericType::DATE); + arg->mutable_datetime_value()->Reserve(row_count); + for (size_t row_num = 0; row_num < row_count; ++row_num) { + PDateTime* date_time = arg->add_datetime_value(); + if constexpr (nullable) { + if (!column->is_null_at(row_num)) { + VecDateTimeValue v = VecDateTimeValue(column->get_int(row_num)); + date_time->set_day(v.day()); + date_time->set_month(v.month()); + date_time->set_year(v.year()); + } + } else { + VecDateTimeValue v = VecDateTimeValue(column->get_int(row_num)); + date_time->set_day(v.day()); + date_time->set_month(v.month()); + date_time->set_year(v.year()); + } + } + break; + } + case TypeIndex::DateTime: { + ptype->set_id(PGenericType::DATETIME); + arg->mutable_datetime_value()->Reserve(row_count); + for (size_t row_num = 0; row_num < row_count; ++row_num) { + PDateTime* date_time = arg->add_datetime_value(); + if constexpr (nullable) { + if (!column->is_null_at(row_num)) { + VecDateTimeValue v = VecDateTimeValue(column->get_int(row_num)); + date_time->set_day(v.day()); + date_time->set_month(v.month()); + date_time->set_year(v.year()); + date_time->set_hour(v.hour()); + date_time->set_minute(v.minute()); + date_time->set_second(v.second()); + } + } else { + VecDateTimeValue v = VecDateTimeValue(column->get_int(row_num)); + date_time->set_day(v.day()); + date_time->set_month(v.month()); + date_time->set_year(v.year()); + date_time->set_hour(v.hour()); + date_time->set_minute(v.minute()); + date_time->set_second(v.second()); + } + } + break; + } + case TypeIndex::BitMap: { + ptype->set_id(PGenericType::BITMAP); + arg->mutable_bytes_value()->Reserve(row_count); + for (size_t row_num = 0; row_num < row_count; ++row_num) { + if constexpr (nullable) { + if (column->is_null_at(row_num)) { + arg->add_bytes_value(nullptr); + } else { + StringRef data = column->get_data_at(row_num); + arg->add_bytes_value(data.data, data.size); + } + } else { + StringRef data = column->get_data_at(row_num); + arg->add_bytes_value(data.data, data.size); + } + } + break; + } + default: + LOG(INFO) << "unknown type: " << data_type->get_name(); + ptype->set_id(PGenericType::UNKNOWN); + break; + } +} + +void convert_nullable_col_to_pvalue(const ColumnPtr& column, const DataTypePtr& data_type, + const ColumnUInt8& null_col, PValues* arg, size_t row_count) { + if (column->has_null(row_count)) { + auto* null_map = arg->mutable_null_map(); + null_map->Reserve(row_count); + const auto* col = check_and_get_column(null_col); + auto& data = col->get_data(); + null_map->Add(data.begin(), data.begin() + row_count); + convert_col_to_pvalue(column, data_type, arg, row_count); + } else { + convert_col_to_pvalue(column, data_type, arg, row_count); + } +} + +void convert_block_to_proto(Block& block, const ColumnNumbers& arguments, size_t input_rows_count, + PFunctionCallRequest* request) { + size_t row_count = std::min(block.rows(), input_rows_count); + for (size_t col_idx : arguments) { + PValues* arg = request->add_args(); + ColumnWithTypeAndName& column = block.get_by_position(col_idx); + arg->set_has_null(column.column->has_null(row_count)); + auto col = column.column->convert_to_full_column_if_const(); + if (auto* nullable = check_and_get_column(*col)) { + auto data_col = nullable->get_nested_column_ptr(); + auto& null_col = nullable->get_null_map_column(); + auto data_type = std::reinterpret_pointer_cast(column.type); + convert_nullable_col_to_pvalue(data_col->convert_to_full_column_if_const(), + data_type->get_nested_type(), null_col, arg, row_count); + } else { + convert_col_to_pvalue(col, column.type, arg, row_count); + } + } +} + +template +void convert_to_column(MutableColumnPtr& column, const PValues& result) { + switch (result.type().id()) { + case PGenericType::UINT8: { + column->reserve(result.uint32_value_size()); + column->resize(result.uint32_value_size()); + auto& data = reinterpret_cast(column.get())->get_data(); + for (int i = 0; i < result.uint32_value_size(); ++i) { + data[i] = result.uint32_value(i); + } + break; + } + case PGenericType::UINT16: { + column->reserve(result.uint32_value_size()); + column->resize(result.uint32_value_size()); + auto& data = reinterpret_cast(column.get())->get_data(); + for (int i = 0; i < result.uint32_value_size(); ++i) { + data[i] = result.uint32_value(i); + } + break; + } + case PGenericType::UINT32: { + column->reserve(result.uint32_value_size()); + column->resize(result.uint32_value_size()); + auto& data = reinterpret_cast(column.get())->get_data(); + for (int i = 0; i < result.uint32_value_size(); ++i) { + data[i] = result.uint32_value(i); + } + break; + } + case PGenericType::UINT64: { + column->reserve(result.uint64_value_size()); + column->resize(result.uint64_value_size()); + auto& data = reinterpret_cast(column.get())->get_data(); + for (int i = 0; i < result.uint64_value_size(); ++i) { + data[i] = result.uint64_value(i); + } + break; + } + case PGenericType::INT8: { + column->reserve(result.int32_value_size()); + column->resize(result.int32_value_size()); + auto& data = reinterpret_cast(column.get())->get_data(); + for (int i = 0; i < result.int32_value_size(); ++i) { + data[i] = result.int32_value(i); + } + break; + } + case PGenericType::INT16: { + column->reserve(result.int32_value_size()); + column->resize(result.int32_value_size()); + auto& data = reinterpret_cast(column.get())->get_data(); + for (int i = 0; i < result.int32_value_size(); ++i) { + data[i] = result.int32_value(i); + } + break; + } + case PGenericType::INT32: { + column->reserve(result.int32_value_size()); + column->resize(result.int32_value_size()); + auto& data = reinterpret_cast(column.get())->get_data(); + for (int i = 0; i < result.int32_value_size(); ++i) { + data[i] = result.int32_value(i); + } + break; + } + case PGenericType::INT64: { + column->reserve(result.int64_value_size()); + column->resize(result.int64_value_size()); + auto& data = reinterpret_cast(column.get())->get_data(); + for (int i = 0; i < result.int64_value_size(); ++i) { + data[i] = result.int64_value(i); + } + break; + } + case PGenericType::DATE: + case PGenericType::DATETIME: { + column->reserve(result.datetime_value_size()); + column->resize(result.datetime_value_size()); + auto& data = reinterpret_cast(column.get())->get_data(); + for (int i = 0; i < result.datetime_value_size(); ++i) { + VecDateTimeValue v; + PDateTime pv = result.datetime_value(i); + v.set_time(pv.year(), pv.month(), pv.day(), pv.hour(), pv.minute(), pv.minute()); + data[i] = binary_cast(v); + } + break; + } + case PGenericType::FLOAT: { + column->reserve(result.float_value_size()); + column->resize(result.float_value_size()); + auto& data = reinterpret_cast(column.get())->get_data(); + for (int i = 0; i < result.float_value_size(); ++i) { + data[i] = result.float_value(i); + } + break; + } + case PGenericType::DOUBLE: { + column->reserve(result.double_value_size()); + column->resize(result.double_value_size()); + auto& data = reinterpret_cast(column.get())->get_data(); + for (int i = 0; i < result.double_value_size(); ++i) { + data[i] = result.double_value(i); + } + break; + } + case PGenericType::INT128: { + column->reserve(result.bytes_value_size()); + column->resize(result.bytes_value_size()); + auto& data = reinterpret_cast(column.get())->get_data(); + for (int i = 0; i < result.bytes_value_size(); ++i) { + data[i] = *(int128_t*)(result.bytes_value(i).c_str()); + } + break; + } + case PGenericType::STRING: { + column->reserve(result.string_value_size()); + for (int i = 0; i < result.string_value_size(); ++i) { + column->insert_data(result.string_value(i).c_str(), result.string_value(i).size()); + } + break; + } + case PGenericType::DECIMAL128: { + column->reserve(result.bytes_value_size()); + column->resize(result.bytes_value_size()); + auto& data = reinterpret_cast(column.get())->get_data(); + for (int i = 0; i < result.bytes_value_size(); ++i) { + data[i] = *(int128_t*)(result.bytes_value(i).c_str()); + } + break; + } + case PGenericType::BITMAP: { + column->reserve(result.bytes_value_size()); + for (int i = 0; i < result.bytes_value_size(); ++i) { + column->insert_data(result.bytes_value(i).c_str(), result.bytes_value(i).size()); + } + break; + } + default: { + LOG(WARNING) << "unknown PGenericType: " << result.type().DebugString(); + break; + } + } +} + +void convert_to_block(Block& block, const PValues& result, size_t pos) { + auto data_type = block.get_data_type(pos); + if (data_type->is_nullable()) { + auto null_type = std::reinterpret_pointer_cast(data_type); + auto data_col = null_type->get_nested_type()->create_column(); + convert_to_column(data_col, result); + auto null_col = ColumnUInt8::create(data_col->size(), 0); + auto& null_map_data = null_col->get_data(); + null_col->reserve(data_col->size()); + null_col->resize(data_col->size()); + if (result.has_null()) { + for (int i = 0; i < data_col->size(); ++i) { + null_map_data[i] = result.null_map(i); + } + } else { + for (int i = 0; i < data_col->size(); ++i) { + null_map_data[i] = false; + } + } + block.replace_by_position( + pos, std::move(ColumnNullable::create(std::move(data_col), std::move(null_col)))); + } else { + auto column = data_type->create_column(); + convert_to_column(column, result); + block.replace_by_position(pos, std::move(column)); + } +} + +Status RPCFnCall::execute(FunctionContext* context, Block& block, const ColumnNumbers& arguments, + size_t result, size_t input_rows_count, bool dry_run) { + PFunctionCallRequest request; + PFunctionCallResponse response; + request.set_function_name(_symbol); + convert_block_to_proto(block, arguments, input_rows_count, &request); + brpc::Controller cntl; + _client->fn_call(&cntl, &request, &response, nullptr); + if (cntl.Failed()) { + return Status::InternalError( + fmt::format("call to rpc function {} failed: {}", _symbol, cntl.ErrorText()) + .c_str()); + } + if (response.status().status_code() != 0) { + return Status::InternalError(fmt::format("call to rpc function {} failed: {}", _symbol, + response.status().DebugString())); + } + convert_to_block(block, response.result(), result); + return Status::OK(); +} +} // namespace doris::vectorized diff --git a/be/src/vec/functions/function_rpc.h b/be/src/vec/functions/function_rpc.h new file mode 100644 index 0000000000..2c7535adfc --- /dev/null +++ b/be/src/vec/functions/function_rpc.h @@ -0,0 +1,68 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include "vec/functions/function.h" + +namespace doris { +class PFunctionService_Stub; + +namespace vectorized { +class RPCFnCall : public IFunctionBase { +public: + RPCFnCall(const std::string& symbol, const std::string& server, const DataTypes& argument_types, + const DataTypePtr& return_type); + static FunctionBasePtr create(const std::string& symbol, const std::string& server, + const ColumnsWithTypeAndName& argument_types, + const DataTypePtr& return_type) { + DataTypes data_types(argument_types.size()); + for (size_t i = 0; i < argument_types.size(); ++i) { + data_types[i] = argument_types[i].type; + } + return std::make_shared(symbol, server, data_types, return_type); + } + + /// Get the main function name. + String get_name() const override { return _name; }; + + const DataTypes& get_argument_types() const override { return _argument_types; }; + const DataTypePtr& get_return_type() const override { return _return_type; }; + + PreparedFunctionPtr prepare(FunctionContext* context, const Block& sample_block, + const ColumnNumbers& arguments, size_t result) const override { + return nullptr; + } + + Status prepare(FunctionContext* context, FunctionContext::FunctionStateScope scope) override; + + Status execute(FunctionContext* context, Block& block, const ColumnNumbers& arguments, + size_t result, size_t input_rows_count, bool dry_run = false) override; + + bool is_deterministic() const override { return false; } + + bool is_deterministic_in_scope_of_query() const override { return false; } + +private: + std::string _symbol; + std::string _server; + std::string _name; + DataTypes _argument_types; + DataTypePtr _return_type; + std::shared_ptr _client = nullptr; +}; + +} // namespace vectorized +} // namespace doris diff --git a/be/src/vec/sink/vdata_stream_sender.cpp b/be/src/vec/sink/vdata_stream_sender.cpp index e3e081df0a..295891ec84 100644 --- a/be/src/vec/sink/vdata_stream_sender.cpp +++ b/be/src/vec/sink/vdata_stream_sender.cpp @@ -54,13 +54,13 @@ Status VDataStreamSender::Channel::init(RuntimeState* state) { _brpc_request.set_be_number(_be_number); _brpc_timeout_ms = std::min(3600, state->query_options().query_timeout) * 1000; - _brpc_stub = state->exec_env()->brpc_stub_cache()->get_stub(_brpc_dest_addr); + _brpc_stub = state->exec_env()->brpc_internal_client_cache()->get_client(_brpc_dest_addr); if (_brpc_dest_addr.hostname == BackendOptions::get_localhost()) { - _brpc_stub = - state->exec_env()->brpc_stub_cache()->get_stub("127.0.0.1", _brpc_dest_addr.port); + _brpc_stub = state->exec_env()->brpc_internal_client_cache()->get_client( + "127.0.0.1", _brpc_dest_addr.port); } else { - _brpc_stub = state->exec_env()->brpc_stub_cache()->get_stub(_brpc_dest_addr); + _brpc_stub = state->exec_env()->brpc_internal_client_cache()->get_client(_brpc_dest_addr); } // In bucket shuffle join will set fragment_instance_id (-1, -1) diff --git a/be/src/vec/sink/vdata_stream_sender.h b/be/src/vec/sink/vdata_stream_sender.h index 223bf28f22..6ed99bd53f 100644 --- a/be/src/vec/sink/vdata_stream_sender.h +++ b/be/src/vec/sink/vdata_stream_sender.h @@ -25,7 +25,7 @@ #include "runtime/descriptors.h" #include "service/backend_options.h" #include "service/brpc.h" -#include "util/brpc_stub_cache.h" +#include "util/brpc_client_cache.h" #include "util/network_util.h" #include "util/ref_count_closure.h" #include "util/uid_util.h" @@ -81,7 +81,8 @@ private: } template - Status channel_add_rows(Channels& channels, int num_channels, const HashVals& hash_vals, int rows, Block* block); + Status channel_add_rows(Channels& channels, int num_channels, const HashVals& hash_vals, + int rows, Block* block); struct hash_128 { uint64_t high; @@ -159,13 +160,14 @@ public: _brpc_dest_addr(brpc_dest), _is_transfer_chain(is_transfer_chain), _send_query_statistics_with_every_batch(send_query_statistics_with_every_batch) { - std::string localhost = BackendOptions::get_localhost(); - _is_local = (_brpc_dest_addr.hostname == localhost) && (_brpc_dest_addr.port == config::brpc_port); - if (_is_local) { - LOG(INFO) << "will use local Exchange, dest_node_id is : "<<_dest_node_id; - } - } - + std::string localhost = BackendOptions::get_localhost(); + _is_local = (_brpc_dest_addr.hostname == localhost) && + (_brpc_dest_addr.port == config::brpc_port); + if (_is_local) { + LOG(INFO) << "will use local Exchange, dest_node_id is : " << _dest_node_id; + } + } + virtual ~Channel() { if (_closure != nullptr && _closure->unref()) { delete _closure; @@ -235,7 +237,6 @@ private: return Status::OK(); } - private: // Serialize _batch into _thrift_batch and send via send_batch(). // Returns send_batch() status. @@ -276,7 +277,8 @@ private: }; template -Status VDataStreamSender::channel_add_rows(Channels& channels, int num_channels, const HashVals& hash_vals, int rows, Block* block) { +Status VDataStreamSender::channel_add_rows(Channels& channels, int num_channels, + const HashVals& hash_vals, int rows, Block* block) { std::vector channel2rows[num_channels]; for (int i = 0; i < rows; i++) { diff --git a/be/test/exec/tablet_sink_test.cpp b/be/test/exec/tablet_sink_test.cpp index 44c5fbd68e..3d55699a6f 100644 --- a/be/test/exec/tablet_sink_test.cpp +++ b/be/test/exec/tablet_sink_test.cpp @@ -34,7 +34,7 @@ #include "runtime/types.h" #include "runtime/tuple_row.h" #include "service/brpc.h" -#include "util/brpc_stub_cache.h" +#include "util/brpc_client_cache.h" #include "util/cpu_info.h" #include "util/debug/leakcheck_disabler.h" #include "util/proto_util.h" @@ -54,7 +54,8 @@ public: _env->_thread_mgr = new ThreadResourceMgr(); _env->_master_info = new TMasterInfo(); _env->_load_stream_mgr = new LoadStreamMgr(); - _env->_brpc_stub_cache = new BrpcStubCache(); + _env->_internal_client_cache = new BrpcClientCache(); + _env->_function_client_cache = new BrpcClientCache(); _env->_buffer_reservation = new ReservationTracker(); ThreadPoolBuilder("SendBatchThreadPool") .set_min_threads(1) @@ -66,7 +67,8 @@ public: } void TearDown() override { - SAFE_DELETE(_env->_brpc_stub_cache); + SAFE_DELETE(_env->_internal_client_cache); + SAFE_DELETE(_env->_function_client_cache); SAFE_DELETE(_env->_load_stream_mgr); SAFE_DELETE(_env->_master_info); SAFE_DELETE(_env->_thread_mgr); diff --git a/be/test/http/stream_load_test.cpp b/be/test/http/stream_load_test.cpp index 0ea97a09aa..fc3435f67b 100644 --- a/be/test/http/stream_load_test.cpp +++ b/be/test/http/stream_load_test.cpp @@ -30,7 +30,7 @@ #include "runtime/stream_load/load_stream_mgr.h" #include "runtime/stream_load/stream_load_executor.h" #include "runtime/thread_resource_mgr.h" -#include "util/brpc_stub_cache.h" +#include "util/brpc_client_cache.h" #include "util/cpu_info.h" class mg_connection; @@ -74,14 +74,17 @@ public: _env._thread_mgr = new ThreadResourceMgr(); _env._master_info = new TMasterInfo(); _env._load_stream_mgr = new LoadStreamMgr(); - _env._brpc_stub_cache = new BrpcStubCache(); + _env._internal_client_cache = new BrpcClientCache(); + _env._function_client_cache = new BrpcClientCache(); _env._stream_load_executor = new StreamLoadExecutor(&_env); _evhttp_req = evhttp_request_new(nullptr, nullptr); } void TearDown() override { - delete _env._brpc_stub_cache; - _env._brpc_stub_cache = nullptr; + delete _env._internal_client_cache; + _env._internal_client_cache = nullptr; + delete _env._function_client_cache; + _env._function_client_cache = nullptr; delete _env._load_stream_mgr; _env._load_stream_mgr = nullptr; delete _env._master_info; diff --git a/be/test/util/CMakeLists.txt b/be/test/util/CMakeLists.txt index e9f75a1462..afa332c0da 100644 --- a/be/test/util/CMakeLists.txt +++ b/be/test/util/CMakeLists.txt @@ -19,7 +19,7 @@ set(EXECUTABLE_OUTPUT_PATH "${BUILD_DIR}/test/util") ADD_BE_TEST(bit_util_test) -ADD_BE_TEST(brpc_stub_cache_test) +ADD_BE_TEST(brpc_client_cache_test) ADD_BE_TEST(path_trie_test) ADD_BE_TEST(coding_test) ADD_BE_TEST(crc32c_test) diff --git a/be/test/util/brpc_stub_cache_test.cpp b/be/test/util/brpc_client_cache_test.cpp similarity index 73% rename from be/test/util/brpc_stub_cache_test.cpp rename to be/test/util/brpc_client_cache_test.cpp index cf68cc1356..c6ece74175 100644 --- a/be/test/util/brpc_stub_cache_test.cpp +++ b/be/test/util/brpc_client_cache_test.cpp @@ -15,40 +15,40 @@ // specific language governing permissions and limitations // under the License. -#include "util/brpc_stub_cache.h" +#include "util/brpc_client_cache.h" #include namespace doris { -class BrpcStubCacheTest : public testing::Test { +class BrpcClientCacheTest : public testing::Test { public: - BrpcStubCacheTest() {} - virtual ~BrpcStubCacheTest() {} + BrpcClientCacheTest() {} + virtual ~BrpcClientCacheTest() {} }; -TEST_F(BrpcStubCacheTest, normal) { - BrpcStubCache cache; +TEST_F(BrpcClientCacheTest, normal) { + BrpcClientCache cache; TNetworkAddress address; address.hostname = "127.0.0.1"; address.port = 123; - auto stub1 = cache.get_stub(address); + auto stub1 = cache.get_client(address); ASSERT_NE(nullptr, stub1); address.port = 124; - auto stub2 = cache.get_stub(address); + auto stub2 = cache.get_client(address); ASSERT_NE(nullptr, stub2); ASSERT_NE(stub1, stub2); address.port = 123; - auto stub3 = cache.get_stub(address); + auto stub3 = cache.get_client(address); ASSERT_EQ(stub1, stub3); } -TEST_F(BrpcStubCacheTest, invalid) { - BrpcStubCache cache; +TEST_F(BrpcClientCacheTest, invalid) { + BrpcClientCache cache; TNetworkAddress address; address.hostname = "invalid.cm.invalid"; address.port = 123; - auto stub1 = cache.get_stub(address); + auto stub1 = cache.get_client(address); ASSERT_EQ(nullptr, stub1); } diff --git a/be/test/vec/runtime/vdata_stream_test.cpp b/be/test/vec/runtime/vdata_stream_test.cpp index cc4d429718..5fef7617fd 100644 --- a/be/test/vec/runtime/vdata_stream_test.cpp +++ b/be/test/vec/runtime/vdata_stream_test.cpp @@ -65,18 +65,19 @@ private: std::unique_ptr _service; }; -class MockBrpcStubCache : public BrpcStubCache { +template +class MockBrpcClientCache : public BrpcClientCache { public: - MockBrpcStubCache(google::protobuf::RpcChannel* channel) { + MockBrpcClientCache(google::protobuf::RpcChannel* channel) { _channel.reset(channel); - _stub.reset(new PBackendService_Stub(channel)); + _stub.reset(new T(channel)); } - virtual ~MockBrpcStubCache() = default; - virtual std::shared_ptr get_stub(const TNetworkAddress&) { return _stub; } + virtual ~MockBrpcClientCache() = default; + virtual std::shared_ptr get_client(const TNetworkAddress&) { return _stub; } private: std::unique_ptr _channel; - std::shared_ptr _stub; + std::shared_ptr _stub; }; class VDataStreamTest : public testing::Test { @@ -107,8 +108,8 @@ TEST_F(VDataStreamTest, BasicTest) { mock_service->stream_mgr = &_instance; MockChannel* channel = new MockChannel(std::move(mock_service)); - runtime_stat._exec_env->_brpc_stub_cache = - _object_pool.add(new MockBrpcStubCache(std::move(channel))); + runtime_stat._exec_env->_internal_client_cache = + _object_pool.add(new MockBrpcClientCache(std::move(channel))); TUniqueId uid; PlanNodeId nid = 1; diff --git a/fe/fe-core/src/main/java/org/apache/doris/analysis/CreateFunctionStmt.java b/fe/fe-core/src/main/java/org/apache/doris/analysis/CreateFunctionStmt.java index 6e376ad02c..a0e2ccb931 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/analysis/CreateFunctionStmt.java +++ b/fe/fe-core/src/main/java/org/apache/doris/analysis/CreateFunctionStmt.java @@ -22,21 +22,30 @@ import org.apache.doris.catalog.AliasFunction; import org.apache.doris.catalog.Catalog; import org.apache.doris.catalog.Function; import org.apache.doris.catalog.ScalarFunction; +import org.apache.doris.catalog.ScalarType; import org.apache.doris.catalog.Type; import org.apache.doris.common.AnalysisException; +import org.apache.doris.common.Config; import org.apache.doris.common.ErrorCode; import org.apache.doris.common.ErrorReport; import org.apache.doris.common.FeConstants; import org.apache.doris.common.UserException; import org.apache.doris.common.util.Util; import org.apache.doris.mysql.privilege.PrivPredicate; +import org.apache.doris.proto.FunctionService; +import org.apache.doris.proto.PFunctionServiceGrpc; +import org.apache.doris.proto.Types; import org.apache.doris.qe.ConnectContext; +import org.apache.doris.thrift.TFunctionBinaryType; import com.google.common.base.Strings; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableSortedMap; import org.apache.commons.codec.binary.Hex; +import org.apache.commons.lang3.StringUtils; +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; import java.io.IOException; import java.io.InputStream; @@ -45,8 +54,12 @@ import java.security.NoSuchAlgorithmException; import java.util.List; import java.util.Map; +import io.grpc.ManagedChannel; +import io.grpc.netty.NettyChannelBuilder; + // create a user define function public class CreateFunctionStmt extends DdlStmt { + private final static Logger LOG = LogManager.getLogger(CreateFunctionStmt.class); public static final String OBJECT_FILE_KEY = "object_file"; public static final String SYMBOL_KEY = "symbol"; public static final String PREPARE_SYMBOL_KEY = "prepare_fn"; @@ -59,6 +72,7 @@ public class CreateFunctionStmt extends DdlStmt { public static final String FINALIZE_KEY = "finalize_fn"; public static final String GET_VALUE_KEY = "get_value_fn"; public static final String REMOVE_KEY = "remove_fn"; + public static final String BINARY_TYPE = "type"; private final FunctionName functionName; private final boolean isAggregate; @@ -69,11 +83,12 @@ public class CreateFunctionStmt extends DdlStmt { private final Map properties; private final List parameters; private final Expr originFunction; + TFunctionBinaryType binaryType = TFunctionBinaryType.NATIVE; // needed item set after analyzed private String objectFile; private Function function; - private String checksum; + private String checksum = ""; // timeout for both connection and read. 10 seconds is long enough. private static final int HTTP_TIMEOUT_MS = 10000; @@ -111,8 +126,13 @@ public class CreateFunctionStmt extends DdlStmt { this.properties = ImmutableSortedMap.of(); } - public FunctionName getFunctionName() { return functionName; } - public Function getFunction() { return function; } + public FunctionName getFunctionName() { + return functionName; + } + + public Function getFunction() { + return function; + } public Expr getOriginFunction() { return originFunction; @@ -156,26 +176,32 @@ public class CreateFunctionStmt extends DdlStmt { intermediateType = returnType; } + String type = properties.getOrDefault(BINARY_TYPE, "NATIVE"); + binaryType = getFunctionBinaryType(type); + if (binaryType == null) { + throw new AnalysisException("unknown function type"); + } + objectFile = properties.get(OBJECT_FILE_KEY); if (Strings.isNullOrEmpty(objectFile)) { throw new AnalysisException("No 'object_file' in properties"); } - try { - computeObjectChecksum(); - } catch (IOException | NoSuchAlgorithmException e) { - throw new AnalysisException("cannot to compute object's checksum"); - } - - String md5sum = properties.get(MD5_CHECKSUM); - if (md5sum != null && !md5sum.equalsIgnoreCase(checksum)) { - throw new AnalysisException("library's checksum is not equal with input, checksum=" + checksum); + if (binaryType != TFunctionBinaryType.RPC) { + try { + computeObjectChecksum(); + } catch (IOException | NoSuchAlgorithmException e) { + throw new AnalysisException("cannot to compute object's checksum"); + } + String md5sum = properties.get(MD5_CHECKSUM); + if (md5sum != null && !md5sum.equalsIgnoreCase(checksum)) { + throw new AnalysisException("library's checksum is not equal with input, checksum=" + checksum); + } } } private void computeObjectChecksum() throws IOException, NoSuchAlgorithmException { if (FeConstants.runningUnitTest) { // skip checking checksum when running ut - checksum = ""; return; } @@ -196,6 +222,9 @@ public class CreateFunctionStmt extends DdlStmt { } private void analyzeUda() throws AnalysisException { + if (binaryType == TFunctionBinaryType.RPC) { + throw new AnalysisException("RPC UDAF is not supported."); + } AggregateFunction.AggregateFunctionBuilder builder = AggregateFunction.AggregateFunctionBuilder.createUdfBuilder(); builder.name(functionName).argsType(argsDef.getArgTypes()).retType(returnType.getType()). @@ -227,13 +256,111 @@ public class CreateFunctionStmt extends DdlStmt { } String prepareFnSymbol = properties.get(PREPARE_SYMBOL_KEY); String closeFnSymbol = properties.get(CLOSE_SYMBOL_KEY); - function = ScalarFunction.createUdf( + // TODO(yangzhg) support check function in FE when function service behind load balancer + // the format for load balance can ref https://github.com/apache/incubator-brpc/blob/master/docs/en/client.md#connect-to-a-cluster + if (binaryType == TFunctionBinaryType.RPC && !objectFile.contains("://")) { + if (StringUtils.isNotBlank(prepareFnSymbol) || StringUtils.isNotBlank(closeFnSymbol)) { + throw new AnalysisException(" prepare and close in RPC UDF are not supported."); + } + String[] url = objectFile.split(":"); + if (url.length != 2) { + throw new AnalysisException("function server address invalid."); + } + String host = url[0]; + int port = Integer.valueOf(url[1]); + ManagedChannel channel = NettyChannelBuilder.forAddress(host, port) + .flowControlWindow(Config.grpc_max_message_size_bytes) + .maxInboundMessageSize(Config.grpc_max_message_size_bytes) + .enableRetry().maxRetryAttempts(3) + .usePlaintext().build(); + PFunctionServiceGrpc.PFunctionServiceBlockingStub stub = PFunctionServiceGrpc.newBlockingStub(channel); + FunctionService.PCheckFunctionRequest.Builder builder = FunctionService.PCheckFunctionRequest.newBuilder(); + builder.getFunctionBuilder().setFunctionName(functionName.getFunction()); + for (Type arg : argsDef.getArgTypes()) { + builder.getFunctionBuilder().addInputs(convertToPParameterType(arg)); + } + builder.getFunctionBuilder().setOutput(convertToPParameterType(returnType.getType())); + FunctionService.PCheckFunctionResponse response = stub.checkFn(builder.build()); + if (response.getStatus().getStatusCode() != 0) { + throw new AnalysisException("cannot access function server:" + response.getStatus()); + } + } + function = ScalarFunction.createUdf(binaryType, functionName, argsDef.getArgTypes(), returnType.getType(), argsDef.isVariadic(), objectFile, symbol, prepareFnSymbol, closeFnSymbol); function.setChecksum(checksum); } + private Types.PGenericType convertToPParameterType(Type arg) throws AnalysisException { + Types.PGenericType.Builder typeBuilder = Types.PGenericType.newBuilder(); + switch (arg.getPrimitiveType()) { + case INVALID_TYPE: + typeBuilder.setId(Types.PGenericType.TypeId.UNKNOWN); + break; + case BOOLEAN: + typeBuilder.setId(Types.PGenericType.TypeId.BOOLEAN); + break; + case SMALLINT: + typeBuilder.setId(Types.PGenericType.TypeId.INT16); + break; + case TINYINT: + typeBuilder.setId(Types.PGenericType.TypeId.INT8); + break; + case INT: + typeBuilder.setId(Types.PGenericType.TypeId.INT32); + break; + case BIGINT: + typeBuilder.setId(Types.PGenericType.TypeId.INT64); + break; + case FLOAT: + typeBuilder.setId(Types.PGenericType.TypeId.FLOAT); + break; + case DOUBLE: + typeBuilder.setId(Types.PGenericType.TypeId.DOUBLE); + break; + case CHAR: + case VARCHAR: + typeBuilder.setId(Types.PGenericType.TypeId.STRING); + break; + case HLL: + typeBuilder.setId(Types.PGenericType.TypeId.HLL); + break; + case BITMAP: + typeBuilder.setId(Types.PGenericType.TypeId.BITMAP); + break; + case DATE: + typeBuilder.setId(Types.PGenericType.TypeId.DATE); + break; + case DATETIME: + case TIME: + typeBuilder.setId(Types.PGenericType.TypeId.DATETIME); + break; + case DECIMALV2: + typeBuilder.setId(Types.PGenericType.TypeId.DECIMAL128) + .getDecimalTypeBuilder() + .setPrecision(((ScalarType) arg).getScalarPrecision()) + .setScale(((ScalarType) arg).getScalarScale()); + break; + case LARGEINT: + typeBuilder.setId(Types.PGenericType.TypeId.INT128); + break; + default: + throw new AnalysisException("type " + arg.getPrimitiveType().toString() + " is not supported"); + } + return typeBuilder.build(); + } + + private TFunctionBinaryType getFunctionBinaryType(String type) { + TFunctionBinaryType binaryType = null; + try { + binaryType = TFunctionBinaryType.valueOf(type); + } catch (IllegalArgumentException e) { + // ignore enum Exception + } + return binaryType; + } + private void analyzeAliasFunction() throws AnalysisException { function = AliasFunction.createFunction(functionName, argsDef.getArgTypes(), Type.VARCHAR, argsDef.isVariadic(), parameters, originFunction); @@ -279,8 +406,8 @@ public class CreateFunctionStmt extends DdlStmt { } return stringBuilder.toString(); } - - @Override + + @Override public RedirectStatus getRedirectStatus() { return RedirectStatus.FORWARD_WITH_SYNC; } diff --git a/fe/fe-core/src/main/java/org/apache/doris/catalog/ScalarFunction.java b/fe/fe-core/src/main/java/org/apache/doris/catalog/ScalarFunction.java index 5f216d3a98..308bda052f 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/catalog/ScalarFunction.java +++ b/fe/fe-core/src/main/java/org/apache/doris/catalog/ScalarFunction.java @@ -312,11 +312,12 @@ public class ScalarFunction extends Function { } public static ScalarFunction createUdf( + TFunctionBinaryType binaryType, FunctionName name, Type[] args, Type returnType, boolean isVariadic, String objectFile, String symbol, String prepareFnSymbol, String closeFnSymbol) { - ScalarFunction fn = new ScalarFunction(name, Arrays.asList(args), returnType, isVariadic, - TFunctionBinaryType.NATIVE, true, false); + ScalarFunction fn = new ScalarFunction(name, Arrays.asList(args), returnType, isVariadic, binaryType, + true, false); fn.symbolName = symbol; fn.prepareFnSymbol = prepareFnSymbol; fn.closeFnSymbol = closeFnSymbol; diff --git a/fe/fe-core/src/main/java/org/apache/doris/common/Status.java b/fe/fe-core/src/main/java/org/apache/doris/common/Status.java index 7d6b7c609c..1104cc4c3c 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/common/Status.java +++ b/fe/fe-core/src/main/java/org/apache/doris/common/Status.java @@ -17,7 +17,7 @@ package org.apache.doris.common; -import org.apache.doris.proto.Status.PStatus; +import org.apache.doris.proto.Types.PStatus; import org.apache.doris.thrift.TStatus; import org.apache.doris.thrift.TStatusCode; diff --git a/fe/fe-core/src/main/java/org/apache/doris/qe/Coordinator.java b/fe/fe-core/src/main/java/org/apache/doris/qe/Coordinator.java index 407bbea040..27b072db8b 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/qe/Coordinator.java +++ b/fe/fe-core/src/main/java/org/apache/doris/qe/Coordinator.java @@ -2018,7 +2018,7 @@ public class Coordinator { public InternalService.PExecPlanFragmentResult get() { InternalService.PExecPlanFragmentResult result = InternalService.PExecPlanFragmentResult .newBuilder() - .setStatus(org.apache.doris.proto.Status.PStatus.newBuilder() + .setStatus(org.apache.doris.proto.Types.PStatus.newBuilder() .addErrorMsgs(e.getMessage()) .setStatusCode(TStatusCode.THRIFT_RPC_ERROR.getValue()) .build()) diff --git a/fe/fe-core/src/test/java/org/apache/doris/load/sync/canal/CanalSyncDataTest.java b/fe/fe-core/src/test/java/org/apache/doris/load/sync/canal/CanalSyncDataTest.java index 70815a3b4a..a3051c65dd 100644 --- a/fe/fe-core/src/test/java/org/apache/doris/load/sync/canal/CanalSyncDataTest.java +++ b/fe/fe-core/src/test/java/org/apache/doris/load/sync/canal/CanalSyncDataTest.java @@ -24,7 +24,6 @@ import org.apache.doris.common.AnalysisException; import org.apache.doris.common.Config; import org.apache.doris.planner.StreamLoadPlanner; import org.apache.doris.proto.InternalService; -import org.apache.doris.proto.Status; import org.apache.doris.proto.Types; import org.apache.doris.resource.Tag; import org.apache.doris.rpc.BackendServiceProxy; @@ -97,22 +96,22 @@ public class CanalSyncDataTest { SystemInfoService systemInfoService; InternalService.PExecPlanFragmentResult beginOkResult = InternalService.PExecPlanFragmentResult.newBuilder() - .setStatus(Status.PStatus.newBuilder().setStatusCode(0).build()).build(); // begin txn OK + .setStatus(Types.PStatus.newBuilder().setStatusCode(0).build()).build(); // begin txn OK InternalService.PExecPlanFragmentResult beginFailResult = InternalService.PExecPlanFragmentResult.newBuilder() - .setStatus(Status.PStatus.newBuilder().setStatusCode(1).build()).build(); // begin txn CANCELLED + .setStatus(Types.PStatus.newBuilder().setStatusCode(1).build()).build(); // begin txn CANCELLED InternalService.PCommitResult commitOkResult = InternalService.PCommitResult.newBuilder() - .setStatus(Status.PStatus.newBuilder().setStatusCode(0).build()).build(); // commit txn OK + .setStatus(Types.PStatus.newBuilder().setStatusCode(0).build()).build(); // commit txn OK InternalService.PCommitResult commitFailResult = InternalService.PCommitResult.newBuilder() - .setStatus(Status.PStatus.newBuilder().setStatusCode(1).build()).build(); // commit txn CANCELLED + .setStatus(Types.PStatus.newBuilder().setStatusCode(1).build()).build(); // commit txn CANCELLED InternalService.PRollbackResult abortOKResult = InternalService.PRollbackResult.newBuilder() - .setStatus(Status.PStatus.newBuilder().setStatusCode(0).build()).build(); // abort txn OK + .setStatus(Types.PStatus.newBuilder().setStatusCode(0).build()).build(); // abort txn OK InternalService.PSendDataResult sendDataOKResult = InternalService.PSendDataResult.newBuilder() - .setStatus(Status.PStatus.newBuilder().setStatusCode(0).build()).build(); // send data OK + .setStatus(Types.PStatus.newBuilder().setStatusCode(0).build()).build(); // send data OK @Before public void setUp() throws Exception { diff --git a/fe/fe-core/src/test/java/org/apache/doris/utframe/MockedBackendFactory.java b/fe/fe-core/src/test/java/org/apache/doris/utframe/MockedBackendFactory.java index b04b54a97c..42dab10411 100644 --- a/fe/fe-core/src/test/java/org/apache/doris/utframe/MockedBackendFactory.java +++ b/fe/fe-core/src/test/java/org/apache/doris/utframe/MockedBackendFactory.java @@ -21,7 +21,7 @@ import org.apache.doris.common.ClientPool; import org.apache.doris.proto.Data; import org.apache.doris.proto.InternalService; import org.apache.doris.proto.PBackendServiceGrpc; -import org.apache.doris.proto.Status; +import org.apache.doris.proto.Types; import org.apache.doris.thrift.BackendService; import org.apache.doris.thrift.FrontendService; import org.apache.doris.thrift.HeartbeatService; @@ -326,7 +326,7 @@ public class MockedBackendFactory { @Override public void transmitData(InternalService.PTransmitDataParams request, StreamObserver responseObserver) { responseObserver.onNext(InternalService.PTransmitDataResult.newBuilder() - .setStatus(Status.PStatus.newBuilder().setStatusCode(0)).build()); + .setStatus(Types.PStatus.newBuilder().setStatusCode(0)).build()); responseObserver.onCompleted(); } @@ -334,7 +334,7 @@ public class MockedBackendFactory { public void execPlanFragment(InternalService.PExecPlanFragmentRequest request, StreamObserver responseObserver) { System.out.println("get exec_plan_fragment request"); responseObserver.onNext(InternalService.PExecPlanFragmentResult.newBuilder() - .setStatus(Status.PStatus.newBuilder().setStatusCode(0)).build()); + .setStatus(Types.PStatus.newBuilder().setStatusCode(0)).build()); responseObserver.onCompleted(); } @@ -342,7 +342,7 @@ public class MockedBackendFactory { public void cancelPlanFragment(InternalService.PCancelPlanFragmentRequest request, StreamObserver responseObserver) { System.out.println("get cancel_plan_fragment request"); responseObserver.onNext(InternalService.PCancelPlanFragmentResult.newBuilder() - .setStatus(Status.PStatus.newBuilder().setStatusCode(0)).build()); + .setStatus(Types.PStatus.newBuilder().setStatusCode(0)).build()); responseObserver.onCompleted(); } @@ -350,7 +350,7 @@ public class MockedBackendFactory { public void fetchData(InternalService.PFetchDataRequest request, StreamObserver responseObserver) { System.out.println("get fetch_data request"); responseObserver.onNext(InternalService.PFetchDataResult.newBuilder() - .setStatus(Status.PStatus.newBuilder().setStatusCode(0)) + .setStatus(Types.PStatus.newBuilder().setStatusCode(0)) .setQueryStatistics(Data.PQueryStatistics.newBuilder() .setScanRows(0L) .setScanBytes(0L)) @@ -382,7 +382,7 @@ public class MockedBackendFactory { public void getInfo(InternalService.PProxyRequest request, StreamObserver responseObserver) { System.out.println("get get_info request"); responseObserver.onNext(InternalService.PProxyResult.newBuilder() - .setStatus(Status.PStatus.newBuilder().setStatusCode(0)).build()); + .setStatus(Types.PStatus.newBuilder().setStatusCode(0)).build()); responseObserver.onCompleted(); } diff --git a/gensrc/proto/function_service.proto b/gensrc/proto/function_service.proto new file mode 100644 index 0000000000..561be9f887 --- /dev/null +++ b/gensrc/proto/function_service.proto @@ -0,0 +1,63 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +syntax="proto2"; + +package doris; +option java_package = "org.apache.doris.proto"; +option cc_generic_services = true; + +import "types.proto"; + +message PRequestContext { + optional string id = 1; + optional PFunctionContext function_context = 2; +} + +message PFunctionCallRequest { + optional string function_name = 1; + repeated PValues args = 2; + optional PRequestContext context = 3; +} + +message PFunctionCallResponse { + optional PValues result = 1; + optional PStatus status = 2; +} + +message PCheckFunctionRequest { + enum MatchType { + IDENTICAL = 0; + INDISTINGUISHABLE = 1; + SUPERTYPE_OF = 2; + NONSTRICT_SUPERTYPE_OF = 3; + MATCHABLE = 4; + } + optional PFunction function = 1; + optional MatchType match_type = 2; +} + +message PCheckFunctionResponse { + optional PStatus status = 1; +} + +service PFunctionService { + rpc fn_call(PFunctionCallRequest) returns (PFunctionCallResponse); + rpc check_fn(PCheckFunctionRequest) returns (PCheckFunctionResponse); + rpc hand_shake(PHandShakeRequest) returns (PHandShakeResponse); +} + diff --git a/gensrc/proto/internal_service.proto b/gensrc/proto/internal_service.proto index d01a5fed72..41a0dce4bb 100644 --- a/gensrc/proto/internal_service.proto +++ b/gensrc/proto/internal_service.proto @@ -22,7 +22,6 @@ option java_package = "org.apache.doris.proto"; import "data.proto"; import "descriptors.proto"; -import "status.proto"; import "types.proto"; option cc_generic_services = true; @@ -430,15 +429,6 @@ message PResetRPCChannelResponse { repeated string channels = 2; }; -message PHandShakeRequest { - optional string hello = 1; -} - -message PHandShakeResponse { - optional PStatus status = 1; - optional string hello = 2; -} - service PBackendService { rpc transmit_data(PTransmitDataParams) returns (PTransmitDataResult); rpc exec_plan_fragment(PExecPlanFragmentRequest) returns (PExecPlanFragmentResult); diff --git a/gensrc/proto/status.proto b/gensrc/proto/status.proto deleted file mode 100644 index d1e9e7dda0..0000000000 --- a/gensrc/proto/status.proto +++ /dev/null @@ -1,27 +0,0 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - -syntax="proto2"; - -package doris; -option java_package = "org.apache.doris.proto"; - -message PStatus { - required int32 status_code = 1; - repeated string error_msgs = 2; -}; - diff --git a/gensrc/proto/types.proto b/gensrc/proto/types.proto index 762229df1e..f7bff5c878 100644 --- a/gensrc/proto/types.proto +++ b/gensrc/proto/types.proto @@ -20,6 +20,10 @@ syntax="proto2"; package doris; option java_package = "org.apache.doris.proto"; +message PStatus { + required int32 status_code = 1; + repeated string error_msgs = 2; +}; message PScalarType { // TPrimitiveType, use int32 to avoid redefine Enum required int32 type = 1; @@ -63,3 +67,150 @@ message PUniqueId { required int64 lo = 2; }; +message PGenericType { + enum TypeId { + UINT8 = 0; + UINT16 = 1; + UINT32 = 2; + UINT64 = 3; + UINT128 = 4; + UINT256 = 5; + INT8 = 6; + INT16 = 7; + INT32 = 8; + INT64 = 9; + INT128 = 10; + INT256 = 11; + FLOAT = 12; + DOUBLE = 13; + BOOLEAN = 14; + DATE = 15; + DATETIME = 16; + HLL = 17; + BITMAP = 18; + LIST = 19; + MAP = 20; + STRUCT =21; + STRING = 22; + DECIMAL32 = 23; + DECIMAL64 = 24; + DECIMAL128 = 25; + BYTES = 26; + NOTHING = 27; + UNKNOWN = 999; + } + required TypeId id = 2; + optional PList list_type = 11; + optional PMap map_type = 12; + optional PStruct struct_type = 13; + optional PDecimal decimal_type = 14; +} + +message PList { + required PGenericType element_type = 1; +} + +message PMap { + required PGenericType key_type = 1; + required PGenericType value_type = 2; +} + +message PField { + required PGenericType type = 1; + optional string name = 2; + optional string comment = 3; +} + +message PStruct { + repeated PField fields = 1; + required string name = 2; +} + +message PDecimal { + required uint32 precision = 1; + required uint32 scale = 2; +} + +message PDateTime { + optional int32 year = 1; + optional int32 month = 2; + optional int32 day = 3; + optional int32 hour = 4; + optional int32 minute = 5; + optional int32 second = 6; + optional int32 microsecond = 7; +} + +message PValue { + required PGenericType type = 1; + optional bool is_null = 2 [default = false]; + optional double double_value = 3; + optional float float_value = 4; + optional int32 int32_value = 5; + optional int64 int64_value = 6; + optional uint32 uint32_value = 7; + optional uint64 uint64_value = 8; + optional bool bool_value = 9; + optional string string_value = 10; + optional bytes bytes_value = 11; + optional PDateTime datetime_value = 12; +} + +message PValues { + required PGenericType type = 1; + optional bool has_null = 2 [default = false]; + repeated bool null_map = 3; + repeated double double_value = 4; + repeated float float_value = 5; + repeated int32 int32_value = 6; + repeated int64 int64_value = 7; + repeated uint32 uint32_value = 8; + repeated uint64 uint64_value = 9; + repeated bool bool_value = 10; + repeated string string_value = 11; + repeated bytes bytes_value = 12; + repeated PDateTime datetime_value = 13; +} + +// this mesage may not used for now +message PFunction { + enum FunctionType { + UDF = 0; + // not supported now + UDAF = 1; + UDTF = 2; + } + message Property { + required string key = 1; + required string val = 2; + }; + required string function_name = 1; + repeated PGenericType inputs = 2; + optional PGenericType output = 3; + optional FunctionType type = 4 [default = UDF]; + optional bool variadic = 5; + repeated Property properties = 6; +} + +message PFunctionContext { + optional string version = 1 [default = "V2_0"]; + repeated PValue staging_input_vals = 2; + repeated PValue constant_args = 3; + optional string error_msg = 4; + optional PUniqueId query_id = 5; + optional bytes thread_local_fn_state = 6; + optional bytes fragment_local_fn_state = 7; + optional string string_result = 8; + optional int64 num_updates = 9; + optional int64 num_removes = 10; + optional int64 num_warnings = 11; +} + +message PHandShakeRequest { + optional string hello = 1; +} + +message PHandShakeResponse { + optional PStatus status = 1; + optional string hello = 2; +} diff --git a/gensrc/thrift/Types.thrift b/gensrc/thrift/Types.thrift index 7dc77f3573..c1c8487ab1 100644 --- a/gensrc/thrift/Types.thrift +++ b/gensrc/thrift/Types.thrift @@ -254,7 +254,7 @@ enum TFunctionType { } enum TFunctionBinaryType { - // Palo builtin. We can either run this interpreted or via codegen + // Doris builtin. We can either run this interpreted or via codegen // depending on the query option. BUILTIN, @@ -266,6 +266,9 @@ enum TFunctionBinaryType { // Native-interface, precompiled to IR; loaded from *.ll IR, + + // call udfs by rpc service + RPC, } // Represents a fully qualified function name. diff --git a/run-be-ut.sh b/run-be-ut.sh index 51f27f5d0e..904d197d44 100755 --- a/run-be-ut.sh +++ b/run-be-ut.sh @@ -135,6 +135,7 @@ ${CMAKE_CMD} -G "${GENERATOR}" \ -DCMAKE_BUILD_TYPE="${CMAKE_BUILD_TYPE}" \ -DMAKE_TEST=ON \ -DGLIBC_COMPATIBILITY="${GLIBC_COMPATIBILITY}" \ + -DBUILD_META_TOOL=OFF \ -DWITH_MYSQL=OFF \ ${CMAKE_USE_CCACHE} ../ ${BUILD_SYSTEM} -j ${PARALLEL} $RUN_FILE