From ae680b4248826aa2266811240fc0d526fa41ac16 Mon Sep 17 00:00:00 2001 From: Zhengguo Yang Date: Thu, 21 Apr 2022 17:38:58 +0800 Subject: [PATCH] [UDF] support RPC udaf part 1: support create RPC udaf in fe (#8510) --- be/src/exec/partitioned_aggregation_node.cc | 8 +- be/src/exprs/CMakeLists.txt | 3 +- be/src/exprs/{agg_fn.cc => agg_fn.cpp} | 114 ++- be/src/exprs/agg_fn.h | 69 +- be/src/exprs/expr_context.h | 2 +- be/src/exprs/new_agg_fn_evaluator.cc | 23 +- be/src/exprs/new_agg_fn_evaluator_ir.cc | 31 - be/src/exprs/rpc_fn.cpp | 770 ++++++++++++++++++ be/src/exprs/rpc_fn.h | 136 ++++ be/src/exprs/rpc_fn_call.cpp | 273 +------ be/src/exprs/rpc_fn_call.h | 20 +- be/src/udf/udf.cpp | 14 - be/src/udf/udf_internal.h | 4 - be/src/vec/core/block.cpp | 12 +- be/src/vec/core/block.h | 2 +- be/src/vec/exprs/vectorized_fn_call.cpp | 4 +- be/src/vec/functions/function_rpc.cpp | 540 +----------- be/src/vec/functions/function_rpc.h | 25 +- contrib/udf/CMakeLists.txt | 16 - .../udf/native-user-defined-function.md | 2 - .../Data Definition/create-function.md | 11 +- .../udf/native-user-defined-function.md | 2 - .../Data Definition/create-function.md | 56 +- .../doris/analysis/CreateFunctionStmt.java | 86 +- .../doris/catalog/AggregateFunction.java | 4 + .../org/apache/doris/common/util/URI.java | 5 + gensrc/proto/function_service.proto | 3 +- gensrc/proto/types.proto | 12 +- .../cpp_function_service_demo.cpp | 36 +- .../apache/doris/udf/FunctionServiceImpl.java | 2 +- .../function_server_demo.py | 2 +- 31 files changed, 1220 insertions(+), 1067 deletions(-) rename be/src/exprs/{agg_fn.cc => agg_fn.cpp} (58%) delete mode 100644 be/src/exprs/new_agg_fn_evaluator_ir.cc create mode 100644 be/src/exprs/rpc_fn.cpp create mode 100644 be/src/exprs/rpc_fn.h diff --git a/be/src/exec/partitioned_aggregation_node.cc b/be/src/exec/partitioned_aggregation_node.cc index 80b8d49764..a99b0da23b 100644 --- a/be/src/exec/partitioned_aggregation_node.cc +++ b/be/src/exec/partitioned_aggregation_node.cc @@ -174,10 +174,10 @@ Status PartitionedAggregationNode::init(const TPlanNode& tnode, RuntimeState* st SlotDescriptor* intermediate_slot_desc = intermediate_tuple_desc_->slots()[j]; SlotDescriptor* output_slot_desc = output_tuple_desc_->slots()[j]; AggFn* agg_fn; - RETURN_IF_ERROR(AggFn::Create(tnode.agg_node.aggregate_functions[i], row_desc, + RETURN_IF_ERROR(AggFn::create(tnode.agg_node.aggregate_functions[i], row_desc, *intermediate_slot_desc, *output_slot_desc, state, &agg_fn)); agg_fns_.push_back(agg_fn); - needs_serialize_ |= agg_fn->SupportsSerialize(); + needs_serialize_ |= agg_fn->supports_serialize(); } return Status::OK(); } @@ -719,7 +719,7 @@ Status PartitionedAggregationNode::close(RuntimeState* state) { } Expr::close(grouping_exprs_); Expr::close(build_exprs_); - AggFn::Close(agg_fns_); + AggFn::close(agg_fns_); return ExecNode::close(state); } @@ -1105,7 +1105,7 @@ void PartitionedAggregationNode::DebugString(int indentation_level, stringstream << "intermediate_tuple_id=" << intermediate_tuple_id_ << " output_tuple_id=" << output_tuple_id_ << " needs_finalize=" << needs_finalize_ << " grouping_exprs=" << Expr::debug_string(grouping_exprs_) - << " agg_exprs=" << AggFn::DebugString(agg_fns_); + << " agg_exprs=" << AggFn::debug_string(agg_fns_); ExecNode::debug_string(indentation_level, out); *out << ")"; } diff --git a/be/src/exprs/CMakeLists.txt b/be/src/exprs/CMakeLists.txt index 7ac1af85df..ff0a8037bb 100644 --- a/be/src/exprs/CMakeLists.txt +++ b/be/src/exprs/CMakeLists.txt @@ -52,6 +52,7 @@ add_library(Exprs math_functions.cpp null_literal.cpp scalar_fn_call.cpp + rpc_fn.cpp rpc_fn_call.cpp slot_ref.cpp string_functions.cpp @@ -64,7 +65,7 @@ add_library(Exprs json_functions.cpp operators.cpp hll_hash_function.cpp - agg_fn.cc + agg_fn.cpp new_agg_fn_evaluator.cc bitmap_function.cpp hll_function.cpp diff --git a/be/src/exprs/agg_fn.cc b/be/src/exprs/agg_fn.cpp similarity index 58% rename from be/src/exprs/agg_fn.cc rename to be/src/exprs/agg_fn.cpp index ca6d41f967..a04ef1ba86 100644 --- a/be/src/exprs/agg_fn.cc +++ b/be/src/exprs/agg_fn.cpp @@ -21,6 +21,7 @@ #include "exprs/agg_fn.h" #include "exprs/anyval_util.h" +#include "exprs/rpc_fn.h" #include "runtime/descriptors.h" #include "runtime/runtime_state.h" #include "runtime/user_function_cache.h" @@ -67,7 +68,7 @@ AggFn::AggFn(const TExprNode& tnode, const SlotDescriptor& intermediate_slot_des } } -Status AggFn::Init(const RowDescriptor& row_desc, RuntimeState* state) { +Status AggFn::init(const RowDescriptor& row_desc, RuntimeState* state) { // TODO chenhao , calling expr's prepare in NewAggFnEvaluator create // Initialize all children (i.e. input exprs to this aggregate expr). //for (Expr* input_expr : children()) { @@ -89,45 +90,74 @@ Status AggFn::Init(const RowDescriptor& row_desc, RuntimeState* state) { ss << "Function " << _fn.name.function_name << " is not implemented."; return Status::InternalError(ss.str()); } - - RETURN_IF_ERROR(UserFunctionCache::instance()->get_function_ptr( - _fn.id, aggregate_fn.init_fn_symbol, _fn.hdfs_location, _fn.checksum, &init_fn_, - &_cache_entry)); - RETURN_IF_ERROR(UserFunctionCache::instance()->get_function_ptr( - _fn.id, aggregate_fn.update_fn_symbol, _fn.hdfs_location, _fn.checksum, &update_fn_, - &_cache_entry)); - - // Merge() is not defined for purely analytic function. - if (!aggregate_fn.is_analytic_only_fn) { + if (_fn.binary_type == TFunctionBinaryType::NATIVE || + _fn.binary_type == TFunctionBinaryType::BUILTIN || + _fn.binary_type == TFunctionBinaryType::HIVE) { RETURN_IF_ERROR(UserFunctionCache::instance()->get_function_ptr( - _fn.id, aggregate_fn.merge_fn_symbol, _fn.hdfs_location, _fn.checksum, &merge_fn_, + _fn.id, aggregate_fn.init_fn_symbol, _fn.hdfs_location, _fn.checksum, &_init_fn, &_cache_entry)); - } - // Serialize(), GetValue(), Remove() and Finalize() are optional - if (!aggregate_fn.serialize_fn_symbol.empty()) { RETURN_IF_ERROR(UserFunctionCache::instance()->get_function_ptr( - _fn.id, aggregate_fn.serialize_fn_symbol, _fn.hdfs_location, _fn.checksum, - &serialize_fn_, &_cache_entry)); - } - if (!aggregate_fn.get_value_fn_symbol.empty()) { - RETURN_IF_ERROR(UserFunctionCache::instance()->get_function_ptr( - _fn.id, aggregate_fn.get_value_fn_symbol, _fn.hdfs_location, _fn.checksum, - &get_value_fn_, &_cache_entry)); - } - if (!aggregate_fn.remove_fn_symbol.empty()) { - RETURN_IF_ERROR(UserFunctionCache::instance()->get_function_ptr( - _fn.id, aggregate_fn.remove_fn_symbol, _fn.hdfs_location, _fn.checksum, &remove_fn_, + _fn.id, aggregate_fn.update_fn_symbol, _fn.hdfs_location, _fn.checksum, &_update_fn, &_cache_entry)); - } - if (!aggregate_fn.finalize_fn_symbol.empty()) { - RETURN_IF_ERROR(UserFunctionCache::instance()->get_function_ptr( - _fn.id, _fn.aggregate_fn.finalize_fn_symbol, _fn.hdfs_location, _fn.checksum, - &finalize_fn_, &_cache_entry)); + + // Merge() is not defined for purely analytic function. + if (!aggregate_fn.is_analytic_only_fn) { + RETURN_IF_ERROR(UserFunctionCache::instance()->get_function_ptr( + _fn.id, aggregate_fn.merge_fn_symbol, _fn.hdfs_location, _fn.checksum, + &_merge_fn, &_cache_entry)); + } + // Serialize(), GetValue(), Remove() and Finalize() are optional + if (!aggregate_fn.serialize_fn_symbol.empty()) { + RETURN_IF_ERROR(UserFunctionCache::instance()->get_function_ptr( + _fn.id, aggregate_fn.serialize_fn_symbol, _fn.hdfs_location, _fn.checksum, + &_serialize_fn, &_cache_entry)); + } + if (!aggregate_fn.get_value_fn_symbol.empty()) { + RETURN_IF_ERROR(UserFunctionCache::instance()->get_function_ptr( + _fn.id, aggregate_fn.get_value_fn_symbol, _fn.hdfs_location, _fn.checksum, + &_get_value_fn, &_cache_entry)); + } + if (!aggregate_fn.remove_fn_symbol.empty()) { + RETURN_IF_ERROR(UserFunctionCache::instance()->get_function_ptr( + _fn.id, aggregate_fn.remove_fn_symbol, _fn.hdfs_location, _fn.checksum, + &_remove_fn, &_cache_entry)); + } + if (!aggregate_fn.finalize_fn_symbol.empty()) { + RETURN_IF_ERROR(UserFunctionCache::instance()->get_function_ptr( + _fn.id, _fn.aggregate_fn.finalize_fn_symbol, _fn.hdfs_location, _fn.checksum, + &_finalize_fn, &_cache_entry)); + } + } else if (_fn.binary_type == TFunctionBinaryType::RPC) { + _rpc_init = std::make_unique(state, _fn, RPCFn::AggregationStep::INIT, true); + _rpc_update = std::make_unique(state, _fn, RPCFn::AggregationStep::UPDATE, true); + + // Merge() is not defined for purely analytic function. + if (!aggregate_fn.is_analytic_only_fn) { + _rpc_merge = std::make_unique(state, _fn, RPCFn::AggregationStep::MERGE, true); + } + // Serialize(), GetValue(), Remove() and Finalize() are optional + if (!aggregate_fn.serialize_fn_symbol.empty()) { + _rpc_serialize = + std::make_unique(state, _fn, RPCFn::AggregationStep::SERIALIZE, true); + } + if (!aggregate_fn.get_value_fn_symbol.empty()) { + _rpc_get_value = + std::make_unique(state, _fn, RPCFn::AggregationStep::GET_VALUE, true); + } + if (!aggregate_fn.remove_fn_symbol.empty()) { + _rpc_remove = std::make_unique(state, _fn, RPCFn::AggregationStep::REMOVE, true); + } + if (!aggregate_fn.finalize_fn_symbol.empty()) { + _rpc_finalize = + std::make_unique(state, _fn, RPCFn::AggregationStep::FINALIZE, true); + } + } else { + return Status::NotSupported(fmt::format("Not supported BinaryType: {}", _fn.binary_type)); } return Status::OK(); } -Status AggFn::Create(const TExpr& texpr, const RowDescriptor& row_desc, +Status AggFn::create(const TExpr& texpr, const RowDescriptor& row_desc, const SlotDescriptor& intermediate_slot_desc, const SlotDescriptor& output_slot_desc, RuntimeState* state, AggFn** agg_fn) { *agg_fn = nullptr; @@ -140,9 +170,9 @@ Status AggFn::Create(const TExpr& texpr, const RowDescriptor& row_desc, } AggFn* new_agg_fn = pool->add(new AggFn(texpr_node, intermediate_slot_desc, output_slot_desc)); RETURN_IF_ERROR(Expr::create_tree(texpr, pool, new_agg_fn)); - Status status = new_agg_fn->Init(row_desc, state); + Status status = new_agg_fn->init(row_desc, state); if (UNLIKELY(!status.ok())) { - new_agg_fn->Close(); + new_agg_fn->close(); return status; } for (Expr* input_expr : new_agg_fn->children()) { @@ -153,24 +183,24 @@ Status AggFn::Create(const TExpr& texpr, const RowDescriptor& row_desc, return Status::OK(); } -FunctionContext::TypeDesc AggFn::GetIntermediateTypeDesc() const { +FunctionContext::TypeDesc AggFn::get_intermediate_type_desc() const { return AnyValUtil::column_type_to_type_desc(intermediate_slot_desc_.type()); } -FunctionContext::TypeDesc AggFn::GetOutputTypeDesc() const { +FunctionContext::TypeDesc AggFn::get_output_type_desc() const { return AnyValUtil::column_type_to_type_desc(output_slot_desc_.type()); } -void AggFn::Close() { +void AggFn::close() { // This also closes all the input expressions. Expr::close(); } -void AggFn::Close(const std::vector& exprs) { - for (AggFn* expr : exprs) expr->Close(); +void AggFn::close(const std::vector& exprs) { + for (AggFn* expr : exprs) expr->close(); } -std::string AggFn::DebugString() const { +std::string AggFn::debug_string() const { std::stringstream out; out << "AggFn(op=" << agg_op_; for (Expr* input_expr : children()) { @@ -180,11 +210,11 @@ std::string AggFn::DebugString() const { return out.str(); } -std::string AggFn::DebugString(const std::vector& agg_fns) { +std::string AggFn::debug_string(const std::vector& agg_fns) { std::stringstream out; out << "["; for (int i = 0; i < agg_fns.size(); ++i) { - out << (i == 0 ? "" : " ") << agg_fns[i]->DebugString(); + out << (i == 0 ? "" : " ") << agg_fns[i]->debug_string(); } out << "]"; return out.str(); diff --git a/be/src/exprs/agg_fn.h b/be/src/exprs/agg_fn.h index 87083c9484..28342f974a 100644 --- a/be/src/exprs/agg_fn.h +++ b/be/src/exprs/agg_fn.h @@ -35,6 +35,7 @@ class RuntimeState; class Tuple; class TupleRow; class TExprNode; +class RPCFn; /// --- AggFn overview /// @@ -52,33 +53,33 @@ class TExprNode; /// AggFnEvaluator is the interface for evaluating aggregate functions against input /// tuple rows. It invokes the following functions at different phases of the aggregation: /// -/// init_fn_ : An initialization function that initializes the aggregate value. +/// _init_fn : An initialization function that initializes the aggregate value. /// -/// update_fn_ : An update function that processes the arguments for each row in the +/// _update_fn : An update function that processes the arguments for each row in the /// query result set and accumulates an intermediate result. For example, /// this function might increment a counter, append to a string buffer or /// add the input to a cumulative sum. /// -/// merge_fn_ : A merge function that combines multiple intermediate results into a +/// _merge_fn : A merge function that combines multiple intermediate results into a /// single value. /// -/// serialize_fn_: A serialization function that flattens any intermediate values +/// _serialize_fn: A serialization function that flattens any intermediate values /// containing pointers, and frees any memory allocated during the init, /// update and merge phases. /// -/// finalize_fn_ : A finalize function that either passes through the combined result +/// _finalize_fn : A finalize function that either passes through the combined result /// unchanged, or does one final transformation. Also frees the resources /// allocated during init, update and merge phases. /// -/// get_value_fn_: Used by AnalyticEval node to obtain the current intermediate value. +/// _get_value_fn: Used by AnalyticEval node to obtain the current intermediate value. /// -/// remove_fn_ : Used by AnalyticEval node to undo the update to the intermediate value +/// _remove_fn : Used by AnalyticEval node to undo the update to the intermediate value /// by an input row as it falls out of a sliding window. /// class AggFn : public Expr { public: /// Override the base class' implementation. - virtual bool IsAggFn() const { return true; } + virtual bool is_agg_fn() const { return true; } /// Enum for some built-in aggregation ops. enum AggregationOp { @@ -99,7 +100,7 @@ public: /// the row descriptor of the input tuple row; 'intermediate_slot_desc' is the slot /// descriptor of the intermediate value; 'output_slot_desc' is the slot descriptor /// of the output value. On failure, returns error status and sets 'agg_fn' to nullptr. - static Status Create(const TExpr& texpr, const RowDescriptor& row_desc, + static Status create(const TExpr& texpr, const RowDescriptor& row_desc, const SlotDescriptor& intermediate_slot_desc, const SlotDescriptor& output_slot_desc, RuntimeState* state, AggFn** agg_fn) WARN_UNUSED_RESULT; @@ -115,25 +116,25 @@ public: const SlotDescriptor& intermediate_slot_desc() const { return intermediate_slot_desc_; } // Output type is the same as Expr::type(). const SlotDescriptor& output_slot_desc() const { return output_slot_desc_; } - void* remove_fn() const { return remove_fn_; } - void* merge_or_update_fn() const { return is_merge_ ? merge_fn_ : update_fn_; } - void* serialize_fn() const { return serialize_fn_; } - void* get_value_fn() const { return get_value_fn_; } - void* finalize_fn() const { return finalize_fn_; } - bool SupportsRemove() const { return remove_fn_ != nullptr; } - bool SupportsSerialize() const { return serialize_fn_ != nullptr; } - FunctionContext::TypeDesc GetIntermediateTypeDesc() const; - FunctionContext::TypeDesc GetOutputTypeDesc() const; + void* remove_fn() const { return _remove_fn; } + void* merge_or_update_fn() const { return is_merge_ ? _merge_fn : _update_fn; } + void* serialize_fn() const { return _serialize_fn; } + void* get_value_fn() const { return _get_value_fn; } + void* finalize_fn() const { return _finalize_fn; } + bool supports_remove() const { return _remove_fn != nullptr; } + bool supports_serialize() const { return _serialize_fn != nullptr; } + FunctionContext::TypeDesc get_intermediate_type_desc() const; + FunctionContext::TypeDesc get_output_type_desc() const; const std::vector& arg_type_descs() const { return arg_type_descs_; } /// Releases all cache entries to libCache for all nodes in the expr tree. - virtual void Close(); - static void Close(const std::vector& exprs); + virtual void close(); + static void close(const std::vector& exprs); Expr* clone(ObjectPool* pool) const { return nullptr; } - virtual std::string DebugString() const; - static std::string DebugString(const std::vector& exprs); + virtual std::string debug_string() const; + static std::string debug_string(const std::vector& exprs); const int get_vararg_start_idx() const { return _vararg_start_idx; } @@ -158,22 +159,30 @@ private: AggregationOp agg_op_; /// Function pointers for the different phases of the aggregate function. - void* init_fn_ = nullptr; - void* update_fn_ = nullptr; - void* remove_fn_ = nullptr; - void* merge_fn_ = nullptr; - void* serialize_fn_ = nullptr; - void* get_value_fn_ = nullptr; - void* finalize_fn_ = nullptr; + void* _init_fn = nullptr; + void* _update_fn = nullptr; + void* _remove_fn = nullptr; + void* _merge_fn = nullptr; + void* _serialize_fn = nullptr; + void* _get_value_fn = nullptr; + void* _finalize_fn = nullptr; int _vararg_start_idx; + std::unique_ptr _rpc_init; + std::unique_ptr _rpc_update; + std::unique_ptr _rpc_remove; + std::unique_ptr _rpc_merge; + std::unique_ptr _rpc_serialize; + std::unique_ptr _rpc_get_value; + std::unique_ptr _rpc_finalize; + AggFn(const TExprNode& node, const SlotDescriptor& intermediate_slot_desc, const SlotDescriptor& output_slot_desc); /// Initializes the AggFn and its input expressions. May load the UDAF from LibCache /// if necessary. - virtual Status Init(const RowDescriptor& desc, RuntimeState* state) WARN_UNUSED_RESULT; + virtual Status init(const RowDescriptor& desc, RuntimeState* state) WARN_UNUSED_RESULT; }; } // namespace doris diff --git a/be/src/exprs/expr_context.h b/be/src/exprs/expr_context.h index 1b6edc6c31..3a79a1be54 100644 --- a/be/src/exprs/expr_context.h +++ b/be/src/exprs/expr_context.h @@ -156,7 +156,7 @@ public: private: friend class Expr; friend class ScalarFnCall; - friend class RPCFnCall; + friend class RPCFn; friend class InPredicate; friend class RuntimePredicateWrapper; friend class BloomFilterPredicate; diff --git a/be/src/exprs/new_agg_fn_evaluator.cc b/be/src/exprs/new_agg_fn_evaluator.cc index a81205240b..0c09a78ed5 100644 --- a/be/src/exprs/new_agg_fn_evaluator.cc +++ b/be/src/exprs/new_agg_fn_evaluator.cc @@ -116,11 +116,10 @@ Status NewAggFnEvaluator::Create(const AggFn& agg_fn, RuntimeState* state, Objec *result = nullptr; // Create a new AggFn evaluator. - NewAggFnEvaluator* agg_fn_eval = - pool->add(new NewAggFnEvaluator(agg_fn, mem_pool, false)); + NewAggFnEvaluator* agg_fn_eval = pool->add(new NewAggFnEvaluator(agg_fn, mem_pool, false)); agg_fn_eval->agg_fn_ctx_.reset(FunctionContextImpl::create_context( - state, mem_pool, agg_fn.GetIntermediateTypeDesc(), agg_fn.GetOutputTypeDesc(), + state, mem_pool, agg_fn.get_intermediate_type_desc(), agg_fn.get_output_type_desc(), agg_fn.arg_type_descs(), 0, false)); Status status; @@ -284,7 +283,7 @@ void NewAggFnEvaluator::SetDstSlot(const AnyVal* src, const SlotDescriptor& dst_ // This function would be replaced in codegen. void NewAggFnEvaluator::Init(Tuple* dst) { DCHECK(opened_); - DCHECK(agg_fn_.init_fn_ != nullptr); + DCHECK(agg_fn_._init_fn != nullptr); for (ExprContext* input_eval : input_evals_) { DCHECK(input_eval->opened()); } @@ -301,7 +300,7 @@ void NewAggFnEvaluator::Init(Tuple* dst) { sv->ptr = reinterpret_cast(slot); sv->len = type.len; } - reinterpret_cast(agg_fn_.init_fn_)(agg_fn_ctx_.get(), staging_intermediate_val_); + reinterpret_cast(agg_fn_._init_fn)(agg_fn_ctx_.get(), staging_intermediate_val_); SetDstSlot(staging_intermediate_val_, slot_desc, dst); agg_fn_ctx_->impl()->set_num_updates(0); agg_fn_ctx_->impl()->set_num_removes(0); @@ -519,12 +518,12 @@ void NewAggFnEvaluator::Update(const TupleRow* row, Tuple* dst, void* fn) { } void NewAggFnEvaluator::Merge(Tuple* src, Tuple* dst) { - DCHECK(agg_fn_.merge_fn_ != nullptr); + DCHECK(agg_fn_._merge_fn != nullptr); const SlotDescriptor& slot_desc = intermediate_slot_desc(); SetAnyVal(slot_desc, dst, staging_intermediate_val_); SetAnyVal(slot_desc, src, staging_merge_input_val_); // The merge fn always takes one input argument. - reinterpret_cast(agg_fn_.merge_fn_)(agg_fn_ctx_.get(), *staging_merge_input_val_, + reinterpret_cast(agg_fn_._merge_fn)(agg_fn_ctx_.get(), *staging_merge_input_val_, staging_intermediate_val_); SetDstSlot(staging_intermediate_val_, slot_desc, dst); } @@ -650,13 +649,3 @@ void NewAggFnEvaluator::ShallowClone(ObjectPool* pool, MemPool* mem_pool, cloned_evals->push_back(cloned_eval); } } - -// -//void NewAggFnEvaluator::FreeLocalAllocations() { -// ExprContext::FreeLocalAllocations(input_evals_); -// agg_fn_ctx_->impl()->FreeLocalAllocations(); -//} - -//void NewAggFnEvaluator::FreeLocalAllocations(const vector& evals) { -// for (NewAggFnEvaluator* eval : evals) eval->FreeLocalAllocations(); -//} diff --git a/be/src/exprs/new_agg_fn_evaluator_ir.cc b/be/src/exprs/new_agg_fn_evaluator_ir.cc deleted file mode 100644 index 21014e7f40..0000000000 --- a/be/src/exprs/new_agg_fn_evaluator_ir.cc +++ /dev/null @@ -1,31 +0,0 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. -// This file is copied from -// https://github.com/apache/impala/blob/branch-2.10.0/be/src/exprs/agg-fn-evaluator-ir.cc -// and modified by Doris - -#include "exprs/new_agg_fn_evaluator.h" - -using namespace doris; - -FunctionContext* NewAggFnEvaluator::agg_fn_ctx() const { - return agg_fn_ctx_.get(); -} - -ExprContext* const* NewAggFnEvaluator::input_evals() const { - return input_evals_.data(); -} diff --git a/be/src/exprs/rpc_fn.cpp b/be/src/exprs/rpc_fn.cpp new file mode 100644 index 0000000000..63ddc93dfc --- /dev/null +++ b/be/src/exprs/rpc_fn.cpp @@ -0,0 +1,770 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include "exprs/rpc_fn.h" + +#include + +#include "runtime/fragment_mgr.h" +#include "runtime/user_function_cache.h" +#include "service/brpc.h" +#include "util/brpc_client_cache.h" +#include "vec/columns/column.h" +#include "vec/columns/column_vector.h" +#include "vec/core/block.h" +#include "vec/core/column_numbers.h" +#include "vec/data_types/data_type_bitmap.h" +#include "vec/data_types/data_type_date.h" +#include "vec/data_types/data_type_date_time.h" +#include "vec/data_types/data_type_decimal.h" +#include "vec/data_types/data_type_nullable.h" +#include "vec/data_types/data_type_number.h" +#include "vec/data_types/data_type_string.h" + +namespace doris { + +RPCFn::RPCFn(RuntimeState* state, const TFunction& fn, int fn_ctx_id, bool is_agg) + : _state(state), _fn(fn), _fn_ctx_id(fn_ctx_id), _is_agg(is_agg) { + _client = ExecEnv::GetInstance()->brpc_function_client_cache()->get_client(_server_addr); + if (!_is_agg) { + _function_name = _fn.scalar_fn.symbol; + _server_addr = _fn.hdfs_location; + _signature = fmt::format("{}: [{}/{}]", _fn.name.function_name, _fn.hdfs_location, + _fn.scalar_fn.symbol); + } +} + +RPCFn::RPCFn(const TFunction& fn, bool is_agg) : RPCFn(nullptr, fn, -1, is_agg) {} + +RPCFn::RPCFn(RuntimeState* state, const TFunction& fn, AggregationStep step, bool is_agg) + : RPCFn(nullptr, fn, -1, is_agg) { + _step = step; + DCHECK(is_agg) << "Only used for agg fns"; + switch (_step) { + case INIT: { + _function_name = _fn.aggregate_fn.init_fn_symbol; + _server_addr = _fn.hdfs_location; + _signature = fmt::format("{}: [{}/{}]", _fn.name.function_name, _fn.hdfs_location, + _fn.aggregate_fn.init_fn_symbol); + break; + } + case UPDATE: { + _function_name = _fn.aggregate_fn.init_fn_symbol; + break; + } + case MERGE: { + _function_name = _fn.aggregate_fn.merge_fn_symbol; + break; + } + case SERIALIZE: { + _function_name = _fn.aggregate_fn.serialize_fn_symbol; + break; + } + case GET_VALUE: { + _function_name = _fn.aggregate_fn.get_value_fn_symbol; + break; + } + case FINALIZE: { + _function_name = _fn.aggregate_fn.finalize_fn_symbol; + break; + } + case REMOVE: { + _function_name = _fn.aggregate_fn.remove_fn_symbol; + break; + } + + default: + CHECK(false) << "invalid AggregationStep: " << _step; + break; + } + _server_addr = _fn.hdfs_location; + _signature = fmt::format("{}: [{}/{}]", _fn.name.function_name, _server_addr, _function_name); +} + +Status RPCFn::call_internal(ExprContext* context, TupleRow* row, PFunctionCallResponse* response, + const std::vector& exprs) { + FunctionContext* fn_ctx = context->fn_context(_fn_ctx_id); + PFunctionCallRequest request; + request.set_function_name(_function_name); + for (int i = 0; i < exprs.size(); ++i) { + PValues* arg = request.add_args(); + void* src_slot = context->get_value(exprs[i], row); + PGenericType* ptype = arg->mutable_type(); + if (src_slot == nullptr) { + arg->set_has_null(true); + arg->add_null_map(true); + } else { + arg->set_has_null(false); + } + switch (exprs[i]->type().type) { + case TYPE_BOOLEAN: { + ptype->set_id(PGenericType::BOOLEAN); + arg->add_bool_value(*(bool*)src_slot); + break; + } + case TYPE_TINYINT: { + ptype->set_id(PGenericType::INT8); + arg->add_int32_value(*(int8_t*)src_slot); + break; + } + case TYPE_SMALLINT: { + ptype->set_id(PGenericType::INT16); + arg->add_int32_value(*(int16_t*)src_slot); + break; + } + case TYPE_INT: { + ptype->set_id(PGenericType::INT32); + arg->add_int32_value(*(int*)src_slot); + break; + } + case TYPE_BIGINT: { + ptype->set_id(PGenericType::INT64); + arg->add_int64_value(*(int64_t*)src_slot); + break; + } + case TYPE_LARGEINT: { + ptype->set_id(PGenericType::INT128); + char buffer[sizeof(__int128)]; + memcpy(buffer, src_slot, sizeof(__int128)); + arg->add_bytes_value(buffer, sizeof(__int128)); + break; + } + case TYPE_DOUBLE: { + ptype->set_id(PGenericType::DOUBLE); + arg->add_double_value(*(double*)src_slot); + break; + } + case TYPE_FLOAT: { + ptype->set_id(PGenericType::FLOAT); + arg->add_float_value(*(float*)src_slot); + break; + } + case TYPE_VARCHAR: + case TYPE_STRING: + case TYPE_CHAR: { + ptype->set_id(PGenericType::STRING); + StringValue value = *reinterpret_cast(src_slot); + arg->add_string_value(value.ptr, value.len); + break; + } + case TYPE_HLL: { + ptype->set_id(PGenericType::HLL); + StringValue value = *reinterpret_cast(src_slot); + arg->add_string_value(value.ptr, value.len); + break; + } + case TYPE_OBJECT: { + ptype->set_id(PGenericType::BITMAP); + StringValue value = *reinterpret_cast(src_slot); + arg->add_string_value(value.ptr, value.len); + break; + } + case TYPE_DECIMALV2: { + ptype->set_id(PGenericType::DECIMAL128); + ptype->mutable_decimal_type()->set_precision(exprs[i]->type().precision); + ptype->mutable_decimal_type()->set_scale(exprs[i]->type().scale); + char buffer[sizeof(__int128)]; + memcpy(buffer, src_slot, sizeof(__int128)); + arg->add_bytes_value(buffer, sizeof(__int128)); + break; + } + case TYPE_DATE: { + ptype->set_id(PGenericType::DATE); + const auto* time_val = (const DateTimeValue*)(src_slot); + PDateTime* date_time = arg->add_datetime_value(); + date_time->set_day(time_val->day()); + date_time->set_month(time_val->month()); + date_time->set_year(time_val->year()); + break; + } + case TYPE_DATETIME: { + ptype->set_id(PGenericType::DATETIME); + const auto* time_val = (const DateTimeValue*)(src_slot); + PDateTime* date_time = arg->add_datetime_value(); + date_time->set_day(time_val->day()); + date_time->set_month(time_val->month()); + date_time->set_year(time_val->year()); + date_time->set_hour(time_val->hour()); + date_time->set_minute(time_val->minute()); + date_time->set_second(time_val->second()); + date_time->set_microsecond(time_val->microsecond()); + break; + } + case TYPE_TIME: { + ptype->set_id(PGenericType::DATETIME); + const auto* time_val = (const DateTimeValue*)(src_slot); + PDateTime* date_time = arg->add_datetime_value(); + date_time->set_hour(time_val->hour()); + date_time->set_minute(time_val->minute()); + date_time->set_second(time_val->second()); + date_time->set_microsecond(time_val->microsecond()); + break; + } + default: { + std::string error_msg = + fmt::format("data time not supported: {}", exprs[i]->type().type); + fn_ctx->set_error(error_msg.c_str()); + cancel(error_msg); + break; + } + } + } + + brpc::Controller cntl; + _client->fn_call(&cntl, &request, response, nullptr); + if (cntl.Failed()) { + std::string error_msg = + fmt::format("call rpc function {} failed: {}", _signature, cntl.ErrorText()); + fn_ctx->set_error(error_msg.c_str()); + cancel(error_msg); + return Status::InternalError(error_msg); + } + if (!response->has_status() || response->result_size() == 0) { + std::string error_msg = + fmt::format("call rpc function {} failed: status or result is not set: {}", + _signature, response->status().DebugString()); + fn_ctx->set_error(error_msg.c_str()); + cancel(error_msg); + return Status::InternalError(error_msg); + } + if (response->status().status_code() != 0) { + std::string error_msg = fmt::format("call rpc function {} failed: {}", _signature, + response->status().DebugString()); + fn_ctx->set_error(error_msg.c_str()); + cancel(error_msg); + return Status::InternalError(error_msg); + } + return Status::OK(); +} + +void RPCFn::cancel(const std::string& msg) { + _state->exec_env()->fragment_mgr()->cancel(_state->fragment_instance_id(), + PPlanFragmentCancelReason::CALL_RPC_ERROR, msg); +} + +template +void convert_col_to_pvalue(const vectorized::ColumnPtr& column, + const vectorized::DataTypePtr& data_type, PValues* arg, + size_t row_count) { + PGenericType* ptype = arg->mutable_type(); + switch (data_type->get_type_id()) { + case vectorized::TypeIndex::UInt8: { + ptype->set_id(PGenericType::UINT8); + auto* values = arg->mutable_bool_value(); + values->Reserve(row_count); + const auto* col = vectorized::check_and_get_column(column); + auto& data = col->get_data(); + values->Add(data.begin(), data.begin() + row_count); + break; + } + case vectorized::TypeIndex::UInt16: { + ptype->set_id(PGenericType::UINT16); + auto* values = arg->mutable_uint32_value(); + values->Reserve(row_count); + const auto* col = vectorized::check_and_get_column(column); + auto& data = col->get_data(); + values->Add(data.begin(), data.begin() + row_count); + break; + } + case vectorized::TypeIndex::UInt32: { + ptype->set_id(PGenericType::UINT32); + auto* values = arg->mutable_uint32_value(); + values->Reserve(row_count); + const auto* col = vectorized::check_and_get_column(column); + auto& data = col->get_data(); + values->Add(data.begin(), data.begin() + row_count); + break; + } + case vectorized::TypeIndex::UInt64: { + ptype->set_id(PGenericType::UINT64); + auto* values = arg->mutable_uint64_value(); + values->Reserve(row_count); + const auto* col = vectorized::check_and_get_column(column); + auto& data = col->get_data(); + values->Add(data.begin(), data.begin() + row_count); + break; + } + case vectorized::TypeIndex::UInt128: { + ptype->set_id(PGenericType::UINT128); + arg->mutable_bytes_value()->Reserve(row_count); + for (size_t row_num = 0; row_num < row_count; ++row_num) { + if constexpr (nullable) { + if (column->is_null_at(row_num)) { + arg->add_bytes_value(nullptr); + } else { + StringRef data = column->get_data_at(row_num); + arg->add_bytes_value(data.data, data.size); + } + } else { + StringRef data = column->get_data_at(row_num); + arg->add_bytes_value(data.data, data.size); + } + } + break; + } + case vectorized::TypeIndex::Int8: { + ptype->set_id(PGenericType::INT8); + auto* values = arg->mutable_int32_value(); + values->Reserve(row_count); + const auto* col = vectorized::check_and_get_column(column); + auto& data = col->get_data(); + values->Add(data.begin(), data.begin() + row_count); + break; + } + case vectorized::TypeIndex::Int16: { + ptype->set_id(PGenericType::INT16); + auto* values = arg->mutable_int32_value(); + values->Reserve(row_count); + const auto* col = vectorized::check_and_get_column(column); + auto& data = col->get_data(); + values->Add(data.begin(), data.begin() + row_count); + break; + } + case vectorized::TypeIndex::Int32: { + ptype->set_id(PGenericType::INT32); + auto* values = arg->mutable_int32_value(); + values->Reserve(row_count); + const auto* col = vectorized::check_and_get_column(column); + auto& data = col->get_data(); + values->Add(data.begin(), data.begin() + row_count); + break; + } + case vectorized::TypeIndex::Int64: { + ptype->set_id(PGenericType::INT64); + auto* values = arg->mutable_int64_value(); + values->Reserve(row_count); + const auto* col = vectorized::check_and_get_column(column); + auto& data = col->get_data(); + values->Add(data.begin(), data.begin() + row_count); + break; + } + case vectorized::TypeIndex::Int128: { + ptype->set_id(PGenericType::INT128); + arg->mutable_bytes_value()->Reserve(row_count); + for (size_t row_num = 0; row_num < row_count; ++row_num) { + if constexpr (nullable) { + if (column->is_null_at(row_num)) { + arg->add_bytes_value(nullptr); + } else { + StringRef data = column->get_data_at(row_num); + arg->add_bytes_value(data.data, data.size); + } + } else { + StringRef data = column->get_data_at(row_num); + arg->add_bytes_value(data.data, data.size); + } + } + break; + } + case vectorized::TypeIndex::Float32: { + ptype->set_id(PGenericType::FLOAT); + auto* values = arg->mutable_float_value(); + values->Reserve(row_count); + const auto* col = vectorized::check_and_get_column(column); + auto& data = col->get_data(); + values->Add(data.begin(), data.begin() + row_count); + break; + } + + case vectorized::TypeIndex::Float64: { + ptype->set_id(PGenericType::DOUBLE); + auto* values = arg->mutable_double_value(); + values->Reserve(row_count); + const auto* col = vectorized::check_and_get_column(column); + auto& data = col->get_data(); + values->Add(data.begin(), data.begin() + row_count); + break; + } + case vectorized::TypeIndex::Decimal128: { + ptype->set_id(PGenericType::DECIMAL128); + auto dec_type = std::reinterpret_pointer_cast< + const vectorized::DataTypeDecimal>(data_type); + ptype->mutable_decimal_type()->set_precision(dec_type->get_precision()); + ptype->mutable_decimal_type()->set_scale(dec_type->get_scale()); + arg->mutable_bytes_value()->Reserve(row_count); + for (size_t row_num = 0; row_num < row_count; ++row_num) { + if constexpr (nullable) { + if (column->is_null_at(row_num)) { + arg->add_bytes_value(nullptr); + } else { + StringRef data = column->get_data_at(row_num); + arg->add_bytes_value(data.data, data.size); + } + } else { + StringRef data = column->get_data_at(row_num); + arg->add_bytes_value(data.data, data.size); + } + } + break; + } + case vectorized::TypeIndex::String: { + ptype->set_id(PGenericType::STRING); + arg->mutable_bytes_value()->Reserve(row_count); + for (size_t row_num = 0; row_num < row_count; ++row_num) { + if constexpr (nullable) { + if (column->is_null_at(row_num)) { + arg->add_string_value(nullptr); + } else { + StringRef data = column->get_data_at(row_num); + arg->add_string_value(data.to_string()); + } + } else { + StringRef data = column->get_data_at(row_num); + arg->add_string_value(data.to_string()); + } + } + break; + } + case vectorized::TypeIndex::Date: { + ptype->set_id(PGenericType::DATE); + arg->mutable_datetime_value()->Reserve(row_count); + for (size_t row_num = 0; row_num < row_count; ++row_num) { + PDateTime* date_time = arg->add_datetime_value(); + if constexpr (nullable) { + if (!column->is_null_at(row_num)) { + vectorized::VecDateTimeValue v = + vectorized::VecDateTimeValue::create_from_olap_date(column->get_int(row_num)); + date_time->set_day(v.day()); + date_time->set_month(v.month()); + date_time->set_year(v.year()); + } + } else { + vectorized::VecDateTimeValue v = + vectorized::VecDateTimeValue::create_from_olap_date(column->get_int(row_num)); + date_time->set_day(v.day()); + date_time->set_month(v.month()); + date_time->set_year(v.year()); + } + } + break; + } + case vectorized::TypeIndex::DateTime: { + ptype->set_id(PGenericType::DATETIME); + arg->mutable_datetime_value()->Reserve(row_count); + for (size_t row_num = 0; row_num < row_count; ++row_num) { + PDateTime* date_time = arg->add_datetime_value(); + if constexpr (nullable) { + if (!column->is_null_at(row_num)) { + vectorized::VecDateTimeValue v = + vectorized::VecDateTimeValue::create_from_olap_datetime(column->get_int(row_num)); + date_time->set_day(v.day()); + date_time->set_month(v.month()); + date_time->set_year(v.year()); + date_time->set_hour(v.hour()); + date_time->set_minute(v.minute()); + date_time->set_second(v.second()); + } + } else { + vectorized::VecDateTimeValue v = + vectorized::VecDateTimeValue::create_from_olap_datetime(column->get_int(row_num)); + date_time->set_day(v.day()); + date_time->set_month(v.month()); + date_time->set_year(v.year()); + date_time->set_hour(v.hour()); + date_time->set_minute(v.minute()); + date_time->set_second(v.second()); + } + } + break; + } + case vectorized::TypeIndex::BitMap: { + ptype->set_id(PGenericType::BITMAP); + arg->mutable_bytes_value()->Reserve(row_count); + for (size_t row_num = 0; row_num < row_count; ++row_num) { + if constexpr (nullable) { + if (column->is_null_at(row_num)) { + arg->add_bytes_value(nullptr); + } else { + StringRef data = column->get_data_at(row_num); + arg->add_bytes_value(data.data, data.size); + } + } else { + StringRef data = column->get_data_at(row_num); + arg->add_bytes_value(data.data, data.size); + } + } + break; + } + case vectorized::TypeIndex::HLL: { + ptype->set_id(PGenericType::HLL); + arg->mutable_bytes_value()->Reserve(row_count); + for (size_t row_num = 0; row_num < row_count; ++row_num) { + if constexpr (nullable) { + if (column->is_null_at(row_num)) { + arg->add_bytes_value(nullptr); + } else { + StringRef data = column->get_data_at(row_num); + arg->add_bytes_value(data.data, data.size); + } + } else { + StringRef data = column->get_data_at(row_num); + arg->add_bytes_value(data.data, data.size); + } + } + break; + } + default: + LOG(INFO) << "unknown type: " << data_type->get_name(); + ptype->set_id(PGenericType::UNKNOWN); + break; + } +} + +void convert_nullable_col_to_pvalue(const vectorized::ColumnPtr& column, + const vectorized::DataTypePtr& data_type, + const vectorized::ColumnUInt8& null_col, PValues* arg, + size_t row_count) { + if (column->has_null(row_count)) { + auto* null_map = arg->mutable_null_map(); + null_map->Reserve(row_count); + const auto* col = vectorized::check_and_get_column(null_col); + auto& data = col->get_data(); + null_map->Add(data.begin(), data.begin() + row_count); + convert_col_to_pvalue(column, data_type, arg, row_count); + } else { + convert_col_to_pvalue(column, data_type, arg, row_count); + } +} + +void convert_block_to_proto(vectorized::Block& block, const vectorized::ColumnNumbers& arguments, + size_t input_rows_count, PFunctionCallRequest* request) { + size_t row_count = std::min(block.rows(), input_rows_count); + for (size_t col_idx : arguments) { + PValues* arg = request->add_args(); + vectorized::ColumnWithTypeAndName& column = block.get_by_position(col_idx); + arg->set_has_null(column.column->has_null(row_count)); + auto col = column.column->convert_to_full_column_if_const(); + if (auto* nullable = + vectorized::check_and_get_column(*col)) { + auto data_col = nullable->get_nested_column_ptr(); + auto& null_col = nullable->get_null_map_column(); + auto data_type = + std::reinterpret_pointer_cast(column.type); + convert_nullable_col_to_pvalue(data_col->convert_to_full_column_if_const(), + data_type->get_nested_type(), null_col, arg, row_count); + } else { + convert_col_to_pvalue(col, column.type, arg, row_count); + } + } +} + +template +void convert_to_column(vectorized::MutableColumnPtr& column, const PValues& result) { + switch (result.type().id()) { + case PGenericType::UINT8: { + column->reserve(result.uint32_value_size()); + column->resize(result.uint32_value_size()); + auto& data = reinterpret_cast(column.get())->get_data(); + for (int i = 0; i < result.uint32_value_size(); ++i) { + data[i] = result.uint32_value(i); + } + break; + } + case PGenericType::UINT16: { + column->reserve(result.uint32_value_size()); + column->resize(result.uint32_value_size()); + auto& data = reinterpret_cast(column.get())->get_data(); + for (int i = 0; i < result.uint32_value_size(); ++i) { + data[i] = result.uint32_value(i); + } + break; + } + case PGenericType::UINT32: { + column->reserve(result.uint32_value_size()); + column->resize(result.uint32_value_size()); + auto& data = reinterpret_cast(column.get())->get_data(); + for (int i = 0; i < result.uint32_value_size(); ++i) { + data[i] = result.uint32_value(i); + } + break; + } + case PGenericType::UINT64: { + column->reserve(result.uint64_value_size()); + column->resize(result.uint64_value_size()); + auto& data = reinterpret_cast(column.get())->get_data(); + for (int i = 0; i < result.uint64_value_size(); ++i) { + data[i] = result.uint64_value(i); + } + break; + } + case PGenericType::INT8: { + column->reserve(result.int32_value_size()); + column->resize(result.int32_value_size()); + auto& data = reinterpret_cast(column.get())->get_data(); + for (int i = 0; i < result.int32_value_size(); ++i) { + data[i] = result.int32_value(i); + } + break; + } + case PGenericType::INT16: { + column->reserve(result.int32_value_size()); + column->resize(result.int32_value_size()); + auto& data = reinterpret_cast(column.get())->get_data(); + for (int i = 0; i < result.int32_value_size(); ++i) { + data[i] = result.int32_value(i); + } + break; + } + case PGenericType::INT32: { + column->reserve(result.int32_value_size()); + column->resize(result.int32_value_size()); + auto& data = reinterpret_cast(column.get())->get_data(); + for (int i = 0; i < result.int32_value_size(); ++i) { + data[i] = result.int32_value(i); + } + break; + } + case PGenericType::INT64: { + column->reserve(result.int64_value_size()); + column->resize(result.int64_value_size()); + auto& data = reinterpret_cast(column.get())->get_data(); + for (int i = 0; i < result.int64_value_size(); ++i) { + data[i] = result.int64_value(i); + } + break; + } + case PGenericType::DATE: + case PGenericType::DATETIME: { + column->reserve(result.datetime_value_size()); + column->resize(result.datetime_value_size()); + auto& data = reinterpret_cast(column.get())->get_data(); + for (int i = 0; i < result.datetime_value_size(); ++i) { + vectorized::VecDateTimeValue v; + PDateTime pv = result.datetime_value(i); + v.set_time(pv.year(), pv.month(), pv.day(), pv.hour(), pv.minute(), pv.minute()); + data[i] = binary_cast(v); + } + break; + } + case PGenericType::FLOAT: { + column->reserve(result.float_value_size()); + column->resize(result.float_value_size()); + auto& data = reinterpret_cast(column.get())->get_data(); + for (int i = 0; i < result.float_value_size(); ++i) { + data[i] = result.float_value(i); + } + break; + } + case PGenericType::DOUBLE: { + column->reserve(result.double_value_size()); + column->resize(result.double_value_size()); + auto& data = reinterpret_cast(column.get())->get_data(); + for (int i = 0; i < result.double_value_size(); ++i) { + data[i] = result.double_value(i); + } + break; + } + case PGenericType::INT128: { + column->reserve(result.bytes_value_size()); + column->resize(result.bytes_value_size()); + auto& data = reinterpret_cast(column.get())->get_data(); + for (int i = 0; i < result.bytes_value_size(); ++i) { + data[i] = *(int128_t*)(result.bytes_value(i).c_str()); + } + break; + } + case PGenericType::STRING: { + column->reserve(result.string_value_size()); + for (int i = 0; i < result.string_value_size(); ++i) { + column->insert_data(result.string_value(i).c_str(), result.string_value(i).size()); + } + break; + } + case PGenericType::DECIMAL128: { + column->reserve(result.bytes_value_size()); + column->resize(result.bytes_value_size()); + auto& data = reinterpret_cast(column.get())->get_data(); + for (int i = 0; i < result.bytes_value_size(); ++i) { + data[i] = *(int128_t*)(result.bytes_value(i).c_str()); + } + break; + } + case PGenericType::BITMAP: { + column->reserve(result.bytes_value_size()); + for (int i = 0; i < result.bytes_value_size(); ++i) { + column->insert_data(result.bytes_value(i).c_str(), result.bytes_value(i).size()); + } + break; + } + case PGenericType::HLL: { + column->reserve(result.bytes_value_size()); + for (int i = 0; i < result.bytes_value_size(); ++i) { + column->insert_data(result.bytes_value(i).c_str(), result.bytes_value(i).size()); + } + break; + } + default: { + LOG(WARNING) << "unknown PGenericType: " << result.type().DebugString(); + break; + } + } +} + +void convert_to_block(vectorized::Block& block, const PValues& result, size_t pos) { + auto data_type = block.get_data_type(pos); + if (data_type->is_nullable()) { + auto null_type = + std::reinterpret_pointer_cast(data_type); + auto data_col = null_type->get_nested_type()->create_column(); + convert_to_column(data_col, result); + auto null_col = vectorized::ColumnUInt8::create(data_col->size(), 0); + auto& null_map_data = null_col->get_data(); + null_col->reserve(data_col->size()); + null_col->resize(data_col->size()); + if (result.has_null()) { + for (int i = 0; i < data_col->size(); ++i) { + null_map_data[i] = result.null_map(i); + } + } else { + for (int i = 0; i < data_col->size(); ++i) { + null_map_data[i] = false; + } + } + block.replace_by_position( + pos, vectorized::ColumnNullable::create(std::move(data_col), std::move(null_col))); + } else { + auto column = data_type->create_column(); + convert_to_column(column, result); + block.replace_by_position(pos, std::move(column)); + } +} +Status RPCFn::vec_call(FunctionContext* context, vectorized::Block& block, + const vectorized::ColumnNumbers& arguments, size_t result, + size_t input_rows_count) { + PFunctionCallRequest request; + PFunctionCallResponse response; + request.set_function_name(_function_name); + convert_block_to_proto(block, arguments, input_rows_count, &request); + brpc::Controller cntl; + _client->fn_call(&cntl, &request, &response, nullptr); + if (cntl.Failed()) { + return Status::InternalError( + fmt::format("call to rpc function {} failed: {}", _signature, cntl.ErrorText()) + .c_str()); + } + if (!response.has_status() || response.result_size() == 0) { + return Status::InternalError(fmt::format( + "call rpc function {} failed: status or result is not set.", _signature)); + } + if (response.status().status_code() != 0) { + return Status::InternalError(fmt::format("call to rpc function {} failed: {}", _signature, + response.status().DebugString())); + } + convert_to_block(block, response.result(0), result); + return Status::OK(); +} +} // namespace doris diff --git a/be/src/exprs/rpc_fn.h b/be/src/exprs/rpc_fn.h new file mode 100644 index 0000000000..154f158640 --- /dev/null +++ b/be/src/exprs/rpc_fn.h @@ -0,0 +1,136 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#pragma once + +#include +#include + +#include "common/status.h" +#include "exprs/expr.h" +#include "exprs/expr_context.h" +#include "gen_cpp/function_service.pb.h" +#include "runtime/runtime_state.h" +#include "udf/udf.h" + +namespace doris { +namespace vectorized { +class Block; +} // namespace vectorized + +class RPCFn { +public: + enum AggregationStep { + INIT = 0, + UPDATE = 1, + MERGE = 2, + REMOVE = 3, + SERIALIZE = 4, + GET_VALUE = 5, + FINALIZE = 6, + INVALID = 999, + }; + + RPCFn(RuntimeState* state, const TFunction& fn, int fn_ctx_id, bool is_agg); + RPCFn(const TFunction& fn, bool is_agg); + RPCFn(RuntimeState* state, const TFunction& fn, AggregationStep step, bool is_agg); + ~RPCFn() {} + template + T call(ExprContext* context, TupleRow* row, const std::vector& exprs); + Status vec_call(FunctionContext* context, vectorized::Block& block, + const std::vector& arguments, size_t result, size_t input_rows_count); + bool avliable() { return _client != nullptr; } + +private: + Status call_internal(ExprContext* context, TupleRow* row, PFunctionCallResponse* response, + const std::vector& exprs); + void cancel(const std::string& msg); + + std::shared_ptr _client; + RuntimeState* _state; + std::string _function_name; + std::string _server_addr; + std::string _signature; + TFunction _fn; + int _fn_ctx_id; + bool _is_agg; + AggregationStep _step = AggregationStep::INVALID; +}; + +template +T RPCFn::call(ExprContext* context, TupleRow* row, const std::vector& exprs) { + PFunctionCallResponse response; + Status st = call_internal(context, row, &response, exprs); + WARN_IF_ERROR(st, "call rpc udf error"); + if (!st.ok() || (response.result(0).has_null() && response.result(0).null_map(0))) { + return T::null(); + } + T res_val; + // TODO(yangzhg) deal with udtf and udaf + const PValues& result = response.result(0); + if constexpr (std::is_same_v) { + DCHECK(result.type().id() == PGenericType::INT8); + res_val.val = static_cast(result.int32_value(0)); + } else if constexpr (std::is_same_v) { + DCHECK(result.type().id() == PGenericType::INT16); + res_val.val = static_cast(result.int32_value(0)); + } else if constexpr (std::is_same_v) { + DCHECK(result.type().id() == PGenericType::INT32); + res_val.val = result.int32_value(0); + } else if constexpr (std::is_same_v) { + DCHECK(result.type().id() == PGenericType::INT64); + res_val.val = result.int64_value(0); + } else if constexpr (std::is_same_v) { + DCHECK(result.type().id() == PGenericType::FLOAT); + res_val.val = result.float_value(0); + } else if constexpr (std::is_same_v) { + DCHECK(result.type().id() == PGenericType::DOUBLE); + res_val.val = result.double_value(0); + } else if constexpr (std::is_same_v) { + DCHECK(result.type().id() == PGenericType::STRING); + auto* fn_ctx = context->fn_context(_fn_ctx_id); + StringVal val(fn_ctx, result.string_value(0).size()); + res_val = val.copy_from(fn_ctx, + reinterpret_cast(result.string_value(0).c_str()), + result.string_value(0).size()); + } else if constexpr (std::is_same_v) { + DCHECK(result.type().id() == PGenericType::INT128); + memcpy(&(res_val.val), result.bytes_value(0).data(), sizeof(__int128_t)); + } else if constexpr (std::is_same_v) { + DCHECK(result.type().id() == PGenericType::DATE || + result.type().id() == PGenericType::DATETIME); + DateTimeValue value; + value.set_time(result.datetime_value(0).year(), result.datetime_value(0).month(), + result.datetime_value(0).day(), result.datetime_value(0).hour(), + result.datetime_value(0).minute(), result.datetime_value(0).second(), + result.datetime_value(0).microsecond()); + if (result.type().id() == PGenericType::DATE) { + value.set_type(TimeType::TIME_DATE); + } else if (result.type().id() == PGenericType::DATETIME) { + if (result.datetime_value(0).has_year()) { + value.set_type(TimeType::TIME_DATETIME); + } else + value.set_type(TimeType::TIME_TIME); + } + value.to_datetime_val(&res_val); + } else if constexpr (std::is_same_v) { + DCHECK(result.type().id() == PGenericType::DECIMAL128); + memcpy(&(res_val.val), result.bytes_value(0).data(), sizeof(__int128_t)); + } + return res_val; +} +} // namespace doris \ No newline at end of file diff --git a/be/src/exprs/rpc_fn_call.cpp b/be/src/exprs/rpc_fn_call.cpp index b8939af7d4..8d779f0ce8 100644 --- a/be/src/exprs/rpc_fn_call.cpp +++ b/be/src/exprs/rpc_fn_call.cpp @@ -19,20 +19,20 @@ #include "exprs/anyval_util.h" #include "exprs/expr_context.h" +#include "exprs/rpc_fn.h" #include "fmt/format.h" -#include "gen_cpp/function_service.pb.h" -#include "runtime/fragment_mgr.h" +#include "rpc_fn.h" #include "runtime/runtime_state.h" #include "runtime/user_function_cache.h" -#include "service/brpc.h" -#include "util/brpc_client_cache.h" namespace doris { -RPCFnCall::RPCFnCall(const TExprNode& node) : Expr(node), _fn_context_index(-1) { +RPCFnCall::RPCFnCall(const TExprNode& node) : Expr(node), _tnode(node) { DCHECK_EQ(_fn.binary_type, TFunctionBinaryType::RPC); } +RPCFnCall::~RPCFnCall() {} + Status RPCFnCall::prepare(RuntimeState* state, const RowDescriptor& desc, ExprContext* context) { RETURN_IF_ERROR(Expr::prepare(state, desc, context)); DCHECK(!_fn.scalar_fn.symbol.empty()); @@ -44,16 +44,12 @@ Status RPCFnCall::prepare(RuntimeState* state, const RowDescriptor& desc, ExprCo arg_types.push_back(AnyValUtil::column_type_to_type_desc(_children[i]->type())); char_arg = char_arg || (_children[i]->type().type == TYPE_CHAR); } - _fn_context_index = context->register_func(state, return_type, arg_types, 0); + int id = context->register_func(state, return_type, arg_types, 0); - // _fn.scalar_fn.symbol - _rpc_function_symbol = _fn.scalar_fn.symbol; - - _client = state->exec_env()->brpc_function_client_cache()->get_client(_fn.hdfs_location); - - if (_client == nullptr) { + _rpc_fn = std::make_unique(state, _fn, id, false); + if (!_rpc_fn->avliable()) { return Status::InternalError( - fmt::format("rpc env init error: {}/{}", _fn.hdfs_location, _rpc_function_symbol)); + fmt::format("rpc env init error: {}/{}", _fn.hdfs_location, _fn.scalar_fn.symbol)); } return Status::OK(); } @@ -61,7 +57,6 @@ Status RPCFnCall::prepare(RuntimeState* state, const RowDescriptor& desc, ExprCo Status RPCFnCall::open(RuntimeState* state, ExprContext* ctx, FunctionContext::FunctionStateScope scope) { RETURN_IF_ERROR(Expr::open(state, ctx, scope)); - _state = state; return Status::OK(); } @@ -70,276 +65,50 @@ void RPCFnCall::close(RuntimeState* state, ExprContext* context, Expr::close(state, context, scope); } -Status RPCFnCall::call_rpc(ExprContext* context, TupleRow* row, PFunctionCallResponse* response) { - PFunctionCallRequest request; - request.set_function_name(_rpc_function_symbol); - for (int i = 0; i < _children.size(); ++i) { - PValues* arg = request.add_args(); - void* src_slot = context->get_value(_children[i], row); - PGenericType* ptype = arg->mutable_type(); - if (src_slot == nullptr) { - arg->set_has_null(true); - arg->add_null_map(true); - } else { - arg->set_has_null(false); - } - switch (_children[i]->type().type) { - case TYPE_BOOLEAN: { - ptype->set_id(PGenericType::BOOLEAN); - arg->add_bool_value(*(bool*)src_slot); - break; - } - case TYPE_TINYINT: { - ptype->set_id(PGenericType::INT8); - arg->add_int32_value(*(int8_t*)src_slot); - break; - } - case TYPE_SMALLINT: { - ptype->set_id(PGenericType::INT16); - arg->add_int32_value(*(int16_t*)src_slot); - break; - } - case TYPE_INT: { - ptype->set_id(PGenericType::INT32); - arg->add_int32_value(*(int*)src_slot); - break; - } - case TYPE_BIGINT: { - ptype->set_id(PGenericType::INT64); - arg->add_int64_value(*(int64_t*)src_slot); - break; - } - case TYPE_LARGEINT: { - ptype->set_id(PGenericType::INT128); - char buffer[sizeof(__int128)]; - memcpy(buffer, src_slot, sizeof(__int128)); - arg->add_bytes_value(buffer, sizeof(__int128)); - break; - } - case TYPE_DOUBLE: { - ptype->set_id(PGenericType::DOUBLE); - arg->add_double_value(*(double*)src_slot); - break; - } - case TYPE_FLOAT: { - ptype->set_id(PGenericType::FLOAT); - arg->add_float_value(*(float*)src_slot); - break; - } - case TYPE_VARCHAR: - case TYPE_STRING: - case TYPE_CHAR: { - ptype->set_id(PGenericType::STRING); - StringValue value = *reinterpret_cast(src_slot); - arg->add_string_value(value.ptr, value.len); - break; - } - case TYPE_HLL: { - ptype->set_id(PGenericType::HLL); - StringValue value = *reinterpret_cast(src_slot); - arg->add_string_value(value.ptr, value.len); - break; - } - case TYPE_OBJECT: { - ptype->set_id(PGenericType::BITMAP); - StringValue value = *reinterpret_cast(src_slot); - arg->add_string_value(value.ptr, value.len); - break; - } - case TYPE_DECIMALV2: { - ptype->set_id(PGenericType::DECIMAL128); - ptype->mutable_decimal_type()->set_precision(_children[i]->type().precision); - ptype->mutable_decimal_type()->set_scale(_children[i]->type().scale); - char buffer[sizeof(__int128)]; - memcpy(buffer, src_slot, sizeof(__int128)); - arg->add_bytes_value(buffer, sizeof(__int128)); - break; - } - case TYPE_DATE: { - ptype->set_id(PGenericType::DATE); - const auto* time_val = (const DateTimeValue*)(src_slot); - PDateTime* date_time = arg->add_datetime_value(); - date_time->set_day(time_val->day()); - date_time->set_month(time_val->month()); - date_time->set_year(time_val->year()); - break; - } - case TYPE_DATETIME: { - ptype->set_id(PGenericType::DATETIME); - const auto* time_val = (const DateTimeValue*)(src_slot); - PDateTime* date_time = arg->add_datetime_value(); - date_time->set_day(time_val->day()); - date_time->set_month(time_val->month()); - date_time->set_year(time_val->year()); - date_time->set_hour(time_val->hour()); - date_time->set_minute(time_val->minute()); - date_time->set_second(time_val->second()); - date_time->set_microsecond(time_val->microsecond()); - break; - } - case TYPE_TIME: { - ptype->set_id(PGenericType::DATETIME); - const auto* time_val = (const DateTimeValue*)(src_slot); - PDateTime* date_time = arg->add_datetime_value(); - date_time->set_hour(time_val->hour()); - date_time->set_minute(time_val->minute()); - date_time->set_second(time_val->second()); - date_time->set_microsecond(time_val->microsecond()); - break; - } - default: { - FunctionContext* fn_ctx = context->fn_context(_fn_context_index); - std::string error_msg = - fmt::format("data time not supported: {}", _children[i]->type().type); - fn_ctx->set_error(error_msg.c_str()); - cancel(error_msg); - break; - } - } - } - - brpc::Controller cntl; - _client->fn_call(&cntl, &request, response, nullptr); - if (cntl.Failed()) { - FunctionContext* fn_ctx = context->fn_context(_fn_context_index); - std::string error_msg = fmt::format("call rpc function {} failed: {}", _rpc_function_symbol, - cntl.ErrorText()); - fn_ctx->set_error(error_msg.c_str()); - cancel(error_msg); - return Status::InternalError(error_msg); - } - if (!response->has_status() || !response->has_result()) { - FunctionContext* fn_ctx = context->fn_context(_fn_context_index); - std::string error_msg = - fmt::format("call rpc function {} failed: status or result is not set: {}", - _rpc_function_symbol, response->status().DebugString()); - fn_ctx->set_error(error_msg.c_str()); - cancel(error_msg); - return Status::InternalError(error_msg); - } - if (response->status().status_code() != 0) { - FunctionContext* fn_ctx = context->fn_context(_fn_context_index); - std::string error_msg = fmt::format("call rpc function {} failed: {}", _rpc_function_symbol, - response->status().DebugString()); - fn_ctx->set_error(error_msg.c_str()); - cancel(error_msg); - return Status::InternalError(error_msg); - } - return Status::OK(); -} - -template -T RPCFnCall::interpret_eval(ExprContext* context, TupleRow* row) { - PFunctionCallResponse response; - Status st = call_rpc(context, row, &response); - WARN_IF_ERROR(st, "call rpc udf error"); - if (!st.ok() || (response.result().has_null() && response.result().null_map(0))) { - return T::null(); - } - T res_val; - // TODO(yangzhg) deal with udtf and udaf - const PValues& result = response.result(); - if constexpr (std::is_same_v) { - DCHECK(result.type().id() == PGenericType::INT8); - res_val.val = static_cast(result.int32_value(0)); - } else if constexpr (std::is_same_v) { - DCHECK(result.type().id() == PGenericType::INT16); - res_val.val = static_cast(result.int32_value(0)); - } else if constexpr (std::is_same_v) { - DCHECK(result.type().id() == PGenericType::INT32); - res_val.val = result.int32_value(0); - } else if constexpr (std::is_same_v) { - DCHECK(result.type().id() == PGenericType::INT64); - res_val.val = result.int64_value(0); - } else if constexpr (std::is_same_v) { - DCHECK(result.type().id() == PGenericType::FLOAT); - res_val.val = result.float_value(0); - } else if constexpr (std::is_same_v) { - DCHECK(result.type().id() == PGenericType::DOUBLE); - res_val.val = result.double_value(0); - } else if constexpr (std::is_same_v) { - DCHECK(result.type().id() == PGenericType::STRING); - FunctionContext* fn_ctx = context->fn_context(_fn_context_index); - StringVal val(fn_ctx, result.string_value(0).size()); - res_val = val.copy_from(fn_ctx, - reinterpret_cast(result.string_value(0).c_str()), - result.string_value(0).size()); - } else if constexpr (std::is_same_v) { - DCHECK(result.type().id() == PGenericType::INT128); - memcpy(&(res_val.val), result.bytes_value(0).data(), sizeof(__int128_t)); - } else if constexpr (std::is_same_v) { - DCHECK(result.type().id() == PGenericType::DATE || - result.type().id() == PGenericType::DATETIME); - DateTimeValue value; - value.set_time(result.datetime_value(0).year(), result.datetime_value(0).month(), - result.datetime_value(0).day(), result.datetime_value(0).hour(), - result.datetime_value(0).minute(), result.datetime_value(0).second(), - result.datetime_value(0).microsecond()); - if (result.type().id() == PGenericType::DATE) { - value.set_type(TimeType::TIME_DATE); - } else if (result.type().id() == PGenericType::DATETIME) { - if (result.datetime_value(0).has_year()) { - value.set_type(TimeType::TIME_DATETIME); - } else - value.set_type(TimeType::TIME_TIME); - } - value.to_datetime_val(&res_val); - } else if constexpr (std::is_same_v) { - DCHECK(result.type().id() == PGenericType::DECIMAL128); - memcpy(&(res_val.val), result.bytes_value(0).data(), sizeof(__int128_t)); - } - return res_val; -} // namespace doris - doris_udf::IntVal RPCFnCall::get_int_val(ExprContext* context, TupleRow* row) { - return interpret_eval(context, row); + return _rpc_fn->call(context, row, _children); } doris_udf::BooleanVal RPCFnCall::get_boolean_val(ExprContext* context, TupleRow* row) { - return interpret_eval(context, row); + return _rpc_fn->call(context, row, _children); } doris_udf::TinyIntVal RPCFnCall::get_tiny_int_val(ExprContext* context, TupleRow* row) { - return interpret_eval(context, row); + return _rpc_fn->call(context, row, _children); } doris_udf::SmallIntVal RPCFnCall::get_small_int_val(ExprContext* context, TupleRow* row) { - return interpret_eval(context, row); + return _rpc_fn->call(context, row, _children); } doris_udf::BigIntVal RPCFnCall::get_big_int_val(ExprContext* context, TupleRow* row) { - return interpret_eval(context, row); + return _rpc_fn->call(context, row, _children); } doris_udf::FloatVal RPCFnCall::get_float_val(ExprContext* context, TupleRow* row) { - return interpret_eval(context, row); + return _rpc_fn->call(context, row, _children); } doris_udf::DoubleVal RPCFnCall::get_double_val(ExprContext* context, TupleRow* row) { - return interpret_eval(context, row); + return _rpc_fn->call(context, row, _children); } doris_udf::StringVal RPCFnCall::get_string_val(ExprContext* context, TupleRow* row) { - return interpret_eval(context, row); + return _rpc_fn->call(context, row, _children); } doris_udf::LargeIntVal RPCFnCall::get_large_int_val(ExprContext* context, TupleRow* row) { - return interpret_eval(context, row); + return _rpc_fn->call(context, row, _children); } doris_udf::DateTimeVal RPCFnCall::get_datetime_val(ExprContext* context, TupleRow* row) { - return interpret_eval(context, row); + return _rpc_fn->call(context, row, _children); } doris_udf::DecimalV2Val RPCFnCall::get_decimalv2_val(ExprContext* context, TupleRow* row) { - return interpret_eval(context, row); + return _rpc_fn->call(context, row, _children); } doris_udf::CollectionVal RPCFnCall::get_array_val(ExprContext* context, TupleRow* row) { - return interpret_eval(context, row); + return _rpc_fn->call(context, row, _children); } -void RPCFnCall::cancel(const std::string& msg) { - _state->exec_env()->fragment_mgr()->cancel(_state->fragment_instance_id(), - PPlanFragmentCancelReason::CALL_RPC_ERROR, msg); -} - } // namespace doris diff --git a/be/src/exprs/rpc_fn_call.h b/be/src/exprs/rpc_fn_call.h index b534c0c68b..d63fb2db0e 100644 --- a/be/src/exprs/rpc_fn_call.h +++ b/be/src/exprs/rpc_fn_call.h @@ -23,13 +23,12 @@ namespace doris { class TExprNode; -class PFunctionService_Stub; -class PFunctionCallResponse; +class RPCFn; class RPCFnCall : public Expr { public: RPCFnCall(const TExprNode& node); - ~RPCFnCall() = default; + ~RPCFnCall(); virtual Status prepare(RuntimeState* state, const RowDescriptor& desc, ExprContext* context) override; @@ -37,7 +36,9 @@ public: FunctionContext::FunctionStateScope scope) override; virtual void close(RuntimeState* state, ExprContext* context, FunctionContext::FunctionStateScope scope) override; - virtual Expr* clone(ObjectPool* pool) const override { return pool->add(new RPCFnCall(*this)); } + virtual Expr* clone(ObjectPool* pool) const override { + return pool->add(new RPCFnCall(_tnode)); + } virtual doris_udf::BooleanVal get_boolean_val(ExprContext* context, TupleRow*) override; virtual doris_udf::TinyIntVal get_tiny_int_val(ExprContext* context, TupleRow*) override; @@ -53,14 +54,7 @@ public: virtual doris_udf::CollectionVal get_array_val(ExprContext* context, TupleRow*) override; private: - Status call_rpc(ExprContext* context, TupleRow* row, PFunctionCallResponse* response); - template - RETURN_TYPE interpret_eval(ExprContext* context, TupleRow* row); - void cancel(const std::string& msg); - - std::shared_ptr _client = nullptr; - int _fn_context_index; - std::string _rpc_function_symbol; - RuntimeState* _state; + std::unique_ptr _rpc_fn; + const TExprNode& _tnode; }; } // namespace doris diff --git a/be/src/udf/udf.cpp b/be/src/udf/udf.cpp index f3c11ee6cc..b343155cff 100644 --- a/be/src/udf/udf.cpp +++ b/be/src/udf/udf.cpp @@ -201,20 +201,6 @@ FunctionContext* FunctionContextImpl::clone(MemPool* pool) { return new_context; } -// TODO: to be implemented -void FunctionContextImpl::serialize(PFunctionContext* pcontext) const { - // pcontext->set_string_result(_string_result); - // pcontext->set_num_updates(_num_updates); - // pcontext->set_num_removes(_num_removes); - // pcontext->set_num_warnings(_num_warnings); - // pcontext->set_error_msg(_error_msg); - // PUniqueId* query_id = pcontext->mutable_query_id(); - // query_id->set_hi(_context->query_id().hi); - // query_id->set_lo(_context->query_id().lo); -} - -void FunctionContextImpl::derialize(const PFunctionContext& pcontext) {} - } // namespace doris namespace doris_udf { diff --git a/be/src/udf/udf_internal.h b/be/src/udf/udf_internal.h index 903ac069dc..708e45dda6 100644 --- a/be/src/udf/udf_internal.h +++ b/be/src/udf/udf_internal.h @@ -36,7 +36,6 @@ class FreePool; class MemPool; class RuntimeState; struct ColumnPtrWrapper; -class PFunctionContext; // This class actually implements the interface of FunctionContext. This is split to // hide the details from the external header. @@ -111,9 +110,6 @@ public: const doris_udf::FunctionContext::TypeDesc& get_return_type() const { return _return_type; } - void serialize(PFunctionContext* pcontext) const; - void derialize(const PFunctionContext& pcontext); - private: friend class doris_udf::FunctionContext; friend class ExprContext; diff --git a/be/src/vec/core/block.cpp b/be/src/vec/core/block.cpp index e304a1973e..36be1d61c3 100644 --- a/be/src/vec/core/block.cpp +++ b/be/src/vec/core/block.cpp @@ -351,9 +351,6 @@ std::string Block::dump_names() const { } std::string Block::dump_data(size_t begin, size_t row_limit) const { - if (rows() == 0) { - return "empty block."; - } std::vector headers; std::vector headers_size; for (auto it = data.begin(); it != data.end(); ++it) { @@ -379,6 +376,9 @@ std::string Block::dump_data(size_t begin, size_t row_limit) const { out << std::setw(1) << "|" << std::endl; // header bottom line line(); + if (rows() == 0) { + return out.str(); + } // content for (size_t row_num = begin; row_num < rows() && row_num < row_limit + begin; ++row_num) { for (size_t i = 0; i < columns(); ++i) { @@ -875,9 +875,6 @@ Block MutableBlock::to_block(int start_column, int end_column) { } std::string MutableBlock::dump_data(size_t row_limit) const { - if (rows() == 0) { - return "empty block."; - } std::vector headers; std::vector headers_size; for (size_t i = 0; i < columns(); ++i) { @@ -903,6 +900,9 @@ std::string MutableBlock::dump_data(size_t row_limit) const { out << std::setw(1) << "|" << std::endl; // header bottom line line(); + if (rows() == 0) { + return out.str(); + } // content for (size_t row_num = 0; row_num < rows() && row_num < row_limit; ++row_num) { for (size_t i = 0; i < columns(); ++i) { diff --git a/be/src/vec/core/block.h b/be/src/vec/core/block.h index ee032900a3..94dd5ec8e9 100644 --- a/be/src/vec/core/block.h +++ b/be/src/vec/core/block.h @@ -259,7 +259,7 @@ public: std::unique_ptr create_same_struct_block(size_t size) const; - /** Compares (*this) n-th row and rhs m-th row. + /** Compares (*this) n-th row and rhs m-th row. * Returns negative number, 0, or positive number (*this) n-th row is less, equal, greater than rhs m-th row respectively. * Is used in sortings. * diff --git a/be/src/vec/exprs/vectorized_fn_call.cpp b/be/src/vec/exprs/vectorized_fn_call.cpp index 379b892f03..9afacc748b 100644 --- a/be/src/vec/exprs/vectorized_fn_call.cpp +++ b/be/src/vec/exprs/vectorized_fn_call.cpp @@ -20,6 +20,7 @@ #include #include "exprs/anyval_util.h" +#include "exprs/rpc_fn.h" #include "fmt/format.h" #include "fmt/ranges.h" #include "udf/udf_internal.h" @@ -45,8 +46,7 @@ doris::Status VectorizedFnCall::prepare(doris::RuntimeState* state, child_expr_name.emplace_back(child->expr_name()); } if (_fn.binary_type == TFunctionBinaryType::RPC) { - _function = RPCFnCall::create(_fn.name.function_name, _fn.hdfs_location, argument_template, - _data_type); + _function = FunctionRPC::create(_fn, argument_template, _data_type); } else if (_fn.binary_type == TFunctionBinaryType::JAVA_UDF) { #ifdef LIBJVM _function = JavaFunctionCall::create(_fn, argument_template, _data_type); diff --git a/be/src/vec/functions/function_rpc.cpp b/be/src/vec/functions/function_rpc.cpp index 9b2e11d08a..97d31710ca 100644 --- a/be/src/vec/functions/function_rpc.cpp +++ b/be/src/vec/functions/function_rpc.cpp @@ -21,542 +21,24 @@ #include -#include "gen_cpp/function_service.pb.h" -#include "runtime/exec_env.h" -#include "runtime/user_function_cache.h" -#include "service/brpc.h" -#include "util/brpc_client_cache.h" -#include "vec/columns/column_vector.h" -#include "vec/core/block.h" -#include "vec/data_types/data_type_bitmap.h" -#include "vec/data_types/data_type_date.h" -#include "vec/data_types/data_type_date_time.h" -#include "vec/data_types/data_type_decimal.h" -#include "vec/data_types/data_type_nullable.h" -#include "vec/data_types/data_type_number.h" -#include "vec/data_types/data_type_string.h" +#include "exprs/rpc_fn.h" namespace doris::vectorized { -RPCFnCall::RPCFnCall(const std::string& symbol, const std::string& server, - const DataTypes& argument_types, const DataTypePtr& return_type) - : _symbol(symbol), - _server(server), - _name(fmt::format("{}/{}", server, symbol)), - _argument_types(argument_types), - _return_type(return_type) {} -Status RPCFnCall::prepare(FunctionContext* context, FunctionContext::FunctionStateScope scope) { - _client = ExecEnv::GetInstance()->brpc_function_client_cache()->get_client(_server); +FunctionRPC::FunctionRPC(const TFunction& fn, const DataTypes& argument_types, + const DataTypePtr& return_type) + : _argument_types(argument_types), _return_type(return_type), _tfn(fn) {} - if (_client == nullptr) { +Status FunctionRPC::prepare(FunctionContext* context, FunctionContext::FunctionStateScope scope) { + _fn = std::make_unique(_tfn, false); + + if (!_fn->avliable()) { return Status::InternalError("rpc env init error"); } return Status::OK(); } -template -void convert_col_to_pvalue(const ColumnPtr& column, const DataTypePtr& data_type, PValues* arg, - size_t row_count) { - PGenericType* ptype = arg->mutable_type(); - switch (data_type->get_type_id()) { - case TypeIndex::UInt8: { - ptype->set_id(PGenericType::UINT8); - auto* values = arg->mutable_bool_value(); - values->Reserve(row_count); - const auto* col = check_and_get_column(column); - auto& data = col->get_data(); - values->Add(data.begin(), data.begin() + row_count); - break; - } - case TypeIndex::UInt16: { - ptype->set_id(PGenericType::UINT16); - auto* values = arg->mutable_uint32_value(); - values->Reserve(row_count); - const auto* col = check_and_get_column(column); - auto& data = col->get_data(); - values->Add(data.begin(), data.begin() + row_count); - break; - } - case TypeIndex::UInt32: { - ptype->set_id(PGenericType::UINT32); - auto* values = arg->mutable_uint32_value(); - values->Reserve(row_count); - const auto* col = check_and_get_column(column); - auto& data = col->get_data(); - values->Add(data.begin(), data.begin() + row_count); - break; - } - case TypeIndex::UInt64: { - ptype->set_id(PGenericType::UINT64); - auto* values = arg->mutable_uint64_value(); - values->Reserve(row_count); - const auto* col = check_and_get_column(column); - auto& data = col->get_data(); - values->Add(data.begin(), data.begin() + row_count); - break; - } - case TypeIndex::UInt128: { - ptype->set_id(PGenericType::UINT128); - arg->mutable_bytes_value()->Reserve(row_count); - for (size_t row_num = 0; row_num < row_count; ++row_num) { - if constexpr (nullable) { - if (column->is_null_at(row_num)) { - arg->add_bytes_value(nullptr); - } else { - StringRef data = column->get_data_at(row_num); - arg->add_bytes_value(data.data, data.size); - } - } else { - StringRef data = column->get_data_at(row_num); - arg->add_bytes_value(data.data, data.size); - } - } - break; - } - case TypeIndex::Int8: { - ptype->set_id(PGenericType::INT8); - auto* values = arg->mutable_int32_value(); - values->Reserve(row_count); - const auto* col = check_and_get_column(column); - auto& data = col->get_data(); - values->Add(data.begin(), data.begin() + row_count); - break; - } - case TypeIndex::Int16: { - ptype->set_id(PGenericType::INT16); - auto* values = arg->mutable_int32_value(); - values->Reserve(row_count); - const auto* col = check_and_get_column(column); - auto& data = col->get_data(); - values->Add(data.begin(), data.begin() + row_count); - break; - } - case TypeIndex::Int32: { - ptype->set_id(PGenericType::INT32); - auto* values = arg->mutable_int32_value(); - values->Reserve(row_count); - const auto* col = check_and_get_column(column); - auto& data = col->get_data(); - values->Add(data.begin(), data.begin() + row_count); - break; - } - case TypeIndex::Int64: { - ptype->set_id(PGenericType::INT64); - auto* values = arg->mutable_int64_value(); - values->Reserve(row_count); - const auto* col = check_and_get_column(column); - auto& data = col->get_data(); - values->Add(data.begin(), data.begin() + row_count); - break; - } - case TypeIndex::Int128: { - ptype->set_id(PGenericType::INT128); - arg->mutable_bytes_value()->Reserve(row_count); - for (size_t row_num = 0; row_num < row_count; ++row_num) { - if constexpr (nullable) { - if (column->is_null_at(row_num)) { - arg->add_bytes_value(nullptr); - } else { - StringRef data = column->get_data_at(row_num); - arg->add_bytes_value(data.data, data.size); - } - } else { - StringRef data = column->get_data_at(row_num); - arg->add_bytes_value(data.data, data.size); - } - } - break; - } - case TypeIndex::Float32: { - ptype->set_id(PGenericType::FLOAT); - auto* values = arg->mutable_float_value(); - values->Reserve(row_count); - const auto* col = check_and_get_column(column); - auto& data = col->get_data(); - values->Add(data.begin(), data.begin() + row_count); - break; - } - - case TypeIndex::Float64: { - ptype->set_id(PGenericType::DOUBLE); - auto* values = arg->mutable_double_value(); - values->Reserve(row_count); - const auto* col = check_and_get_column(column); - auto& data = col->get_data(); - values->Add(data.begin(), data.begin() + row_count); - break; - } - case TypeIndex::Decimal128: { - ptype->set_id(PGenericType::DECIMAL128); - auto dec_type = std::reinterpret_pointer_cast>(data_type); - ptype->mutable_decimal_type()->set_precision(dec_type->get_precision()); - ptype->mutable_decimal_type()->set_scale(dec_type->get_scale()); - arg->mutable_bytes_value()->Reserve(row_count); - for (size_t row_num = 0; row_num < row_count; ++row_num) { - if constexpr (nullable) { - if (column->is_null_at(row_num)) { - arg->add_bytes_value(nullptr); - } else { - StringRef data = column->get_data_at(row_num); - arg->add_bytes_value(data.data, data.size); - } - } else { - StringRef data = column->get_data_at(row_num); - arg->add_bytes_value(data.data, data.size); - } - } - break; - } - case TypeIndex::String: { - ptype->set_id(PGenericType::STRING); - arg->mutable_bytes_value()->Reserve(row_count); - for (size_t row_num = 0; row_num < row_count; ++row_num) { - if constexpr (nullable) { - if (column->is_null_at(row_num)) { - arg->add_string_value(nullptr); - } else { - StringRef data = column->get_data_at(row_num); - arg->add_string_value(data.to_string()); - } - } else { - StringRef data = column->get_data_at(row_num); - arg->add_string_value(data.to_string()); - } - } - break; - } - case TypeIndex::Date: { - ptype->set_id(PGenericType::DATE); - arg->mutable_datetime_value()->Reserve(row_count); - for (size_t row_num = 0; row_num < row_count; ++row_num) { - PDateTime* date_time = arg->add_datetime_value(); - if constexpr (nullable) { - if (!column->is_null_at(row_num)) { - VecDateTimeValue v = - binary_cast( - column->get_int(row_num)); - date_time->set_day(v.day()); - date_time->set_month(v.month()); - date_time->set_year(v.year()); - } - } else { - VecDateTimeValue v = binary_cast( - column->get_int(row_num)); - date_time->set_day(v.day()); - date_time->set_month(v.month()); - date_time->set_year(v.year()); - } - } - break; - } - case TypeIndex::DateTime: { - ptype->set_id(PGenericType::DATETIME); - arg->mutable_datetime_value()->Reserve(row_count); - for (size_t row_num = 0; row_num < row_count; ++row_num) { - PDateTime* date_time = arg->add_datetime_value(); - if constexpr (nullable) { - if (!column->is_null_at(row_num)) { - VecDateTimeValue v = - binary_cast( - column->get_int(row_num)); - date_time->set_day(v.day()); - date_time->set_month(v.month()); - date_time->set_year(v.year()); - date_time->set_hour(v.hour()); - date_time->set_minute(v.minute()); - date_time->set_second(v.second()); - } - } else { - VecDateTimeValue v = binary_cast( - column->get_int(row_num)); - date_time->set_day(v.day()); - date_time->set_month(v.month()); - date_time->set_year(v.year()); - date_time->set_hour(v.hour()); - date_time->set_minute(v.minute()); - date_time->set_second(v.second()); - } - } - break; - } - case TypeIndex::BitMap: { - ptype->set_id(PGenericType::BITMAP); - arg->mutable_bytes_value()->Reserve(row_count); - for (size_t row_num = 0; row_num < row_count; ++row_num) { - if constexpr (nullable) { - if (column->is_null_at(row_num)) { - arg->add_bytes_value(nullptr); - } else { - StringRef data = column->get_data_at(row_num); - arg->add_bytes_value(data.data, data.size); - } - } else { - StringRef data = column->get_data_at(row_num); - arg->add_bytes_value(data.data, data.size); - } - } - break; - } - case TypeIndex::HLL: { - ptype->set_id(PGenericType::HLL); - arg->mutable_bytes_value()->Reserve(row_count); - for (size_t row_num = 0; row_num < row_count; ++row_num) { - if constexpr (nullable) { - if (column->is_null_at(row_num)) { - arg->add_bytes_value(nullptr); - } else { - StringRef data = column->get_data_at(row_num); - arg->add_bytes_value(data.data, data.size); - } - } else { - StringRef data = column->get_data_at(row_num); - arg->add_bytes_value(data.data, data.size); - } - } - break; - } - default: - LOG(INFO) << "unknown type: " << data_type->get_name(); - ptype->set_id(PGenericType::UNKNOWN); - break; - } -} - -void convert_nullable_col_to_pvalue(const ColumnPtr& column, const DataTypePtr& data_type, - const ColumnUInt8& null_col, PValues* arg, size_t row_count) { - if (column->has_null(row_count)) { - auto* null_map = arg->mutable_null_map(); - null_map->Reserve(row_count); - const auto* col = check_and_get_column(null_col); - auto& data = col->get_data(); - null_map->Add(data.begin(), data.begin() + row_count); - convert_col_to_pvalue(column, data_type, arg, row_count); - } else { - convert_col_to_pvalue(column, data_type, arg, row_count); - } -} - -void convert_block_to_proto(Block& block, const ColumnNumbers& arguments, size_t input_rows_count, - PFunctionCallRequest* request) { - size_t row_count = std::min(block.rows(), input_rows_count); - for (size_t col_idx : arguments) { - PValues* arg = request->add_args(); - ColumnWithTypeAndName& column = block.get_by_position(col_idx); - arg->set_has_null(column.column->has_null(row_count)); - auto col = column.column->convert_to_full_column_if_const(); - if (auto* nullable = check_and_get_column(*col)) { - auto data_col = nullable->get_nested_column_ptr(); - auto& null_col = nullable->get_null_map_column(); - auto data_type = std::reinterpret_pointer_cast(column.type); - convert_nullable_col_to_pvalue(data_col->convert_to_full_column_if_const(), - data_type->get_nested_type(), null_col, arg, row_count); - } else { - convert_col_to_pvalue(col, column.type, arg, row_count); - } - } -} - -template -void convert_to_column(MutableColumnPtr& column, const PValues& result) { - switch (result.type().id()) { - case PGenericType::UINT8: { - column->reserve(result.uint32_value_size()); - column->resize(result.uint32_value_size()); - auto& data = reinterpret_cast(column.get())->get_data(); - for (int i = 0; i < result.uint32_value_size(); ++i) { - data[i] = result.uint32_value(i); - } - break; - } - case PGenericType::UINT16: { - column->reserve(result.uint32_value_size()); - column->resize(result.uint32_value_size()); - auto& data = reinterpret_cast(column.get())->get_data(); - for (int i = 0; i < result.uint32_value_size(); ++i) { - data[i] = result.uint32_value(i); - } - break; - } - case PGenericType::UINT32: { - column->reserve(result.uint32_value_size()); - column->resize(result.uint32_value_size()); - auto& data = reinterpret_cast(column.get())->get_data(); - for (int i = 0; i < result.uint32_value_size(); ++i) { - data[i] = result.uint32_value(i); - } - break; - } - case PGenericType::UINT64: { - column->reserve(result.uint64_value_size()); - column->resize(result.uint64_value_size()); - auto& data = reinterpret_cast(column.get())->get_data(); - for (int i = 0; i < result.uint64_value_size(); ++i) { - data[i] = result.uint64_value(i); - } - break; - } - case PGenericType::INT8: { - column->reserve(result.int32_value_size()); - column->resize(result.int32_value_size()); - auto& data = reinterpret_cast(column.get())->get_data(); - for (int i = 0; i < result.int32_value_size(); ++i) { - data[i] = result.int32_value(i); - } - break; - } - case PGenericType::INT16: { - column->reserve(result.int32_value_size()); - column->resize(result.int32_value_size()); - auto& data = reinterpret_cast(column.get())->get_data(); - for (int i = 0; i < result.int32_value_size(); ++i) { - data[i] = result.int32_value(i); - } - break; - } - case PGenericType::INT32: { - column->reserve(result.int32_value_size()); - column->resize(result.int32_value_size()); - auto& data = reinterpret_cast(column.get())->get_data(); - for (int i = 0; i < result.int32_value_size(); ++i) { - data[i] = result.int32_value(i); - } - break; - } - case PGenericType::INT64: { - column->reserve(result.int64_value_size()); - column->resize(result.int64_value_size()); - auto& data = reinterpret_cast(column.get())->get_data(); - for (int i = 0; i < result.int64_value_size(); ++i) { - data[i] = result.int64_value(i); - } - break; - } - case PGenericType::DATE: - case PGenericType::DATETIME: { - column->reserve(result.datetime_value_size()); - column->resize(result.datetime_value_size()); - auto& data = reinterpret_cast(column.get())->get_data(); - for (int i = 0; i < result.datetime_value_size(); ++i) { - VecDateTimeValue v; - PDateTime pv = result.datetime_value(i); - v.set_time(pv.year(), pv.month(), pv.day(), pv.hour(), pv.minute(), pv.minute()); - data[i] = binary_cast(v); - } - break; - } - case PGenericType::FLOAT: { - column->reserve(result.float_value_size()); - column->resize(result.float_value_size()); - auto& data = reinterpret_cast(column.get())->get_data(); - for (int i = 0; i < result.float_value_size(); ++i) { - data[i] = result.float_value(i); - } - break; - } - case PGenericType::DOUBLE: { - column->reserve(result.double_value_size()); - column->resize(result.double_value_size()); - auto& data = reinterpret_cast(column.get())->get_data(); - for (int i = 0; i < result.double_value_size(); ++i) { - data[i] = result.double_value(i); - } - break; - } - case PGenericType::INT128: { - column->reserve(result.bytes_value_size()); - column->resize(result.bytes_value_size()); - auto& data = reinterpret_cast(column.get())->get_data(); - for (int i = 0; i < result.bytes_value_size(); ++i) { - data[i] = *(int128_t*)(result.bytes_value(i).c_str()); - } - break; - } - case PGenericType::STRING: { - column->reserve(result.string_value_size()); - for (int i = 0; i < result.string_value_size(); ++i) { - column->insert_data(result.string_value(i).c_str(), result.string_value(i).size()); - } - break; - } - case PGenericType::DECIMAL128: { - column->reserve(result.bytes_value_size()); - column->resize(result.bytes_value_size()); - auto& data = reinterpret_cast(column.get())->get_data(); - for (int i = 0; i < result.bytes_value_size(); ++i) { - data[i] = *(int128_t*)(result.bytes_value(i).c_str()); - } - break; - } - case PGenericType::BITMAP: { - column->reserve(result.bytes_value_size()); - for (int i = 0; i < result.bytes_value_size(); ++i) { - column->insert_data(result.bytes_value(i).c_str(), result.bytes_value(i).size()); - } - break; - } - case PGenericType::HLL: { - column->reserve(result.bytes_value_size()); - for (int i = 0; i < result.bytes_value_size(); ++i) { - column->insert_data(result.bytes_value(i).c_str(), result.bytes_value(i).size()); - } - break; - } - default: { - LOG(WARNING) << "unknown PGenericType: " << result.type().DebugString(); - break; - } - } -} - -void convert_to_block(Block& block, const PValues& result, size_t pos) { - auto data_type = block.get_data_type(pos); - if (data_type->is_nullable()) { - auto null_type = std::reinterpret_pointer_cast(data_type); - auto data_col = null_type->get_nested_type()->create_column(); - convert_to_column(data_col, result); - auto null_col = ColumnUInt8::create(data_col->size(), 0); - auto& null_map_data = null_col->get_data(); - null_col->reserve(data_col->size()); - null_col->resize(data_col->size()); - if (result.has_null()) { - for (int i = 0; i < data_col->size(); ++i) { - null_map_data[i] = result.null_map(i); - } - } else { - for (int i = 0; i < data_col->size(); ++i) { - null_map_data[i] = false; - } - } - block.replace_by_position(pos, - ColumnNullable::create(std::move(data_col), std::move(null_col))); - } else { - auto column = data_type->create_column(); - convert_to_column(column, result); - block.replace_by_position(pos, std::move(column)); - } -} - -Status RPCFnCall::execute(FunctionContext* context, Block& block, const ColumnNumbers& arguments, - size_t result, size_t input_rows_count, bool dry_run) { - PFunctionCallRequest request; - PFunctionCallResponse response; - request.set_function_name(_symbol); - convert_block_to_proto(block, arguments, input_rows_count, &request); - brpc::Controller cntl; - _client->fn_call(&cntl, &request, &response, nullptr); - if (cntl.Failed()) { - return Status::InternalError( - fmt::format("call to rpc function {} failed: {}", _symbol, cntl.ErrorText()) - .c_str()); - } - if (!response.has_status() || !response.has_result()) { - return Status::InternalError( - fmt::format("call rpc function {} failed: status or result is not set.", _symbol)); - } - if (response.status().status_code() != 0) { - return Status::InternalError(fmt::format("call to rpc function {} failed: {}", _symbol, - response.status().DebugString())); - } - convert_to_block(block, response.result(), result); - return Status::OK(); +Status FunctionRPC::execute(FunctionContext* context, Block& block, const ColumnNumbers& arguments, + size_t result, size_t input_rows_count, bool dry_run) { + return _fn->vec_call(context, block, arguments, result, input_rows_count); } } // namespace doris::vectorized diff --git a/be/src/vec/functions/function_rpc.h b/be/src/vec/functions/function_rpc.h index 43bfe3acc2..a4037958dd 100644 --- a/be/src/vec/functions/function_rpc.h +++ b/be/src/vec/functions/function_rpc.h @@ -20,25 +20,28 @@ #include "vec/functions/function.h" namespace doris { -class PFunctionService_Stub; +class RPCFn; namespace vectorized { -class RPCFnCall : public IFunctionBase { +class FunctionRPC : public IFunctionBase { public: - RPCFnCall(const std::string& symbol, const std::string& server, const DataTypes& argument_types, - const DataTypePtr& return_type); - static FunctionBasePtr create(const std::string& symbol, const std::string& server, - const ColumnsWithTypeAndName& argument_types, + FunctionRPC(const TFunction& fn, const DataTypes& argument_types, + const DataTypePtr& return_type); + + static FunctionBasePtr create(const TFunction& fn, const ColumnsWithTypeAndName& argument_types, const DataTypePtr& return_type) { DataTypes data_types(argument_types.size()); for (size_t i = 0; i < argument_types.size(); ++i) { data_types[i] = argument_types[i].type; } - return std::make_shared(symbol, server, data_types, return_type); + return std::make_shared(fn, data_types, return_type); } /// Get the main function name. - String get_name() const override { return _name; }; + String get_name() const override { + return fmt::format("{}: [{}/{}]", _tfn.name.function_name, _tfn.hdfs_location, + _tfn.scalar_fn.symbol); + }; const DataTypes& get_argument_types() const override { return _argument_types; }; const DataTypePtr& get_return_type() const override { return _return_type; }; @@ -58,12 +61,10 @@ public: bool is_deterministic_in_scope_of_query() const override { return false; } private: - std::string _symbol; - std::string _server; - std::string _name; DataTypes _argument_types; DataTypePtr _return_type; - std::shared_ptr _client = nullptr; + TFunction _tfn; + std::unique_ptr _fn; }; } // namespace vectorized diff --git a/contrib/udf/CMakeLists.txt b/contrib/udf/CMakeLists.txt index 6b347f8c6c..66fd4f32cc 100644 --- a/contrib/udf/CMakeLists.txt +++ b/contrib/udf/CMakeLists.txt @@ -34,22 +34,6 @@ set(BUILD_DIR "${CMAKE_CURRENT_BINARY_DIR}") set(SRC_DIR "${BASE_DIR}/src/") set(OUTPUT_DIR "${BASE_DIR}/output") -# Check gcc -if (CMAKE_CXX_COMPILER_ID STREQUAL "GNU") - if (CMAKE_CXX_COMPILER_VERSION VERSION_LESS "7.3.0") - message(FATAL_ERROR "Need GCC version at least 7.3.0") - endif() - - if (CMAKE_CXX_COMPILER_VERSION VERSION_GREATER "7.3.0") - message(STATUS "GCC version is greater than 7.3.0, disable -Werror. Be careful with compile warnings.") - else() - # -Werror: compile warnings should be errors when using the toolchain compiler. - set(CXX_GCC_FLAGS "${CXX_GCC_FLAGS} -Werror") - endif() -elseif (NOT APPLE) - message(FATAL_ERROR "Compiler should be GNU") -endif() - # Just for clang-tidy: -Wno-expansion-to-defined -Wno-deprecated-declaration SET(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -fPIC -g -ggdb -std=c++11 -Wall -Werror -Wno-unused-variable -Wno-expansion-to-defined -Wno-deprecated-declarations -O3") diff --git a/docs/en/extending-doris/udf/native-user-defined-function.md b/docs/en/extending-doris/udf/native-user-defined-function.md index 32311bf57a..c32f17549c 100644 --- a/docs/en/extending-doris/udf/native-user-defined-function.md +++ b/docs/en/extending-doris/udf/native-user-defined-function.md @@ -34,8 +34,6 @@ There are two types of analysis requirements that UDF can meet: UDF and UDAF. UD This document mainly describes how to write a custom UDF function and how to use it in Doris. -If users use the UDF function and extend Doris' function analysis, and want to contribute their own UDF functions back to the Doris community for other users, please see the document [Contribute UDF](./contribute_udf.md). - ## Writing UDF functions Before using UDF, users need to write their own UDF functions under Doris' UDF framework. In the `contrib/udf/src/udf_samples/udf_sample.h|cpp` file is a simple UDF Demo. diff --git a/docs/en/sql-reference/sql-statements/Data Definition/create-function.md b/docs/en/sql-reference/sql-statements/Data Definition/create-function.md index ca581ddf58..11f6bf9681 100644 --- a/docs/en/sql-reference/sql-statements/Data Definition/create-function.md +++ b/docs/en/sql-reference/sql-statements/Data Definition/create-function.md @@ -79,6 +79,8 @@ CREATE [AGGREGATE] [ALIAS] FUNCTION function_name > "prepare_fn": Function signature of the prepare function for finding the entry from the dynamic library. This option is optional for custom functions > > "close_fn": Function signature of the close function for finding the entry from the dynamic library. This option is optional for custom functions +> "type": Function type, RPC for remote udf, NATIVE for c++ native udf + This statement creates a custom function. Executing this command requires that the user have `ADMIN` privileges. @@ -138,6 +140,13 @@ If the `function_name` contains the database name, the custom function will be c CREATE ALIAS FUNCTION string(ALL, INT) WITH PARAMETER(col, length) AS CAST(col AS varchar(length)); ``` - +6. Create a remote UDF + ``` + CREATE FUNCTION rpc_add(INT, INT) RETURNS INT PROPERTIES ( + "SYMBOL"="add_int", + "OBJECT_FILE"="127.0.0.1:9999", + "TYPE"="RPC" + ); + ``` ## keyword CREATE,FUNCTION diff --git a/docs/zh-CN/extending-doris/udf/native-user-defined-function.md b/docs/zh-CN/extending-doris/udf/native-user-defined-function.md index 8f7e56c869..fff1ddbd5d 100644 --- a/docs/zh-CN/extending-doris/udf/native-user-defined-function.md +++ b/docs/zh-CN/extending-doris/udf/native-user-defined-function.md @@ -35,8 +35,6 @@ UDF 能满足的分析需求分为两种:UDF 和 UDAF。本文中的 UDF 指 这篇文档主要讲述了,如何编写自定义的 UDF 函数,以及如何在 Doris 中使用它。 -如果用户使用 UDF 功能并扩展了 Doris 的函数分析,并且希望将自己实现的 UDF 函数贡献回 Doris 社区给其他用户使用,这时候请看文档 [Contribute UDF](./contribute_udf.md)。 - ## 编写 UDF 函数 在使用UDF之前,用户需要先在 Doris 的 UDF 框架下,编写自己的UDF函数。在`contrib/udf/src/udf_samples/udf_sample.h|cpp`文件中是一个简单的 UDF Demo。 diff --git a/docs/zh-CN/sql-reference/sql-statements/Data Definition/create-function.md b/docs/zh-CN/sql-reference/sql-statements/Data Definition/create-function.md index f6b2c7b990..902462664a 100644 --- a/docs/zh-CN/sql-reference/sql-statements/Data Definition/create-function.md +++ b/docs/zh-CN/sql-reference/sql-statements/Data Definition/create-function.md @@ -79,6 +79,7 @@ CREATE [AGGREGATE] [ALIAS] FUNCTION function_name > "prepare_fn": 自定义函数的prepare函数的函数签名,用于从动态库里面找到prepare函数入口。此选项对于自定义函数是可选项 > > "close_fn": 自定义函数的close函数的函数签名,用于从动态库里面找到close函数入口。此选项对于自定义函数是可选项 +> "type": 自定义函数的类型,如果是远程函数就是则填 RPC,C++的原生 UDF 填 NATIVE, 默认 NATIVE 此语句创建一个自定义函数。执行此命令需要用户拥有 `ADMIN` 权限。 @@ -89,35 +90,35 @@ CREATE [AGGREGATE] [ALIAS] FUNCTION function_name 1. 创建一个自定义标量函数 - ``` - CREATE FUNCTION my_add(INT, INT) RETURNS INT PROPERTIES ( - "symbol" = "_ZN9doris_udf6AddUdfEPNS_15FunctionContextERKNS_6IntValES4_", - "object_file" = "http://host:port/libmyadd.so" - ); - ``` + ``` + CREATE FUNCTION my_add(INT, INT) RETURNS INT PROPERTIES ( + "symbol" = "_ZN9doris_udf6AddUdfEPNS_15FunctionContextERKNS_6IntValES4_", + "object_file" = "http://host:port/libmyadd.so" + ); + ``` 2. 创建一个有prepare/close函数的自定义标量函数 - ``` - CREATE FUNCTION my_add(INT, INT) RETURNS INT PROPERTIES ( - "symbol" = "_ZN9doris_udf6AddUdfEPNS_15FunctionContextERKNS_6IntValES4_", - "prepare_fn" = "_ZN9doris_udf14AddUdf_prepareEPNS_15FunctionContextENS0_18FunctionStateScopeE", - "close_fn" = "_ZN9doris_udf12AddUdf_closeEPNS_15FunctionContextENS0_18FunctionStateScopeE", - "object_file" = "http://host:port/libmyadd.so" - ); - ``` + ``` + CREATE FUNCTION my_add(INT, INT) RETURNS INT PROPERTIES ( + "symbol" = "_ZN9doris_udf6AddUdfEPNS_15FunctionContextERKNS_6IntValES4_", + "prepare_fn" = "_ZN9doris_udf14AddUdf_prepareEPNS_15FunctionContextENS0_18FunctionStateScopeE", + "close_fn" = "_ZN9doris_udf12AddUdf_closeEPNS_15FunctionContextENS0_18FunctionStateScopeE", + "object_file" = "http://host:port/libmyadd.so" + ); + ``` 3. 创建一个自定义聚合函数 - ``` - CREATE AGGREGATE FUNCTION my_count (BIGINT) RETURNS BIGINT PROPERTIES ( - "init_fn"="_ZN9doris_udf9CountInitEPNS_15FunctionContextEPNS_9BigIntValE", - "update_fn"="_ZN9doris_udf11CountUpdateEPNS_15FunctionContextERKNS_6IntValEPNS_9BigIntValE", - "merge_fn"="_ZN9doris_udf10CountMergeEPNS_15FunctionContextERKNS_9BigIntValEPS2_", - "finalize_fn"="_ZN9doris_udf13CountFinalizeEPNS_15FunctionContextERKNS_9BigIntValE", - "object_file"="http://host:port/libudasample.so" - ); - ``` + ``` + CREATE AGGREGATE FUNCTION my_count (BIGINT) RETURNS BIGINT PROPERTIES ( + "init_fn"="_ZN9doris_udf9CountInitEPNS_15FunctionContextEPNS_9BigIntValE", + "update_fn"="_ZN9doris_udf11CountUpdateEPNS_15FunctionContextERKNS_6IntValEPNS_9BigIntValE", + "merge_fn"="_ZN9doris_udf10CountMergeEPNS_15FunctionContextERKNS_9BigIntValEPS2_", + "finalize_fn"="_ZN9doris_udf13CountFinalizeEPNS_15FunctionContextERKNS_9BigIntValE", + "object_file"="http://host:port/libudasample.so" + ); + ``` 4. 创建一个变长参数的标量函数 @@ -139,7 +140,14 @@ CREATE [AGGREGATE] [ALIAS] FUNCTION function_name CREATE ALIAS FUNCTION string(ALL, INT) WITH PARAMETER(col, length) AS CAST(col AS varchar(length)); ``` - +6. 创建一个远程自动函数 + ``` + CREATE FUNCTION rpc_add(INT, INT) RETURNS INT PROPERTIES ( + "SYMBOL"="add_int", + "OBJECT_FILE"="127.0.0.1:9999", + "TYPE"="RPC" + ); + ``` ## keyword CREATE,FUNCTION diff --git a/fe/fe-core/src/main/java/org/apache/doris/analysis/CreateFunctionStmt.java b/fe/fe-core/src/main/java/org/apache/doris/analysis/CreateFunctionStmt.java index d2516954f7..2446fb8249 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/analysis/CreateFunctionStmt.java +++ b/fe/fe-core/src/main/java/org/apache/doris/analysis/CreateFunctionStmt.java @@ -236,9 +236,6 @@ public class CreateFunctionStmt extends DdlStmt { } private void analyzeUda() throws AnalysisException { - if (binaryType == TFunctionBinaryType.RPC) { - throw new AnalysisException("RPC UDAF is not supported."); - } AggregateFunction.AggregateFunctionBuilder builder = AggregateFunction.AggregateFunctionBuilder.createUdfBuilder(); builder.name(functionName).argsType(argsDef.getArgTypes()).retType(returnType.getType()). @@ -255,10 +252,31 @@ public class CreateFunctionStmt extends DdlStmt { if (mergeFnSymbol == null) { throw new AnalysisException("No 'merge_fn' in properties"); } + String serializeFnSymbol = properties.get(SERIALIZE_KEY); + String finalizeFnSymbol = properties.get(FINALIZE_KEY); + String getValueFnSymbol = properties.get(GET_VALUE_KEY); + String removeFnSymbol = properties.get(REMOVE_KEY); + if (binaryType == TFunctionBinaryType.RPC && !userFile.contains("://")) { + checkRPCUdf(initFnSymbol); + checkRPCUdf(updateFnSymbol); + checkRPCUdf(mergeFnSymbol); + if (serializeFnSymbol != null) { + checkRPCUdf(serializeFnSymbol); + } + if (finalizeFnSymbol != null) { + checkRPCUdf(finalizeFnSymbol); + } + if (getValueFnSymbol != null) { + checkRPCUdf(getValueFnSymbol); + } + if (removeFnSymbol != null) { + checkRPCUdf(removeFnSymbol); + } + } function = builder.initFnSymbol(initFnSymbol) .updateFnSymbol(updateFnSymbol).mergeFnSymbol(mergeFnSymbol) - .serializeFnSymbol(properties.get(SERIALIZE_KEY)).finalizeFnSymbol(properties.get(FINALIZE_KEY)) - .getValueFnSymbol(properties.get(GET_VALUE_KEY)).removeFnSymbol(properties.get(REMOVE_KEY)) + .serializeFnSymbol(serializeFnSymbol).finalizeFnSymbol(finalizeFnSymbol) + .getValueFnSymbol(getValueFnSymbol).removeFnSymbol(removeFnSymbol) .build(); function.setChecksum(checksum); } @@ -274,33 +292,9 @@ public class CreateFunctionStmt extends DdlStmt { // the format for load balance can ref https://github.com/apache/incubator-brpc/blob/master/docs/en/client.md#connect-to-a-cluster if (binaryType == TFunctionBinaryType.RPC && !userFile.contains("://")) { if (StringUtils.isNotBlank(prepareFnSymbol) || StringUtils.isNotBlank(closeFnSymbol)) { - throw new AnalysisException(" prepare and close in RPC UDF are not supported."); - } - String[] url = userFile.split(":"); - if (url.length != 2) { - throw new AnalysisException("function server address invalid."); - } - String host = url[0]; - int port = Integer.valueOf(url[1]); - ManagedChannel channel = NettyChannelBuilder.forAddress(host, port) - .flowControlWindow(Config.grpc_max_message_size_bytes) - .maxInboundMessageSize(Config.grpc_max_message_size_bytes) - .enableRetry().maxRetryAttempts(3) - .usePlaintext().build(); - PFunctionServiceGrpc.PFunctionServiceBlockingStub stub = PFunctionServiceGrpc.newBlockingStub(channel); - FunctionService.PCheckFunctionRequest.Builder builder = FunctionService.PCheckFunctionRequest.newBuilder(); - builder.getFunctionBuilder().setFunctionName(symbol); - for (Type arg : argsDef.getArgTypes()) { - builder.getFunctionBuilder().addInputs(convertToPParameterType(arg)); - } - builder.getFunctionBuilder().setOutput(convertToPParameterType(returnType.getType())); - FunctionService.PCheckFunctionResponse response = stub.checkFn(builder.build()); - if (response == null || !response.hasStatus()) { - throw new AnalysisException("cannot access function server"); - } - if (response.getStatus().getStatusCode() != 0) { - throw new AnalysisException("check function [" + symbol + "] failed: " + response.getStatus()); + throw new AnalysisException("prepare and close in RPC UDF are not supported."); } + checkRPCUdf(symbol); } else if (binaryType == TFunctionBinaryType.JAVA_UDF) { analyzeJavaUdf(symbol); } @@ -399,6 +393,36 @@ public class CreateFunctionStmt extends DdlStmt { } } + private void checkRPCUdf(String symbol) throws AnalysisException { + // TODO(yangzhg) support check function in FE when function service behind load balancer + // the format for load balance can ref https://github.com/apache/incubator-brpc/blob/master/docs/en/client.md#connect-to-a-cluster + String[] url = userFile.split(":"); + if (url.length != 2) { + throw new AnalysisException("function server address invalid."); + } + String host = url[0]; + int port = Integer.valueOf(url[1]); + ManagedChannel channel = NettyChannelBuilder.forAddress(host, port) + .flowControlWindow(Config.grpc_max_message_size_bytes) + .maxInboundMessageSize(Config.grpc_max_message_size_bytes) + .enableRetry().maxRetryAttempts(3) + .usePlaintext().build(); + PFunctionServiceGrpc.PFunctionServiceBlockingStub stub = PFunctionServiceGrpc.newBlockingStub(channel); + FunctionService.PCheckFunctionRequest.Builder builder = FunctionService.PCheckFunctionRequest.newBuilder(); + builder.getFunctionBuilder().setFunctionName(symbol); + for (Type arg : argsDef.getArgTypes()) { + builder.getFunctionBuilder().addInputs(convertToPParameterType(arg)); + } + builder.getFunctionBuilder().setOutput(convertToPParameterType(returnType.getType())); + FunctionService.PCheckFunctionResponse response = stub.checkFn(builder.build()); + if (response == null || !response.hasStatus()) { + throw new AnalysisException("cannot access function server"); + } + if (response.getStatus().getStatusCode() != 0) { + throw new AnalysisException("check function [" + symbol + "] failed: " + response.getStatus()); + } + } + private Types.PGenericType convertToPParameterType(Type arg) throws AnalysisException { Types.PGenericType.Builder typeBuilder = Types.PGenericType.newBuilder(); switch (arg.getPrimitiveType()) { diff --git a/fe/fe-core/src/main/java/org/apache/doris/catalog/AggregateFunction.java b/fe/fe-core/src/main/java/org/apache/doris/catalog/AggregateFunction.java index 8f35805523..1820fba43b 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/catalog/AggregateFunction.java +++ b/fe/fe-core/src/main/java/org/apache/doris/catalog/AggregateFunction.java @@ -392,6 +392,10 @@ public class AggregateFunction extends Function { this.removeFnSymbol = symbol; return this; } + public AggregateFunctionBuilder binaryType(TFunctionBinaryType binaryType) { + this.binaryType = binaryType; + return this; + } public AggregateFunction build() { AggregateFunction fn = new AggregateFunction(name, argTypes, retType, hasVarArgs, intermediateType, diff --git a/fe/fe-core/src/main/java/org/apache/doris/common/util/URI.java b/fe/fe-core/src/main/java/org/apache/doris/common/util/URI.java index b0c0f19f4d..4f4409e2a8 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/common/util/URI.java +++ b/fe/fe-core/src/main/java/org/apache/doris/common/util/URI.java @@ -201,4 +201,9 @@ public class URI { throw new AnalysisException("Invalid host port: " + hostPort); } } + + @Override + public String toString() { + return location; + } } diff --git a/gensrc/proto/function_service.proto b/gensrc/proto/function_service.proto index 561be9f887..298957b556 100644 --- a/gensrc/proto/function_service.proto +++ b/gensrc/proto/function_service.proto @@ -35,8 +35,9 @@ message PFunctionCallRequest { } message PFunctionCallResponse { - optional PValues result = 1; + repeated PValues result = 1; optional PStatus status = 2; + optional PRequestContext context = 3; } message PCheckFunctionRequest { diff --git a/gensrc/proto/types.proto b/gensrc/proto/types.proto index e1f8445620..a3ac3dc190 100644 --- a/gensrc/proto/types.proto +++ b/gensrc/proto/types.proto @@ -193,17 +193,7 @@ message PFunction { } message PFunctionContext { - optional string version = 1 [default = "V2_0"]; - repeated PValue staging_input_vals = 2; - repeated PValue constant_args = 3; - optional string error_msg = 4; - optional PUniqueId query_id = 5; - optional bytes thread_local_fn_state = 6; - optional bytes fragment_local_fn_state = 7; - optional string string_result = 8; - optional int64 num_updates = 9; - optional int64 num_removes = 10; - optional int64 num_warnings = 11; + optional bytes data = 1; } message PHandShakeRequest { diff --git a/samples/doris-demo/remote-udf-cpp-demo/cpp_function_service_demo.cpp b/samples/doris-demo/remote-udf-cpp-demo/cpp_function_service_demo.cpp index 3e0193f1ed..7e141394d0 100644 --- a/samples/doris-demo/remote-udf-cpp-demo/cpp_function_service_demo.cpp +++ b/samples/doris-demo/remote-udf-cpp-demo/cpp_function_service_demo.cpp @@ -32,20 +32,21 @@ public: ::google::protobuf::Closure* done) override { brpc::ClosureGuard closure_guard(done); std::string fun_name = request->function_name(); + auto* result = response->add_result(); if (fun_name == "int32_add") { - response->mutable_result()->mutable_type()->set_id(PGenericType::INT32); + result->mutable_type()->set_id(PGenericType::INT32); for (size_t i = 0; i < request->args(0).int32_value_size(); ++i) { - response->mutable_result()->add_int32_value(request->args(0).int32_value(i) + - request->args(1).int32_value(i)); + result->add_int32_value(request->args(0).int32_value(i) + + request->args(1).int32_value(i)); } } else if (fun_name == "int64_add") { - response->mutable_result()->mutable_type()->set_id(PGenericType::INT64); + result->mutable_type()->set_id(PGenericType::INT64); for (size_t i = 0; i < request->args(0).int64_value_size(); ++i) { - response->mutable_result()->add_int64_value(request->args(0).int64_value(i) + - request->args(1).int64_value(i)); + result->add_int64_value(request->args(0).int64_value(i) + + request->args(1).int64_value(i)); } } else if (fun_name == "int128_add") { - response->mutable_result()->mutable_type()->set_id(PGenericType::INT128); + result->mutable_type()->set_id(PGenericType::INT128); for (size_t i = 0; i < request->args(0).bytes_value_size(); ++i) { __int128 v1; memcpy(&v1, request->args(0).bytes_value(i).data(), sizeof(__int128)); @@ -54,26 +55,25 @@ public: __int128 v = v1 + v2; char buffer[sizeof(__int128)]; memcpy(buffer, &v, sizeof(__int128)); - response->mutable_result()->add_bytes_value(buffer, sizeof(__int128)); + result->add_bytes_value(buffer, sizeof(__int128)); } } else if (fun_name == "float_add") { - response->mutable_result()->mutable_type()->set_id(PGenericType::FLOAT); + result->mutable_type()->set_id(PGenericType::FLOAT); for (size_t i = 0; i < request->args(0).float_value_size(); ++i) { - response->mutable_result()->add_float_value(request->args(0).float_value(i) + - request->args(1).float_value(i)); + result->add_float_value(request->args(0).float_value(i) + + request->args(1).float_value(i)); } } else if (fun_name == "double_add") { - response->mutable_result()->mutable_type()->set_id(PGenericType::DOUBLE); + result->mutable_type()->set_id(PGenericType::DOUBLE); for (size_t i = 0; i < request->args(0).double_value_size(); ++i) { - response->mutable_result()->add_double_value(request->args(0).double_value(i) + - request->args(1).double_value(i)); + result->add_double_value(request->args(0).double_value(i) + + request->args(1).double_value(i)); } } else if (fun_name == "str_add") { - response->mutable_result()->mutable_type()->set_id(PGenericType::STRING); + result->mutable_type()->set_id(PGenericType::STRING); for (size_t i = 0; i < request->args(0).string_value_size(); ++i) { - response->mutable_result()->add_string_value(request->args(0).string_value(i) + - " + " + - request->args(1).string_value(i)); + result->add_string_value(request->args(0).string_value(i) + " + " + + request->args(1).string_value(i)); } } response->mutable_status()->set_status_code(0); diff --git a/samples/doris-demo/remote-udf-java-demo/src/main/java/org/apache/doris/udf/FunctionServiceImpl.java b/samples/doris-demo/remote-udf-java-demo/src/main/java/org/apache/doris/udf/FunctionServiceImpl.java index 40558193ad..fded80505f 100644 --- a/samples/doris-demo/remote-udf-java-demo/src/main/java/org/apache/doris/udf/FunctionServiceImpl.java +++ b/samples/doris-demo/remote-udf-java-demo/src/main/java/org/apache/doris/udf/FunctionServiceImpl.java @@ -45,7 +45,7 @@ public class FunctionServiceImpl extends PFunctionServiceGrpc.PFunctionServiceIm if ("add_int".equals(functionName)) { res = FunctionService.PFunctionCallResponse.newBuilder() .setStatus(Types.PStatus.newBuilder().setStatusCode(0).build()) - .setResult(Types.PValues.newBuilder().setHasNull(false) + .addResult(Types.PValues.newBuilder().setHasNull(false) .addAllInt32Value(IntStream.range(0, Math.min(request.getArgs(0) .getInt32ValueCount(), request.getArgs(1).getInt32ValueCount())) .mapToObj(i -> request.getArgs(0).getInt32Value(i) + request.getArgs(1) diff --git a/samples/doris-demo/remote-udf-python-demo/function_server_demo.py b/samples/doris-demo/remote-udf-python-demo/function_server_demo.py index 60d6d939c9..d1f2160013 100644 --- a/samples/doris-demo/remote-udf-python-demo/function_server_demo.py +++ b/samples/doris-demo/remote-udf-python-demo/function_server_demo.py @@ -43,7 +43,7 @@ class FunctionServerDemo(function_service_pb2_grpc.PFunctionServiceServicer): result_type.id = types_pb2.PGenericType.INT32 result.type.CopyFrom(result_type) result.int32_value.extend([x + y for x, y in zip(request.args[0].int32_value, request.args[1].int32_value)]) - response.result.CopyFrom(result) + response.result.append(result) return response def check_fn(self, request, context):