[feature-wip][UDF][DIP-1] Support variable-size input and output for Java UDF (#8678)

This feature is proposed in DSIP-1. This PR support variable-length input and output Java UDF.
This commit is contained in:
Gabriel
2022-04-11 09:36:16 +08:00
committed by GitHub
parent 174e22b9f0
commit 0d761f9909
10 changed files with 545 additions and 155 deletions

View File

@ -77,12 +77,15 @@ Status JavaFunctionCall::prepare(FunctionContext* context, FunctionContext::Func
TJavaUdfExecutorCtorParams ctor_params;
ctor_params.__set_fn(fn_);
ctor_params.__set_location(local_location);
ctor_params.__set_input_byte_offsets(jni_ctx->input_byte_offsets_ptr);
ctor_params.__set_input_buffer_ptrs(jni_ctx->input_values_buffer_ptr);
ctor_params.__set_input_nulls_ptrs(jni_ctx->input_nulls_buffer_ptr);
ctor_params.__set_output_buffer_ptr(jni_ctx->output_value_buffer);
ctor_params.__set_output_null_ptr(jni_ctx->output_null_value);
ctor_params.__set_batch_size_ptr(jni_ctx->batch_size_ptr);
ctor_params.__set_input_offsets_ptrs((int64_t) jni_ctx->input_offsets_ptrs.get());
ctor_params.__set_input_buffer_ptrs((int64_t) jni_ctx->input_values_buffer_ptr.get());
ctor_params.__set_input_nulls_ptrs((int64_t) jni_ctx->input_nulls_buffer_ptr.get());
ctor_params.__set_output_buffer_ptr((int64_t) jni_ctx->output_value_buffer.get());
ctor_params.__set_output_null_ptr((int64_t) jni_ctx->output_null_value.get());
ctor_params.__set_output_offsets_ptr((int64_t) jni_ctx->output_offsets_ptr.get());
ctor_params.__set_output_intermediate_state_ptr(
(int64_t) jni_ctx->output_intermediate_state_ptr.get());
ctor_params.__set_batch_size_ptr((int64_t) jni_ctx->batch_size_ptr.get());
jbyteArray ctor_params_bytes;
@ -100,11 +103,6 @@ Status JavaFunctionCall::prepare(FunctionContext* context, FunctionContext::Func
Status JavaFunctionCall::execute(FunctionContext* context, Block& block, const ColumnNumbers& arguments,
size_t result, size_t num_rows, bool dry_run) {
auto return_type = block.get_data_type(result);
if (!return_type->have_maximum_size_of_value()) {
return Status::InvalidArgument(strings::Substitute(
"Java UDF doesn't support return type $0 now !", return_type->get_name()));
}
JNIEnv* env;
RETURN_IF_ERROR(JniUtil::GetJNIEnv(&env));
JniContext* jni_ctx = reinterpret_cast<JniContext*>(
@ -119,50 +117,94 @@ Status JavaFunctionCall::execute(FunctionContext* context, Block& block, const C
arg_idx, column.type->get_name(),
_argument_types[arg_idx]->get_name()));
}
if (!column.type->have_maximum_size_of_value()) {
return Status::InvalidArgument(strings::Substitute(
"Java UDF doesn't support input type $0 now !", return_type->get_name()));
}
auto data_col = col;
if (auto* nullable = check_and_get_column<const ColumnNullable>(*col)) {
data_col = nullable->get_nested_column_ptr();
auto null_col =
check_and_get_column<ColumnVector<UInt8>>(nullable->get_null_map_column_ptr());
((int64_t*) jni_ctx->input_nulls_buffer_ptr)[arg_idx] =
jni_ctx->input_nulls_buffer_ptr.get()[arg_idx] =
reinterpret_cast<int64_t>(null_col->get_data().data());
} else {
jni_ctx->input_nulls_buffer_ptr.get()[arg_idx] = -1;
}
if (const ColumnString* str_col = check_and_get_column<ColumnString>(data_col.get())) {
jni_ctx->input_values_buffer_ptr.get()[arg_idx] =
reinterpret_cast<int64_t>(str_col->get_chars().data());
jni_ctx->input_offsets_ptrs.get()[arg_idx] =
reinterpret_cast<int64_t>(str_col->get_offsets().data());
} else if (data_col->is_numeric()) {
jni_ctx->input_values_buffer_ptr.get()[arg_idx] =
reinterpret_cast<int64_t>(data_col->get_raw_data().data);
} else {
return Status::InvalidArgument(strings::Substitute(
"Java UDF doesn't support type $0 now !", _argument_types[arg_idx]->get_name()));
}
((int64_t*) jni_ctx->input_values_buffer_ptr)[arg_idx] =
reinterpret_cast<int64_t>(data_col->get_raw_data().data);
arg_idx++;
}
*(jni_ctx->batch_size_ptr) = num_rows;
auto return_type = block.get_data_type(result);
if (return_type->is_nullable()) {
auto null_type = std::reinterpret_pointer_cast<const DataTypeNullable>(return_type);
auto data_col = null_type->get_nested_type()->create_column();
auto null_col = ColumnUInt8::create(data_col->size(), 0);
null_col->reserve(num_rows);
null_col->resize(num_rows);
data_col->reserve(num_rows);
data_col->resize(num_rows);
*((int64_t*) jni_ctx->output_null_value) =
*(jni_ctx->output_null_value) =
reinterpret_cast<int64_t>(null_col->get_data().data());
*((int64_t*) jni_ctx->output_value_buffer) = reinterpret_cast<int64_t>(data_col->get_raw_data().data);
#ifndef EVALUATE_JAVA_UDF
#define EVALUATE_JAVA_UDF \
if (const ColumnString* str_col = check_and_get_column<ColumnString>(data_col.get())) { \
ColumnString::Chars& chars = const_cast<ColumnString::Chars&>(str_col->get_chars()); \
ColumnString::Offsets& offsets = \
const_cast<ColumnString::Offsets&>(str_col->get_offsets()); \
int increase_buffer_size = 0; \
int32_t buffer_size = \
JavaFunctionCall::IncreaseReservedBufferSize(increase_buffer_size); \
chars.reserve(buffer_size); \
chars.resize(buffer_size); \
offsets.reserve(num_rows); \
offsets.resize(num_rows); \
*(jni_ctx->output_value_buffer) = \
reinterpret_cast<int64_t>(chars.data()); \
*(jni_ctx->output_offsets_ptr) = \
reinterpret_cast<int64_t>(offsets.data()); \
jni_ctx->output_intermediate_state_ptr->row_idx = 0; \
jni_ctx->output_intermediate_state_ptr->buffer_size = buffer_size; \
env->CallNonvirtualVoidMethodA( \
jni_ctx->executor, executor_cl_, executor_evaluate_id_, nullptr); \
while (jni_ctx->output_intermediate_state_ptr->row_idx < num_rows) { \
increase_buffer_size++; \
int32_t buffer_size = \
JavaFunctionCall::IncreaseReservedBufferSize(increase_buffer_size); \
chars.resize(buffer_size); \
*(jni_ctx->output_value_buffer) = \
reinterpret_cast<int64_t>(chars.data()); \
jni_ctx->output_intermediate_state_ptr->buffer_size = buffer_size; \
env->CallNonvirtualVoidMethodA( \
jni_ctx->executor, executor_cl_, executor_evaluate_id_, nullptr); \
} \
} else if (data_col->is_numeric()) { \
data_col->reserve(num_rows); \
data_col->resize(num_rows); \
*(jni_ctx->output_value_buffer) = \
reinterpret_cast<int64_t>(data_col->get_raw_data().data); \
env->CallNonvirtualVoidMethodA( \
jni_ctx->executor, executor_cl_, executor_evaluate_id_, nullptr); \
} else { \
return Status::InvalidArgument(strings::Substitute( \
"Java UDF doesn't support return type $0 now !", return_type->get_name())); \
}
#endif
EVALUATE_JAVA_UDF;
block.replace_by_position(result,
ColumnNullable::create(std::move(data_col), std::move(null_col)));
} else {
*(jni_ctx->output_null_value) = -1;
auto data_col = return_type->create_column();
data_col->reserve(num_rows);
data_col->resize(num_rows);
*((int64_t*) jni_ctx->output_value_buffer) = reinterpret_cast<int64_t>(data_col->get_raw_data().data);
EVALUATE_JAVA_UDF;
block.replace_by_position(result, std::move(data_col));
}
*((int32_t*) jni_ctx->batch_size_ptr) = num_rows;
// Using this version of Call has the lowest overhead. This eliminates the
// vtable lookup and setting up return stacks.
env->CallNonvirtualVoidMethodA(
jni_ctx->executor, executor_cl_, executor_evaluate_id_, nullptr);
return JniUtil::GetJniExceptionMsg(env);
}