[vectorized](udf) java udf support array type (#16841)
This commit is contained in:
@ -27,6 +27,8 @@
|
||||
#include "runtime/exec_env.h"
|
||||
#include "runtime/user_function_cache.h"
|
||||
#include "util/jni-util.h"
|
||||
#include "vec/columns/column_array.h"
|
||||
#include "vec/columns/column_nullable.h"
|
||||
#include "vec/columns/column_vector.h"
|
||||
#include "vec/core/block.h"
|
||||
#include "vec/data_types/data_type_bitmap.h"
|
||||
@ -78,9 +80,16 @@ Status JavaFunctionCall::prepare(FunctionContext* context,
|
||||
ctor_params.__set_input_offsets_ptrs((int64_t)jni_ctx->input_offsets_ptrs.get());
|
||||
ctor_params.__set_input_buffer_ptrs((int64_t)jni_ctx->input_values_buffer_ptr.get());
|
||||
ctor_params.__set_input_nulls_ptrs((int64_t)jni_ctx->input_nulls_buffer_ptr.get());
|
||||
ctor_params.__set_input_array_nulls_buffer_ptr(
|
||||
(int64_t)jni_ctx->input_array_nulls_buffer_ptr.get());
|
||||
ctor_params.__set_input_array_string_offsets_ptrs(
|
||||
(int64_t)jni_ctx->input_array_string_offsets_ptrs.get());
|
||||
ctor_params.__set_output_buffer_ptr((int64_t)jni_ctx->output_value_buffer.get());
|
||||
ctor_params.__set_output_null_ptr((int64_t)jni_ctx->output_null_value.get());
|
||||
ctor_params.__set_output_offsets_ptr((int64_t)jni_ctx->output_offsets_ptr.get());
|
||||
ctor_params.__set_output_array_null_ptr((int64_t)jni_ctx->output_array_null_ptr.get());
|
||||
ctor_params.__set_output_array_string_offsets_ptr(
|
||||
(int64_t)jni_ctx->output_array_string_offsets_ptr.get());
|
||||
ctor_params.__set_output_intermediate_state_ptr(
|
||||
(int64_t)jni_ctx->output_intermediate_state_ptr.get());
|
||||
ctor_params.__set_batch_size_ptr((int64_t)jni_ctx->batch_size_ptr.get());
|
||||
@ -142,6 +151,31 @@ Status JavaFunctionCall::execute(FunctionContext* context, Block& block,
|
||||
} else if (data_cols[arg_idx]->is_numeric() || data_cols[arg_idx]->is_column_decimal()) {
|
||||
jni_ctx->input_values_buffer_ptr.get()[arg_idx] =
|
||||
reinterpret_cast<int64_t>(data_cols[arg_idx]->get_raw_data().data);
|
||||
} else if (data_cols[arg_idx]->is_column_array()) {
|
||||
const ColumnArray* array_col =
|
||||
assert_cast<const ColumnArray*>(data_cols[arg_idx].get());
|
||||
jni_ctx->input_offsets_ptrs.get()[arg_idx] =
|
||||
reinterpret_cast<int64_t>(array_col->get_offsets_column().get_raw_data().data);
|
||||
const ColumnNullable& array_nested_nullable =
|
||||
assert_cast<const ColumnNullable&>(array_col->get_data());
|
||||
auto data_column_null_map = array_nested_nullable.get_null_map_column_ptr();
|
||||
auto data_column = array_nested_nullable.get_nested_column_ptr();
|
||||
jni_ctx->input_array_nulls_buffer_ptr.get()[arg_idx] = reinterpret_cast<int64_t>(
|
||||
check_and_get_column<ColumnVector<UInt8>>(data_column_null_map)
|
||||
->get_data()
|
||||
.data());
|
||||
|
||||
//need pass FE, nullamp and offset, chars
|
||||
if (data_column->is_column_string()) {
|
||||
const ColumnString* col = assert_cast<const ColumnString*>(data_column.get());
|
||||
jni_ctx->input_values_buffer_ptr.get()[arg_idx] =
|
||||
reinterpret_cast<int64_t>(col->get_chars().data());
|
||||
jni_ctx->input_array_string_offsets_ptrs.get()[arg_idx] =
|
||||
reinterpret_cast<int64_t>(col->get_offsets().data());
|
||||
} else {
|
||||
jni_ctx->input_values_buffer_ptr.get()[arg_idx] =
|
||||
reinterpret_cast<int64_t>(data_column->get_raw_data().data);
|
||||
}
|
||||
} else {
|
||||
return Status::InvalidArgument(
|
||||
strings::Substitute("Java UDF doesn't support type $0 now !",
|
||||
@ -155,7 +189,6 @@ Status JavaFunctionCall::execute(FunctionContext* context, Block& block,
|
||||
auto null_type = std::reinterpret_pointer_cast<const DataTypeNullable>(return_type);
|
||||
auto data_col = null_type->get_nested_type()->create_column();
|
||||
auto null_col = ColumnUInt8::create(data_col->size(), 0);
|
||||
null_col->reserve(num_rows);
|
||||
null_col->resize(num_rows);
|
||||
|
||||
*(jni_ctx->output_null_value) = reinterpret_cast<int64_t>(null_col->get_data().data());
|
||||
@ -168,9 +201,7 @@ Status JavaFunctionCall::execute(FunctionContext* context, Block& block,
|
||||
const_cast<ColumnString::Offsets&>(str_col->get_offsets()); \
|
||||
int increase_buffer_size = 0; \
|
||||
int32_t buffer_size = JniUtil::IncreaseReservedBufferSize(increase_buffer_size); \
|
||||
chars.reserve(buffer_size); \
|
||||
chars.resize(buffer_size); \
|
||||
offsets.reserve(num_rows); \
|
||||
offsets.resize(num_rows); \
|
||||
*(jni_ctx->output_value_buffer) = reinterpret_cast<int64_t>(chars.data()); \
|
||||
*(jni_ctx->output_offsets_ptr) = reinterpret_cast<int64_t>(offsets.data()); \
|
||||
@ -188,12 +219,74 @@ Status JavaFunctionCall::execute(FunctionContext* context, Block& block,
|
||||
nullptr); \
|
||||
} \
|
||||
} else if (data_col->is_numeric() || data_col->is_column_decimal()) { \
|
||||
data_col->reserve(num_rows); \
|
||||
data_col->resize(num_rows); \
|
||||
*(jni_ctx->output_value_buffer) = \
|
||||
reinterpret_cast<int64_t>(data_col->get_raw_data().data); \
|
||||
env->CallNonvirtualVoidMethodA(jni_ctx->executor, executor_cl_, executor_evaluate_id_, \
|
||||
nullptr); \
|
||||
} else if (data_col->is_column_array()) { \
|
||||
ColumnArray* array_col = assert_cast<ColumnArray*>(data_col.get()); \
|
||||
ColumnNullable& array_nested_nullable = \
|
||||
assert_cast<ColumnNullable&>(array_col->get_data()); \
|
||||
auto data_column_null_map = array_nested_nullable.get_null_map_column_ptr(); \
|
||||
auto data_column = array_nested_nullable.get_nested_column_ptr(); \
|
||||
auto& offset_column = array_col->get_offsets_column(); \
|
||||
int increase_buffer_size = 0; \
|
||||
int32_t buffer_size = JniUtil::IncreaseReservedBufferSize(increase_buffer_size); \
|
||||
offset_column.resize(num_rows); \
|
||||
*(jni_ctx->output_offsets_ptr) = \
|
||||
reinterpret_cast<int64_t>(offset_column.get_raw_data().data); \
|
||||
data_column_null_map->resize(buffer_size); \
|
||||
auto& null_map_data = \
|
||||
assert_cast<ColumnVector<UInt8>*>(data_column_null_map.get())->get_data(); \
|
||||
*(jni_ctx->output_array_null_ptr) = reinterpret_cast<int64_t>(null_map_data.data()); \
|
||||
jni_ctx->output_intermediate_state_ptr->row_idx = 0; \
|
||||
jni_ctx->output_intermediate_state_ptr->buffer_size = buffer_size; \
|
||||
if (data_column->is_column_string()) { \
|
||||
ColumnString* str_col = assert_cast<ColumnString*>(data_column.get()); \
|
||||
ColumnString::Chars& chars = assert_cast<ColumnString::Chars&>(str_col->get_chars()); \
|
||||
ColumnString::Offsets& offsets = \
|
||||
assert_cast<ColumnString::Offsets&>(str_col->get_offsets()); \
|
||||
chars.resize(buffer_size); \
|
||||
offsets.resize(buffer_size); \
|
||||
*(jni_ctx->output_value_buffer) = reinterpret_cast<int64_t>(chars.data()); \
|
||||
*(jni_ctx->output_array_string_offsets_ptr) = \
|
||||
reinterpret_cast<int64_t>(offsets.data()); \
|
||||
env->CallNonvirtualVoidMethodA(jni_ctx->executor, executor_cl_, executor_evaluate_id_, \
|
||||
nullptr); \
|
||||
while (jni_ctx->output_intermediate_state_ptr->row_idx < num_rows) { \
|
||||
increase_buffer_size++; \
|
||||
buffer_size = JniUtil::IncreaseReservedBufferSize(increase_buffer_size); \
|
||||
null_map_data.resize(buffer_size); \
|
||||
chars.resize(buffer_size); \
|
||||
offsets.resize(buffer_size); \
|
||||
*(jni_ctx->output_array_null_ptr) = \
|
||||
reinterpret_cast<int64_t>(null_map_data.data()); \
|
||||
*(jni_ctx->output_value_buffer) = reinterpret_cast<int64_t>(chars.data()); \
|
||||
jni_ctx->output_intermediate_state_ptr->buffer_size = buffer_size; \
|
||||
env->CallNonvirtualVoidMethodA(jni_ctx->executor, executor_cl_, \
|
||||
executor_evaluate_id_, nullptr); \
|
||||
} \
|
||||
} else { \
|
||||
data_column->resize(buffer_size); \
|
||||
*(jni_ctx->output_value_buffer) = \
|
||||
reinterpret_cast<int64_t>(data_column->get_raw_data().data); \
|
||||
env->CallNonvirtualVoidMethodA(jni_ctx->executor, executor_cl_, executor_evaluate_id_, \
|
||||
nullptr); \
|
||||
while (jni_ctx->output_intermediate_state_ptr->row_idx < num_rows) { \
|
||||
increase_buffer_size++; \
|
||||
buffer_size = JniUtil::IncreaseReservedBufferSize(increase_buffer_size); \
|
||||
null_map_data.resize(buffer_size); \
|
||||
data_column->resize(buffer_size); \
|
||||
*(jni_ctx->output_array_null_ptr) = \
|
||||
reinterpret_cast<int64_t>(null_map_data.data()); \
|
||||
*(jni_ctx->output_value_buffer) = \
|
||||
reinterpret_cast<int64_t>(data_column->get_raw_data().data); \
|
||||
jni_ctx->output_intermediate_state_ptr->buffer_size = buffer_size; \
|
||||
env->CallNonvirtualVoidMethodA(jni_ctx->executor, executor_cl_, \
|
||||
executor_evaluate_id_, nullptr); \
|
||||
} \
|
||||
} \
|
||||
} else { \
|
||||
return Status::InvalidArgument(strings::Substitute( \
|
||||
"Java UDF doesn't support return type $0 now !", return_type->get_name())); \
|
||||
|
||||
Reference in New Issue
Block a user