[refactor](udf) refactor java-udf execute method by using for loop (#21388)

This commit is contained in:
zhangstar333
2023-07-07 11:43:11 +08:00
committed by GitHub
parent 8272232e21
commit bb985cd9a1
4 changed files with 2324 additions and 154 deletions

View File

@ -17,6 +17,9 @@
#include "vec/functions/function_java_udf.h"
#include <glog/logging.h>
#include <cstdint>
#include <memory>
#include <string>
#include <utility>
@ -36,6 +39,7 @@
#include "vec/common/assert_cast.h"
#include "vec/common/string_ref.h"
#include "vec/core/block.h"
#include "vec/data_types/data_type_array.h"
#include "vec/data_types/data_type_nullable.h"
const char* EXECUTOR_CLASS = "org/apache/doris/udf/UdfExecutor";
@ -60,9 +64,18 @@ Status JavaFunctionCall::open(FunctionContext* context, FunctionContext::Functio
jni_env->executor_ctor_id =
env->GetMethodID(jni_env->executor_cl, "<init>", EXECUTOR_CTOR_SIGNATURE);
RETURN_ERROR_IF_EXC(env);
jni_env->executor_evaluate_id =
env->GetMethodID(jni_env->executor_cl, "evaluate", EXECUTOR_EVALUATE_SIGNATURE);
RETURN_ERROR_IF_EXC(env);
jni_env->executor_evaluate_id = env->GetMethodID(
jni_env->executor_cl, "evaluate", "(I[Ljava/lang/Object;)[Ljava/lang/Object;");
jni_env->executor_convert_basic_argument_id = env->GetMethodID(
jni_env->executor_cl, "convertBasicArguments", "(IZIJJJ)[Ljava/lang/Object;");
jni_env->executor_convert_array_argument_id = env->GetMethodID(
jni_env->executor_cl, "convertArrayArguments", "(IZIJJJJJ)[Ljava/lang/Object;");
jni_env->executor_result_basic_batch_id = env->GetMethodID(
jni_env->executor_cl, "copyBatchBasicResult", "(ZI[Ljava/lang/Object;JJJ)V");
jni_env->executor_result_array_batch_id = env->GetMethodID(
jni_env->executor_cl, "copyBatchArrayResult", "(ZI[Ljava/lang/Object;JJJJJ)V");
jni_env->executor_close_id =
env->GetMethodID(jni_env->executor_cl, "close", EXECUTOR_CLOSE_SIGNATURE);
RETURN_ERROR_IF_EXC(env);
@ -130,189 +143,176 @@ Status JavaFunctionCall::execute(FunctionContext* context, Block& block,
context->get_function_state(FunctionContext::THREAD_LOCAL));
JniEnv* jni_env =
reinterpret_cast<JniEnv*>(context->get_function_state(FunctionContext::FRAGMENT_LOCAL));
int arg_idx = 0;
ColumnPtr data_cols[arguments.size()];
ColumnPtr null_cols[arguments.size()];
for (size_t col_idx : arguments) {
ColumnWithTypeAndName& column = block.get_by_position(col_idx);
int arg_size = arguments.size();
ColumnPtr data_cols[arg_size];
ColumnPtr null_cols[arg_size];
jclass obj_class = env->FindClass("[Ljava/lang/Object;");
jclass arraylist_class = env->FindClass("Ljava/util/ArrayList;");
jobjectArray arg_objects = env->NewObjectArray(arg_size, obj_class, nullptr);
int64_t nullmap_address = 0;
for (size_t arg_idx = 0; arg_idx < arg_size; ++arg_idx) {
bool arg_column_nullable = false;
// get argument column and type
ColumnWithTypeAndName& column = block.get_by_position(arguments[arg_idx]);
auto column_type = column.type;
data_cols[arg_idx] = column.column->convert_to_full_column_if_const();
if (!_argument_types[arg_idx]->equals(*column.type)) {
return Status::InvalidArgument(strings::Substitute(
"$0-th input column's type $1 does not equal to required type $2", arg_idx,
column.type->get_name(), _argument_types[arg_idx]->get_name()));
}
// check type
DCHECK(_argument_types[arg_idx]->equals(*column_type))
<< " input column's type is " + column_type->get_name()
<< " does not equal to required type " << _argument_types[arg_idx]->get_name();
// get argument null map and nested column
if (auto* nullable = check_and_get_column<const ColumnNullable>(*data_cols[arg_idx])) {
arg_column_nullable = true;
column_type = remove_nullable(column_type);
null_cols[arg_idx] = nullable->get_null_map_column_ptr();
jni_ctx->input_nulls_buffer_ptr.get()[arg_idx] = reinterpret_cast<int64_t>(
data_cols[arg_idx] = nullable->get_nested_column_ptr();
nullmap_address = reinterpret_cast<int64_t>(
check_and_get_column<ColumnVector<UInt8>>(null_cols[arg_idx])
->get_data()
.data());
data_cols[arg_idx] = nullable->get_nested_column_ptr();
} else {
jni_ctx->input_nulls_buffer_ptr.get()[arg_idx] = -1;
}
if (data_cols[arg_idx]->is_column_string()) {
// convert argument column data into java type
jobjectArray arr_obj = nullptr;
if (data_cols[arg_idx]->is_numeric() || data_cols[arg_idx]->is_column_decimal()) {
arr_obj = (jobjectArray)env->CallNonvirtualObjectMethod(
jni_ctx->executor, jni_env->executor_cl,
jni_env->executor_convert_basic_argument_id, arg_idx, arg_column_nullable,
num_rows, nullmap_address,
reinterpret_cast<int64_t>(data_cols[arg_idx]->get_raw_data().data), 0);
} else if (data_cols[arg_idx]->is_column_string()) {
const ColumnString* str_col =
assert_cast<const ColumnString*>(data_cols[arg_idx].get());
jni_ctx->input_values_buffer_ptr.get()[arg_idx] =
reinterpret_cast<int64_t>(str_col->get_chars().data());
jni_ctx->input_offsets_ptrs.get()[arg_idx] =
reinterpret_cast<int64_t>(str_col->get_offsets().data());
} else if (data_cols[arg_idx]->is_numeric() || data_cols[arg_idx]->is_column_decimal()) {
jni_ctx->input_values_buffer_ptr.get()[arg_idx] =
reinterpret_cast<int64_t>(data_cols[arg_idx]->get_raw_data().data);
arr_obj = (jobjectArray)env->CallNonvirtualObjectMethod(
jni_ctx->executor, jni_env->executor_cl,
jni_env->executor_convert_basic_argument_id, arg_idx, arg_column_nullable,
num_rows, nullmap_address,
reinterpret_cast<int64_t>(str_col->get_chars().data()),
reinterpret_cast<int64_t>(str_col->get_offsets().data()));
} else if (data_cols[arg_idx]->is_column_array()) {
const ColumnArray* array_col =
assert_cast<const ColumnArray*>(data_cols[arg_idx].get());
jni_ctx->input_offsets_ptrs.get()[arg_idx] =
reinterpret_cast<int64_t>(array_col->get_offsets_column().get_raw_data().data);
const ColumnNullable& array_nested_nullable =
assert_cast<const ColumnNullable&>(array_col->get_data());
auto data_column_null_map = array_nested_nullable.get_null_map_column_ptr();
auto data_column = array_nested_nullable.get_nested_column_ptr();
jni_ctx->input_array_nulls_buffer_ptr.get()[arg_idx] = reinterpret_cast<int64_t>(
auto offset_address =
reinterpret_cast<int64_t>(array_col->get_offsets_column().get_raw_data().data);
auto nested_nullmap_address = reinterpret_cast<int64_t>(
check_and_get_column<ColumnVector<UInt8>>(data_column_null_map)
->get_data()
.data());
//need pass FE, nullamp and offset, chars
int64_t nested_data_address = 0, nested_offset_address = 0;
// array type need pass address: [nullmap_address], offset_address, nested_nullmap_address, nested_data_address/nested_char_address,nested_offset_address
if (data_column->is_column_string()) {
const ColumnString* col = assert_cast<const ColumnString*>(data_column.get());
jni_ctx->input_values_buffer_ptr.get()[arg_idx] =
reinterpret_cast<int64_t>(col->get_chars().data());
jni_ctx->input_array_string_offsets_ptrs.get()[arg_idx] =
reinterpret_cast<int64_t>(col->get_offsets().data());
nested_data_address = reinterpret_cast<int64_t>(col->get_chars().data());
nested_offset_address = reinterpret_cast<int64_t>(col->get_offsets().data());
} else {
jni_ctx->input_values_buffer_ptr.get()[arg_idx] =
reinterpret_cast<int64_t>(data_column->get_raw_data().data);
nested_data_address = reinterpret_cast<int64_t>(data_column->get_raw_data().data);
}
arr_obj = (jobjectArray)env->CallNonvirtualObjectMethod(
jni_ctx->executor, jni_env->executor_cl,
jni_env->executor_convert_array_argument_id, arg_idx, arg_column_nullable,
num_rows, nullmap_address, offset_address, nested_nullmap_address,
nested_data_address, nested_offset_address);
} else {
return Status::InvalidArgument(
strings::Substitute("Java UDF doesn't support type $0 now !",
_argument_types[arg_idx]->get_name()));
}
arg_idx++;
}
*(jni_ctx->batch_size_ptr) = num_rows;
auto return_type = block.get_data_type(result);
if (return_type->is_nullable()) {
auto null_type = std::reinterpret_pointer_cast<const DataTypeNullable>(return_type);
auto data_col = null_type->get_nested_type()->create_column();
auto null_col = ColumnUInt8::create(data_col->size(), 0);
null_col->resize(num_rows);
*(jni_ctx->output_null_value) = reinterpret_cast<int64_t>(null_col->get_data().data());
#ifndef EVALUATE_JAVA_UDF
#define EVALUATE_JAVA_UDF \
if (data_col->is_column_string()) { \
const ColumnString* str_col = assert_cast<const ColumnString*>(data_col.get()); \
ColumnString::Chars& chars = const_cast<ColumnString::Chars&>(str_col->get_chars()); \
ColumnString::Offsets& offsets = \
const_cast<ColumnString::Offsets&>(str_col->get_offsets()); \
int increase_buffer_size = 0; \
int64_t buffer_size = JniUtil::IncreaseReservedBufferSize(increase_buffer_size); \
chars.resize(buffer_size); \
offsets.resize(num_rows); \
*(jni_ctx->output_value_buffer) = reinterpret_cast<int64_t>(chars.data()); \
*(jni_ctx->output_offsets_ptr) = reinterpret_cast<int64_t>(offsets.data()); \
jni_ctx->output_intermediate_state_ptr->row_idx = 0; \
jni_ctx->output_intermediate_state_ptr->buffer_size = buffer_size; \
env->CallNonvirtualVoidMethodA(jni_ctx->executor, jni_env->executor_cl, \
jni_env->executor_evaluate_id, nullptr); \
while (jni_ctx->output_intermediate_state_ptr->row_idx < num_rows) { \
increase_buffer_size++; \
buffer_size = JniUtil::IncreaseReservedBufferSize(increase_buffer_size); \
chars.resize(buffer_size); \
*(jni_ctx->output_value_buffer) = reinterpret_cast<int64_t>(chars.data()); \
jni_ctx->output_intermediate_state_ptr->buffer_size = buffer_size; \
env->CallNonvirtualVoidMethodA(jni_ctx->executor, jni_env->executor_cl, \
jni_env->executor_evaluate_id, nullptr); \
} \
} else if (data_col->is_numeric() || data_col->is_column_decimal()) { \
data_col->resize(num_rows); \
*(jni_ctx->output_value_buffer) = \
reinterpret_cast<int64_t>(data_col->get_raw_data().data); \
env->CallNonvirtualVoidMethodA(jni_ctx->executor, jni_env->executor_cl, \
jni_env->executor_evaluate_id, nullptr); \
} else if (data_col->is_column_array()) { \
ColumnArray* array_col = assert_cast<ColumnArray*>(data_col.get()); \
ColumnNullable& array_nested_nullable = \
assert_cast<ColumnNullable&>(array_col->get_data()); \
auto data_column_null_map = array_nested_nullable.get_null_map_column_ptr(); \
auto data_column = array_nested_nullable.get_nested_column_ptr(); \
auto& offset_column = array_col->get_offsets_column(); \
int increase_buffer_size = 0; \
int64_t buffer_size = JniUtil::IncreaseReservedBufferSize(increase_buffer_size); \
offset_column.resize(num_rows); \
*(jni_ctx->output_offsets_ptr) = \
reinterpret_cast<int64_t>(offset_column.get_raw_data().data); \
data_column_null_map->resize(buffer_size); \
auto& null_map_data = \
assert_cast<ColumnVector<UInt8>*>(data_column_null_map.get())->get_data(); \
*(jni_ctx->output_array_null_ptr) = reinterpret_cast<int64_t>(null_map_data.data()); \
jni_ctx->output_intermediate_state_ptr->row_idx = 0; \
jni_ctx->output_intermediate_state_ptr->buffer_size = buffer_size; \
if (data_column->is_column_string()) { \
ColumnString* str_col = assert_cast<ColumnString*>(data_column.get()); \
ColumnString::Chars& chars = assert_cast<ColumnString::Chars&>(str_col->get_chars()); \
ColumnString::Offsets& offsets = \
assert_cast<ColumnString::Offsets&>(str_col->get_offsets()); \
chars.resize(buffer_size); \
offsets.resize(buffer_size); \
*(jni_ctx->output_value_buffer) = reinterpret_cast<int64_t>(chars.data()); \
*(jni_ctx->output_array_string_offsets_ptr) = \
reinterpret_cast<int64_t>(offsets.data()); \
env->CallNonvirtualVoidMethodA(jni_ctx->executor, jni_env->executor_cl, \
jni_env->executor_evaluate_id, nullptr); \
while (jni_ctx->output_intermediate_state_ptr->row_idx < num_rows) { \
increase_buffer_size++; \
buffer_size = JniUtil::IncreaseReservedBufferSize(increase_buffer_size); \
null_map_data.resize(buffer_size); \
chars.resize(buffer_size); \
offsets.resize(buffer_size); \
*(jni_ctx->output_array_null_ptr) = \
reinterpret_cast<int64_t>(null_map_data.data()); \
*(jni_ctx->output_value_buffer) = reinterpret_cast<int64_t>(chars.data()); \
*(jni_ctx->output_array_string_offsets_ptr) = \
reinterpret_cast<int64_t>(offsets.data()); \
jni_ctx->output_intermediate_state_ptr->buffer_size = buffer_size; \
env->CallNonvirtualVoidMethodA(jni_ctx->executor, jni_env->executor_cl, \
jni_env->executor_evaluate_id, nullptr); \
} \
} else { \
data_column->resize(buffer_size); \
*(jni_ctx->output_value_buffer) = \
reinterpret_cast<int64_t>(data_column->get_raw_data().data); \
env->CallNonvirtualVoidMethodA(jni_ctx->executor, jni_env->executor_cl, \
jni_env->executor_evaluate_id, nullptr); \
while (jni_ctx->output_intermediate_state_ptr->row_idx < num_rows) { \
increase_buffer_size++; \
buffer_size = JniUtil::IncreaseReservedBufferSize(increase_buffer_size); \
null_map_data.resize(buffer_size); \
data_column->resize(buffer_size); \
*(jni_ctx->output_array_null_ptr) = \
reinterpret_cast<int64_t>(null_map_data.data()); \
*(jni_ctx->output_value_buffer) = \
reinterpret_cast<int64_t>(data_column->get_raw_data().data); \
jni_ctx->output_intermediate_state_ptr->buffer_size = buffer_size; \
env->CallNonvirtualVoidMethodA(jni_ctx->executor, jni_env->executor_cl, \
jni_env->executor_evaluate_id, nullptr); \
} \
} \
} else { \
return Status::InvalidArgument(strings::Substitute( \
"Java UDF doesn't support return type $0 now !", return_type->get_name())); \
env->SetObjectArrayElement(arg_objects, arg_idx, arr_obj);
env->DeleteLocalRef(arr_obj);
}
#endif
EVALUATE_JAVA_UDF;
block.replace_by_position(result,
ColumnNullable::create(std::move(data_col), std::move(null_col)));
RETURN_IF_ERROR(JniUtil::GetJniExceptionMsg(env));
// evaluate with argument object
jobjectArray result_obj = (jobjectArray)env->CallNonvirtualObjectMethod(
jni_ctx->executor, jni_env->executor_cl, jni_env->executor_evaluate_id, num_rows,
arg_objects);
env->DeleteLocalRef(arg_objects);
RETURN_IF_ERROR(JniUtil::GetJniExceptionMsg(env));
auto return_type = block.get_data_type(result);
bool result_nullable = return_type->is_nullable();
ColumnUInt8::MutablePtr null_col = nullptr;
if (result_nullable) {
return_type = remove_nullable(return_type);
null_col = ColumnUInt8::create(num_rows, 0);
memset(null_col->get_data().data(), 0, num_rows);
nullmap_address = reinterpret_cast<int64_t>(null_col->get_data().data());
}
auto res_col = return_type->create_column();
res_col->resize(num_rows);
//could resize for column firstly, copy batch result into column
if (res_col->is_numeric() || res_col->is_column_decimal()) {
env->CallNonvirtualVoidMethod(jni_ctx->executor, jni_env->executor_cl,
jni_env->executor_result_basic_batch_id, result_nullable,
num_rows, result_obj, nullmap_address,
reinterpret_cast<int64_t>(res_col->get_raw_data().data), 0);
} else if (res_col->is_column_string()) {
const ColumnString* str_col = assert_cast<const ColumnString*>(res_col.get());
ColumnString::Chars& chars = const_cast<ColumnString::Chars&>(str_col->get_chars());
ColumnString::Offsets& offsets = const_cast<ColumnString::Offsets&>(str_col->get_offsets());
env->CallNonvirtualVoidMethod(
jni_ctx->executor, jni_env->executor_cl, jni_env->executor_result_basic_batch_id,
result_nullable, num_rows, result_obj, nullmap_address,
reinterpret_cast<int64_t>(&chars), reinterpret_cast<int64_t>(offsets.data()));
} else if (res_col->is_column_array()) {
ColumnArray* array_col = assert_cast<ColumnArray*>(res_col.get());
ColumnNullable& array_nested_nullable = assert_cast<ColumnNullable&>(array_col->get_data());
auto data_column_null_map = array_nested_nullable.get_null_map_column_ptr();
auto data_column = array_nested_nullable.get_nested_column_ptr();
auto& offset_column = array_col->get_offsets_column();
auto offset_address = reinterpret_cast<int64_t>(offset_column.get_raw_data().data);
auto& null_map_data =
assert_cast<ColumnVector<UInt8>*>(data_column_null_map.get())->get_data();
auto nested_nullmap_address = reinterpret_cast<int64_t>(null_map_data.data());
jmethodID list_size = env->GetMethodID(arraylist_class, "size", "()I");
int element_size = 0; // get all element size in num_rows of array column
for (int i = 0; i < num_rows; ++i) {
jobject obj = env->GetObjectArrayElement(result_obj, i);
if (obj == nullptr) {
continue;
}
element_size = element_size + env->CallIntMethod(obj, list_size);
env->DeleteLocalRef(obj);
}
array_nested_nullable.resize(element_size);
memset(null_map_data.data(), 0, element_size);
int64_t nested_data_address = 0, nested_offset_address = 0;
// array type need pass address: [nullmap_address], offset_address, nested_nullmap_address, nested_data_address/nested_char_address,nested_offset_address
if (data_column->is_column_string()) {
ColumnString* str_col = assert_cast<ColumnString*>(data_column.get());
ColumnString::Chars& chars = assert_cast<ColumnString::Chars&>(str_col->get_chars());
ColumnString::Offsets& offsets =
assert_cast<ColumnString::Offsets&>(str_col->get_offsets());
nested_data_address = reinterpret_cast<int64_t>(&chars);
nested_offset_address = reinterpret_cast<int64_t>(offsets.data());
} else {
nested_data_address = reinterpret_cast<int64_t>(data_column->get_raw_data().data);
}
env->CallNonvirtualVoidMethod(
jni_ctx->executor, jni_env->executor_cl, jni_env->executor_result_array_batch_id,
result_nullable, num_rows, result_obj, nullmap_address, offset_address,
nested_nullmap_address, nested_data_address, nested_offset_address);
} else {
*(jni_ctx->output_null_value) = -1;
auto data_col = return_type->create_column();
EVALUATE_JAVA_UDF;
block.replace_by_position(result, std::move(data_col));
return Status::InvalidArgument(strings::Substitute(
"Java UDF doesn't support return type $0 now !", return_type->get_name()));
}
env->DeleteLocalRef(result_obj);
env->DeleteLocalRef(obj_class);
env->DeleteLocalRef(arraylist_class);
if (result_nullable) {
block.replace_by_position(result,
ColumnNullable::create(std::move(res_col), std::move(null_col)));
} else {
block.replace_by_position(result, std::move(res_col));
}
return JniUtil::GetJniExceptionMsg(env);
}