[refactor](jni) unified jni framework for java udaf (#25591)

Follow https://github.com/apache/doris/pull/25302, and use the unified jni framework to refactor java udaf.
This PR has removed the old interfaces to run java udf/udaf. Thanks to the ease of use of the new framework, the core code for modifying UDAF does not exceed 100 lines, and the logic is similar to that of UDF.
This commit is contained in:
Ashin Gau
2023-10-20 16:13:40 +08:00
committed by GitHub
parent a11cde7bee
commit a2ceea5951
13 changed files with 329 additions and 3543 deletions

View File

@ -37,6 +37,7 @@
#include "vec/common/string_ref.h"
#include "vec/core/field.h"
#include "vec/core/types.h"
#include "vec/exec/jni_connector.h"
#include "vec/io/io_helper.h"
namespace doris::vectorized {
@ -45,10 +46,10 @@ const char* UDAF_EXECUTOR_CLASS = "org/apache/doris/udf/UdafExecutor";
const char* UDAF_EXECUTOR_CTOR_SIGNATURE = "([B)V";
const char* UDAF_EXECUTOR_CLOSE_SIGNATURE = "()V";
const char* UDAF_EXECUTOR_DESTROY_SIGNATURE = "()V";
const char* UDAF_EXECUTOR_ADD_SIGNATURE = "(ZJJ)V";
const char* UDAF_EXECUTOR_ADD_SIGNATURE = "(ZIIJILjava/util/Map;)V";
const char* UDAF_EXECUTOR_SERIALIZE_SIGNATURE = "(J)[B";
const char* UDAF_EXECUTOR_MERGE_SIGNATURE = "(J[B)V";
const char* UDAF_EXECUTOR_RESULT_SIGNATURE = "(JJ)Z";
const char* UDAF_EXECUTOR_GET_SIGNATURE = "(JLjava/util/Map;)J";
const char* UDAF_EXECUTOR_RESET_SIGNATURE = "(J)V";
// Calling Java method about those signature means: "(argument-types)return-type"
// https://www.iitk.ac.in/esc101/05Aug/tutorial/native1.1/implementing/method.html
@ -60,10 +61,14 @@ public:
~AggregateJavaUdafData() {
JNIEnv* env;
Status status;
RETURN_IF_STATUS_ERROR(status, JniUtil::GetJNIEnv(&env));
if (!JniUtil::GetJNIEnv(&env).ok()) {
LOG(WARNING) << "Failed to get JNIEnv";
}
env->CallNonvirtualVoidMethod(executor_obj, executor_cl, executor_close_id);
RETURN_IF_STATUS_ERROR(status, JniUtil::GetJniExceptionMsg(env));
Status st = JniUtil::GetJniExceptionMsg(env);
if (!st.ok()) {
LOG(WARNING) << "Failed to close JAVA UDAF: " << st.to_string();
}
env->DeleteGlobalRef(executor_cl);
env->DeleteGlobalRef(executor_obj);
}
@ -103,126 +108,24 @@ public:
int place_offset) {
JNIEnv* env = nullptr;
RETURN_NOT_OK_STATUS_WITH_WARN(JniUtil::GetJNIEnv(&env), "Java-Udaf add function");
jclass obj_class = env->FindClass("[Ljava/lang/Object;");
jobjectArray arg_objects = env->NewObjectArray(argument_size, obj_class, nullptr);
int64_t nullmap_address = 0;
for (int arg_idx = 0; arg_idx < argument_size; ++arg_idx) {
bool arg_column_nullable = false;
auto data_col = columns[arg_idx];
if (auto* nullable = check_and_get_column<const ColumnNullable>(*columns[arg_idx])) {
arg_column_nullable = true;
auto null_col = nullable->get_null_map_column_ptr();
data_col = nullable->get_nested_column_ptr();
nullmap_address = reinterpret_cast<int64_t>(
check_and_get_column<ColumnVector<UInt8>>(null_col)->get_data().data());
}
// convert argument column data into java type
jobjectArray arr_obj = nullptr;
if (data_col->is_numeric() || data_col->is_column_decimal()) {
arr_obj = (jobjectArray)env->CallObjectMethod(
executor_obj, executor_convert_basic_argument_id, arg_idx,
arg_column_nullable, row_num_start, row_num_end, nullmap_address,
reinterpret_cast<int64_t>(data_col->get_raw_data().data), 0);
} else if (data_col->is_column_string()) {
const ColumnString* str_col = assert_cast<const ColumnString*>(data_col);
arr_obj = (jobjectArray)env->CallObjectMethod(
executor_obj, executor_convert_basic_argument_id, arg_idx,
arg_column_nullable, row_num_start, row_num_end, nullmap_address,
reinterpret_cast<int64_t>(str_col->get_chars().data()),
reinterpret_cast<int64_t>(str_col->get_offsets().data()));
} else if (data_col->is_column_array()) {
const ColumnArray* array_col = assert_cast<const ColumnArray*>(data_col);
const ColumnNullable& array_nested_nullable =
assert_cast<const ColumnNullable&>(array_col->get_data());
auto data_column_null_map = array_nested_nullable.get_null_map_column_ptr();
auto data_column = array_nested_nullable.get_nested_column_ptr();
auto offset_address = reinterpret_cast<int64_t>(
array_col->get_offsets_column().get_raw_data().data);
auto nested_nullmap_address = reinterpret_cast<int64_t>(
check_and_get_column<ColumnVector<UInt8>>(data_column_null_map)
->get_data()
.data());
int64_t nested_data_address = 0, nested_offset_address = 0;
// array type need pass address: [nullmap_address], offset_address, nested_nullmap_address, nested_data_address/nested_char_address,nested_offset_address
if (data_column->is_column_string()) {
const ColumnString* col = assert_cast<const ColumnString*>(data_column.get());
nested_data_address = reinterpret_cast<int64_t>(col->get_chars().data());
nested_offset_address = reinterpret_cast<int64_t>(col->get_offsets().data());
} else {
nested_data_address =
reinterpret_cast<int64_t>(data_column->get_raw_data().data);
}
arr_obj = (jobjectArray)env->CallObjectMethod(
executor_obj, executor_convert_array_argument_id, arg_idx,
arg_column_nullable, row_num_start, row_num_end, nullmap_address,
offset_address, nested_nullmap_address, nested_data_address,
nested_offset_address);
} else if (data_col->is_column_map()) {
const ColumnMap* map_col = assert_cast<const ColumnMap*>(data_col);
auto offset_address = reinterpret_cast<int64_t>(
map_col->get_offsets_column().get_raw_data().data);
const ColumnNullable& map_key_column_nullable =
assert_cast<const ColumnNullable&>(map_col->get_keys());
auto key_data_column_null_map = map_key_column_nullable.get_null_map_column_ptr();
auto key_data_column = map_key_column_nullable.get_nested_column_ptr();
auto key_nested_nullmap_address = reinterpret_cast<int64_t>(
check_and_get_column<ColumnVector<UInt8>>(key_data_column_null_map)
->get_data()
.data());
int64_t key_nested_data_address = 0, key_nested_offset_address = 0;
if (key_data_column->is_column_string()) {
const ColumnString* col =
assert_cast<const ColumnString*>(key_data_column.get());
key_nested_data_address = reinterpret_cast<int64_t>(col->get_chars().data());
key_nested_offset_address =
reinterpret_cast<int64_t>(col->get_offsets().data());
} else {
key_nested_data_address =
reinterpret_cast<int64_t>(key_data_column->get_raw_data().data);
}
const ColumnNullable& map_value_column_nullable =
assert_cast<const ColumnNullable&>(map_col->get_values());
auto value_data_column_null_map =
map_value_column_nullable.get_null_map_column_ptr();
auto value_data_column = map_value_column_nullable.get_nested_column_ptr();
auto value_nested_nullmap_address = reinterpret_cast<int64_t>(
check_and_get_column<ColumnVector<UInt8>>(value_data_column_null_map)
->get_data()
.data());
int64_t value_nested_data_address = 0, value_nested_offset_address = 0;
if (value_data_column->is_column_string()) {
const ColumnString* col =
assert_cast<const ColumnString*>(value_data_column.get());
value_nested_data_address = reinterpret_cast<int64_t>(col->get_chars().data());
value_nested_offset_address =
reinterpret_cast<int64_t>(col->get_offsets().data());
} else {
value_nested_data_address =
reinterpret_cast<int64_t>(value_data_column->get_raw_data().data);
}
arr_obj = (jobjectArray)env->CallObjectMethod(
executor_obj, executor_convert_map_argument_id, arg_idx,
arg_column_nullable, row_num_start, row_num_end, nullmap_address,
offset_address, key_nested_nullmap_address, key_nested_data_address,
key_nested_offset_address, value_nested_nullmap_address,
value_nested_data_address, value_nested_offset_address);
} else {
return Status::InvalidArgument(
strings::Substitute("Java UDAF doesn't support type is $0 now !",
argument_types[arg_idx]->get_name()));
}
env->SetObjectArrayElement(arg_objects, arg_idx, arr_obj);
env->DeleteLocalRef(arr_obj);
Block input_block;
for (size_t i = 0; i < argument_size; ++i) {
input_block.insert(ColumnWithTypeAndName(columns[i]->get_ptr(), argument_types[i],
std::to_string(i)));
}
RETURN_IF_ERROR(JniUtil::GetJniExceptionMsg(env));
std::unique_ptr<long[]> input_table;
RETURN_IF_ERROR(JniConnector::to_java_table(&input_block, input_table));
auto input_table_schema = JniConnector::parse_table_schema(&input_block);
std::map<String, String> input_params = {
{"meta_address", std::to_string((long)input_table.get())},
{"required_fields", input_table_schema.first},
{"columns_types", input_table_schema.second}};
jobject input_map = JniUtil::convert_to_java_map(env, input_params);
// invoke add batch
env->CallObjectMethod(executor_obj, executor_add_batch_id, is_single_place, row_num_start,
row_num_end, places_address, place_offset, arg_objects);
env->DeleteLocalRef(arg_objects);
env->DeleteLocalRef(obj_class);
row_num_end, places_address, place_offset, input_map);
env->DeleteLocalRef(input_map);
return JniUtil::GetJniExceptionMsg(env);
}
@ -275,160 +178,33 @@ public:
}
Status get(IColumn& to, const DataTypePtr& result_type, int64_t place) const {
to.insert_default();
JNIEnv* env = nullptr;
RETURN_NOT_OK_STATUS_WITH_WARN(JniUtil::GetJNIEnv(&env), "Java-Udaf get value function");
int64_t nullmap_address = 0;
if (result_type->is_nullable()) {
auto& nullable = assert_cast<ColumnNullable&>(to);
nullmap_address =
reinterpret_cast<int64_t>(nullable.get_null_map_column().get_raw_data().data);
auto& data_col = nullable.get_nested_column();
RETURN_IF_ERROR(get_result(to, result_type, place, env, data_col, nullmap_address));
} else {
nullmap_address = -1;
auto& data_col = to;
RETURN_IF_ERROR(get_result(to, result_type, place, env, data_col, nullmap_address));
}
return JniUtil::GetJniExceptionMsg(env);
Block output_block;
output_block.insert(ColumnWithTypeAndName(to.get_ptr(), result_type, "_result_"));
auto output_table_schema = JniConnector::parse_table_schema(&output_block);
std::string output_nullable = result_type->is_nullable() ? "true" : "false";
std::map<String, String> output_params = {{"is_nullable", output_nullable},
{"required_fields", output_table_schema.first},
{"columns_types", output_table_schema.second}};
jobject output_map = JniUtil::convert_to_java_map(env, output_params);
long output_address =
env->CallLongMethod(executor_obj, executor_get_value_id, place, output_map);
RETURN_IF_ERROR(JniUtil::GetJniExceptionMsg(env));
env->DeleteLocalRef(output_map);
return JniConnector::fill_block(&output_block, {0}, output_address);
}
private:
Status get_result(IColumn& to, const DataTypePtr& return_type, int64_t place, JNIEnv* env,
IColumn& data_col, int64_t nullmap_address) const {
jobject result_obj = env->CallNonvirtualObjectMethod(executor_obj, executor_cl,
executor_get_value_id, place);
bool result_nullable = return_type->is_nullable();
if (data_col.is_column_string()) {
const ColumnString* str_col = check_and_get_column<ColumnString>(data_col);
ColumnString::Chars& chars = const_cast<ColumnString::Chars&>(str_col->get_chars());
ColumnString::Offsets& offsets =
const_cast<ColumnString::Offsets&>(str_col->get_offsets());
int increase_buffer_size = 0;
int64_t buffer_size = JniUtil::IncreaseReservedBufferSize(increase_buffer_size);
chars.resize(buffer_size);
env->CallNonvirtualVoidMethod(
executor_obj, executor_cl, executor_copy_basic_result_id, result_obj,
to.size() - 1, nullmap_address, reinterpret_cast<int64_t>(chars.data()),
reinterpret_cast<int64_t>(&chars), reinterpret_cast<int64_t>(offsets.data()));
} else if (data_col.is_numeric() || data_col.is_column_decimal()) {
env->CallNonvirtualVoidMethod(executor_obj, executor_cl, executor_copy_basic_result_id,
result_obj, to.size() - 1, nullmap_address,
reinterpret_cast<int64_t>(data_col.get_raw_data().data),
0, 0);
} else if (data_col.is_column_array()) {
jclass arraylist_class = env->FindClass("Ljava/util/ArrayList;");
ColumnArray* array_col = assert_cast<ColumnArray*>(&data_col);
ColumnNullable& array_nested_nullable =
assert_cast<ColumnNullable&>(array_col->get_data());
auto data_column_null_map = array_nested_nullable.get_null_map_column_ptr();
auto data_column = array_nested_nullable.get_nested_column_ptr();
auto& offset_column = array_col->get_offsets_column();
auto offset_address = reinterpret_cast<int64_t>(offset_column.get_raw_data().data);
auto& null_map_data =
assert_cast<ColumnVector<UInt8>*>(data_column_null_map.get())->get_data();
auto nested_nullmap_address = reinterpret_cast<int64_t>(null_map_data.data());
jmethodID list_size = env->GetMethodID(arraylist_class, "size", "()I");
size_t has_put_element_size = array_col->get_offsets().back();
size_t arrar_list_size = env->CallIntMethod(result_obj, list_size);
size_t element_size = has_put_element_size + arrar_list_size;
array_nested_nullable.resize(element_size);
memset(null_map_data.data() + has_put_element_size, 0, arrar_list_size);
int64_t nested_data_address = 0, nested_offset_address = 0;
if (data_column->is_column_string()) {
ColumnString* str_col = assert_cast<ColumnString*>(data_column.get());
ColumnString::Chars& chars =
assert_cast<ColumnString::Chars&>(str_col->get_chars());
ColumnString::Offsets& offsets =
assert_cast<ColumnString::Offsets&>(str_col->get_offsets());
nested_data_address = reinterpret_cast<int64_t>(&chars);
nested_offset_address = reinterpret_cast<int64_t>(offsets.data());
} else {
nested_data_address = reinterpret_cast<int64_t>(data_column->get_raw_data().data);
}
int row = to.size() - 1;
env->CallNonvirtualVoidMethod(executor_obj, executor_cl, executor_copy_array_result_id,
has_put_element_size, result_nullable, row, result_obj,
nullmap_address, offset_address, nested_nullmap_address,
nested_data_address, nested_offset_address);
env->DeleteLocalRef(arraylist_class);
} else if (data_col.is_column_map()) {
jclass hashmap_class = env->FindClass("Ljava/util/HashMap;");
ColumnMap* map_col = assert_cast<ColumnMap*>(&data_col);
auto& offset_column = map_col->get_offsets_column();
auto offset_address = reinterpret_cast<int64_t>(offset_column.get_raw_data().data);
ColumnNullable& map_key_column_nullable =
assert_cast<ColumnNullable&>(map_col->get_keys());
auto key_data_column_null_map = map_key_column_nullable.get_null_map_column_ptr();
auto key_data_column = map_key_column_nullable.get_nested_column_ptr();
auto& key_null_map_data =
assert_cast<ColumnVector<UInt8>*>(key_data_column_null_map.get())->get_data();
auto key_nested_nullmap_address = reinterpret_cast<int64_t>(key_null_map_data.data());
ColumnNullable& map_value_column_nullable =
assert_cast<ColumnNullable&>(map_col->get_values());
auto value_data_column_null_map = map_value_column_nullable.get_null_map_column_ptr();
auto value_data_column = map_value_column_nullable.get_nested_column_ptr();
auto& value_null_map_data =
assert_cast<ColumnVector<UInt8>*>(value_data_column_null_map.get())->get_data();
auto value_nested_nullmap_address =
reinterpret_cast<int64_t>(value_null_map_data.data());
jmethodID map_size = env->GetMethodID(hashmap_class, "size", "()I");
size_t has_put_element_size = map_col->get_offsets().back();
size_t hashmap_size = env->CallIntMethod(result_obj, map_size);
size_t element_size = has_put_element_size + hashmap_size;
map_key_column_nullable.resize(element_size);
memset(key_null_map_data.data() + has_put_element_size, 0, hashmap_size);
map_value_column_nullable.resize(element_size);
memset(value_null_map_data.data() + has_put_element_size, 0, hashmap_size);
int64_t key_nested_data_address = 0, key_nested_offset_address = 0;
if (key_data_column->is_column_string()) {
ColumnString* str_col = assert_cast<ColumnString*>(key_data_column.get());
ColumnString::Chars& chars =
assert_cast<ColumnString::Chars&>(str_col->get_chars());
ColumnString::Offsets& offsets =
assert_cast<ColumnString::Offsets&>(str_col->get_offsets());
key_nested_data_address = reinterpret_cast<int64_t>(&chars);
key_nested_offset_address = reinterpret_cast<int64_t>(offsets.data());
} else {
key_nested_data_address =
reinterpret_cast<int64_t>(key_data_column->get_raw_data().data);
}
int64_t value_nested_data_address = 0, value_nested_offset_address = 0;
if (value_data_column->is_column_string()) {
ColumnString* str_col = assert_cast<ColumnString*>(value_data_column.get());
ColumnString::Chars& chars =
assert_cast<ColumnString::Chars&>(str_col->get_chars());
ColumnString::Offsets& offsets =
assert_cast<ColumnString::Offsets&>(str_col->get_offsets());
value_nested_data_address = reinterpret_cast<int64_t>(&chars);
value_nested_offset_address = reinterpret_cast<int64_t>(offsets.data());
} else {
value_nested_data_address =
reinterpret_cast<int64_t>(value_data_column->get_raw_data().data);
}
int row = to.size() - 1;
env->CallNonvirtualVoidMethod(executor_obj, executor_cl, executor_copy_map_result_id,
has_put_element_size, result_nullable, row, result_obj,
nullmap_address, offset_address,
key_nested_nullmap_address, key_nested_data_address,
key_nested_offset_address, value_nested_nullmap_address,
value_nested_data_address, value_nested_offset_address);
env->DeleteLocalRef(hashmap_class);
} else {
return Status::InvalidArgument(strings::Substitute(
"Java UDAF doesn't support return type is $0 now !", return_type->get_name()));
}
return Status::OK();
}
Status register_func_id(JNIEnv* env) {
auto register_id = [&](const char* func_name, const char* func_sign, jmethodID& func_id) {
func_id = env->GetMethodID(executor_cl, func_name, func_sign);
Status s = JniUtil::GetJniExceptionMsg(env);
if (!s.ok()) {
LOG(WARNING) << "Failed to register function " << func_name << ": "
<< s.to_string();
return Status::InternalError(strings::Substitute(
"Java-Udaf register_func_id meet error and error is $0", s.to_string()));
}
@ -440,27 +216,12 @@ private:
RETURN_IF_ERROR(register_id("merge", UDAF_EXECUTOR_MERGE_SIGNATURE, executor_merge_id));
RETURN_IF_ERROR(
register_id("serialize", UDAF_EXECUTOR_SERIALIZE_SIGNATURE, executor_serialize_id));
RETURN_IF_ERROR(register_id("getValue", "(J)Ljava/lang/Object;", executor_get_value_id));
RETURN_IF_ERROR(
register_id("getValue", UDAF_EXECUTOR_GET_SIGNATURE, executor_get_value_id));
RETURN_IF_ERROR(
register_id("destroy", UDAF_EXECUTOR_DESTROY_SIGNATURE, executor_destroy_id));
RETURN_IF_ERROR(register_id("convertBasicArguments", "(IZIIJJJ)[Ljava/lang/Object;",
executor_convert_basic_argument_id));
RETURN_IF_ERROR(register_id("convertArrayArguments", "(IZIIJJJJJ)[Ljava/lang/Object;",
executor_convert_array_argument_id));
RETURN_IF_ERROR(register_id("convertMapArguments", "(IZIIJJJJJJJJ)[Ljava/lang/Object;",
executor_convert_map_argument_id));
RETURN_IF_ERROR(register_id("copyTupleBasicResult", "(Ljava/lang/Object;IJJJJ)V",
executor_copy_basic_result_id));
RETURN_IF_ERROR(register_id("copyTupleArrayResult", "(JZILjava/lang/Object;JJJJJ)V",
executor_copy_array_result_id));
RETURN_IF_ERROR(register_id("copyTupleMapResult", "(JZILjava/lang/Object;JJJJJJJJ)V",
executor_copy_map_result_id));
RETURN_IF_ERROR(
register_id("addBatch", "(ZIIJI[Ljava/lang/Object;)V", executor_add_batch_id));
register_id("addBatch", UDAF_EXECUTOR_ADD_SIGNATURE, executor_add_batch_id));
return Status::OK();
}
@ -478,12 +239,6 @@ private:
jmethodID executor_reset_id;
jmethodID executor_close_id;
jmethodID executor_destroy_id;
jmethodID executor_convert_basic_argument_id;
jmethodID executor_convert_array_argument_id;
jmethodID executor_convert_map_argument_id;
jmethodID executor_copy_basic_result_id;
jmethodID executor_copy_array_result_id;
jmethodID executor_copy_map_result_id;
int argument_size = 0;
std::string serialize_data;
};

View File

@ -311,8 +311,8 @@ Status JniConnector::_fill_column(TableMetaAddress& address, ColumnPtr& doris_co
}
MutableColumnPtr data_column;
if (doris_column->is_nullable()) {
auto* nullable_column = reinterpret_cast<vectorized::ColumnNullable*>(
(*std::move(doris_column)).mutate().get());
auto* nullable_column =
reinterpret_cast<vectorized::ColumnNullable*>(doris_column->assume_mutable().get());
data_column = nullable_column->get_nested_column_ptr();
NullMap& null_map = nullable_column->get_null_map_data();
size_t origin_size = null_map.size();

View File

@ -138,9 +138,9 @@ Status JavaFunctionCall::execute_impl(FunctionContext* context, Block& block,
jobject output_map = JniUtil::convert_to_java_map(env, output_params);
long output_address = env->CallLongMethod(jni_ctx->executor, jni_env->executor_evaluate_id,
input_map, output_map);
RETURN_IF_ERROR(JniUtil::GetJniExceptionMsg(env));
env->DeleteLocalRef(input_map);
env->DeleteLocalRef(output_map);
RETURN_IF_ERROR(JniUtil::GetJniExceptionMsg(env));
return JniConnector::fill_block(&block, {result}, output_address);
}

View File

@ -147,9 +147,10 @@ private:
return;
}
env->CallNonvirtualVoidMethodA(executor, executor_cl_, executor_close_id_, NULL);
env->DeleteGlobalRef(executor);
env->DeleteGlobalRef(executor_cl_);
Status s = JniUtil::GetJniExceptionMsg(env);
if (!s.ok()) LOG(WARNING) << s;
env->DeleteGlobalRef(executor);
is_closed = true;
}
};

View File

@ -38,7 +38,6 @@ import java.net.URL;
import java.net.URLClassLoader;
import java.security.AccessController;
import java.security.PrivilegedAction;
import java.time.DateTimeException;
import java.time.LocalDate;
import java.time.LocalDateTime;
import java.util.Set;
@ -150,6 +149,7 @@ public class UdfUtils {
* Sets the argument types of a Java UDF or UDAF. Returns true if the argument types specified
* in the UDF are compatible with the argument types of the evaluate() function loaded
* from the associated JAR file.
*
* @throws InternalException
*/
public static Pair<Boolean, JavaUdfDataType[]> setArgTypes(Type[] parameterTypes, Class<?>[] udfArgTypes,
@ -193,191 +193,6 @@ public class UdfUtils {
return Pair.of(true, inputArgTypes);
}
public static Object convertDateTimeV2ToJavaDateTime(long date, Class clz) {
int year = (int) (date >> 46);
int yearMonth = (int) (date >> 42);
int yearMonthDay = (int) (date >> 37);
int month = (yearMonth & 0XF);
int day = (yearMonthDay & 0X1F);
int hour = (int) ((date >> 32) & 0X1F);
int minute = (int) ((date >> 26) & 0X3F);
int second = (int) ((date >> 20) & 0X3F);
//here don't need those bits are type = ((minus_type_neg >> 1) & 0x7);
if (LocalDateTime.class.equals(clz)) {
return convertToLocalDateTime(year, month, day, hour, minute, second);
} else if (org.joda.time.DateTime.class.equals(clz)) {
return convertToJodaDateTime(year, month, day, hour, minute, second);
} else if (org.joda.time.LocalDateTime.class.equals(clz)) {
return convertToJodaLocalDateTime(year, month, day, hour, minute, second);
} else {
return null;
}
}
/**
* input is a 64bit num from backend, and then get year, month, day, hour, minus, second by the order of bits.
*/
public static Object convertDateTimeToJavaDateTime(long date, Class clz) {
int year = (int) (date >> 48);
int yearMonth = (int) (date >> 40);
int yearMonthDay = (int) (date >> 32);
int month = (yearMonth & 0XFF);
int day = (yearMonthDay & 0XFF);
int hourMinuteSecond = (int) (date % (1 << 31));
int minuteTypeNeg = (hourMinuteSecond % (1 << 16));
int hour = (hourMinuteSecond >> 24);
int minute = ((hourMinuteSecond >> 16) & 0XFF);
int second = (minuteTypeNeg >> 4);
//here don't need those bits are type = ((minus_type_neg >> 1) & 0x7);
if (LocalDateTime.class.equals(clz)) {
return convertToLocalDateTime(year, month, day, hour, minute, second);
} else if (org.joda.time.DateTime.class.equals(clz)) {
return convertToJodaDateTime(year, month, day, hour, minute, second);
} else if (org.joda.time.LocalDateTime.class.equals(clz)) {
return convertToJodaLocalDateTime(year, month, day, hour, minute, second);
} else {
return null;
}
}
public static Object convertDateV2ToJavaDate(int date, Class clz) {
int year = date >> 9;
int month = (date >> 5) & 0XF;
int day = date & 0X1F;
if (LocalDate.class.equals(clz)) {
return convertToLocalDate(year, month, day);
} else if (java.util.Date.class.equals(clz)) {
return convertToJavaDate(year, month, day);
} else if (org.joda.time.LocalDate.class.equals(clz)) {
return convertToJodaDate(year, month, day);
} else {
return null;
}
}
public static LocalDateTime convertToLocalDateTime(int year, int month, int day,
int hour, int minute, int second) {
LocalDateTime value = null;
try {
value = LocalDateTime.of(year, month, day, hour, minute, second);
} catch (DateTimeException e) {
LOG.warn("Error occurs when parsing date time value: {}", e);
}
return value;
}
public static org.joda.time.DateTime convertToJodaDateTime(int year, int month, int day,
int hour, int minute, int second) {
try {
return new org.joda.time.DateTime(year, month, day, hour, minute, second);
} catch (Exception e) {
return null;
}
}
public static org.joda.time.LocalDateTime convertToJodaLocalDateTime(int year, int month, int day,
int hour, int minute, int second) {
try {
return new org.joda.time.LocalDateTime(year, month, day, hour, minute, second);
} catch (Exception e) {
return null;
}
}
public static Object convertDateToJavaDate(long date, Class clz) {
int year = (int) (date >> 48);
int yearMonth = (int) (date >> 40);
int yearMonthDay = (int) (date >> 32);
int month = (yearMonth & 0XFF);
int day = (yearMonthDay & 0XFF);
if (LocalDate.class.equals(clz)) {
return convertToLocalDate(year, month, day);
} else if (java.util.Date.class.equals(clz)) {
return convertToJavaDate(year, month, day);
} else if (org.joda.time.LocalDate.class.equals(clz)) {
return convertToJodaDate(year, month, day);
} else {
return null;
}
}
/**
* a 64bit num convertToDate.
*/
public static LocalDate convertToLocalDate(int year, int month, int day) {
LocalDate value = null;
try {
value = LocalDate.of(year, month, day);
} catch (DateTimeException e) {
LOG.warn("Error occurs when parsing date value: {}", e);
}
return value;
}
public static org.joda.time.LocalDate convertToJodaDate(int year, int month, int day) {
try {
return new org.joda.time.LocalDate(year, month, day);
} catch (Exception e) {
return null;
}
}
public static java.util.Date convertToJavaDate(int year, int month, int day) {
try {
return new java.util.Date(year - 1900, month - 1, day);
} catch (Exception e) {
return null;
}
}
/**
* input is the second, minute, hours, day , month and year respectively.
* and then combining all num to a 64bit value return to backend;
*/
public static long convertToDateTime(Object obj, Class clz) {
if (LocalDateTime.class.equals(clz)) {
LocalDateTime date = (LocalDateTime) obj;
return convertToDateTime(date.getYear(), date.getMonthValue(), date.getDayOfMonth(), date.getHour(),
date.getMinute(), date.getSecond(), false);
} else if (org.joda.time.DateTime.class.equals(clz)) {
org.joda.time.DateTime date = (org.joda.time.DateTime) obj;
return convertToDateTime(date.getYear(), date.getMonthOfYear(), date.getDayOfMonth(), date.getHourOfDay(),
date.getMinuteOfHour(), date.getSecondOfMinute(), false);
} else if (org.joda.time.LocalDateTime.class.equals(clz)) {
org.joda.time.LocalDateTime date = (org.joda.time.LocalDateTime) obj;
return convertToDateTime(date.getYear(), date.getMonthOfYear(), date.getDayOfMonth(), date.getHourOfDay(),
date.getMinuteOfHour(), date.getSecondOfMinute(), false);
} else {
return 0;
}
}
public static long convertToDate(Object obj, Class clz) {
if (LocalDate.class.equals(clz)) {
LocalDate date = (LocalDate) obj;
return convertToDateTime(date.getYear(), date.getMonthValue(), date.getDayOfMonth(), 0,
0, 0, true);
} else if (java.util.Date.class.equals(clz)) {
java.util.Date date = (java.util.Date) obj;
return convertToDateTime(date.getYear() + 1900, date.getMonth() + 1, date.getDay(), 0,
0, 0, true);
} else if (org.joda.time.LocalDate.class.equals(clz)) {
org.joda.time.LocalDate date = (org.joda.time.LocalDate) obj;
return convertToDateTime(date.getYear(), date.getMonthOfYear(), date.getDayOfMonth(), 0,
0, 0, true);
} else {
return 0;
}
}
public static long convertToDateTime(int year, int month, int day, int hour, int minute, int second,
boolean isDate) {
long time = 0;
@ -394,54 +209,16 @@ public class UdfUtils {
return time;
}
public static long convertToDateTimeV2(int year, int month, int day, int hour, int minute, int second) {
return (long) second << 20 | (long) minute << 26 | (long) hour << 32
| (long) day << 37 | (long) month << 42 | (long) year << 46;
}
public static long convertToDateTimeV2(
int year, int month, int day, int hour, int minute, int second, int microsecond) {
return (long) microsecond | (long) second << 20 | (long) minute << 26 | (long) hour << 32
| (long) day << 37 | (long) month << 42 | (long) year << 46;
}
public static long convertToDateTimeV2(Object obj, Class clz) {
if (LocalDateTime.class.equals(clz)) {
LocalDateTime date = (LocalDateTime) obj;
return convertToDateTimeV2(date.getYear(), date.getMonthValue(), date.getDayOfMonth(), date.getHour(),
date.getMinute(), date.getSecond());
} else if (org.joda.time.DateTime.class.equals(clz)) {
org.joda.time.DateTime date = (org.joda.time.DateTime) obj;
return convertToDateTimeV2(date.getYear(), date.getMonthOfYear(), date.getDayOfMonth(), date.getHourOfDay(),
date.getMinuteOfHour(), date.getSecondOfMinute(), date.getMillisOfSecond() * 1000);
} else if (org.joda.time.LocalDateTime.class.equals(clz)) {
org.joda.time.LocalDateTime date = (org.joda.time.LocalDateTime) obj;
return convertToDateTimeV2(date.getYear(), date.getMonthOfYear(), date.getDayOfMonth(), date.getHourOfDay(),
date.getMinuteOfHour(), date.getSecondOfMinute(), date.getMillisOfSecond() * 1000);
} else {
return 0;
}
}
public static int convertToDateV2(int year, int month, int day) {
return (int) (day | (long) month << 5 | (long) year << 9);
}
public static int convertToDateV2(Object obj, Class clz) {
if (LocalDate.class.equals(clz)) {
LocalDate date = (LocalDate) obj;
return convertToDateV2(date.getYear(), date.getMonthValue(), date.getDayOfMonth());
} else if (java.util.Date.class.equals(clz)) {
java.util.Date date = (java.util.Date) obj;
return convertToDateV2(date.getYear(), date.getMonth(), date.getDay());
} else if (org.joda.time.LocalDate.class.equals(clz)) {
org.joda.time.LocalDate date = (org.joda.time.LocalDate) obj;
return convertToDateV2(date.getYear(), date.getDayOfMonth(), date.getDayOfMonth());
} else {
return 0;
}
}
/**
* Change the order of the bytes, Because JVM is Big-Endian , x86 is Little-Endian.
*/

View File

@ -171,6 +171,11 @@ public class ColumnType {
return type == Type.DATETIMEV2;
}
public boolean isPrimitive() {
return type == Type.BOOLEAN || type == Type.BYTE || type == Type.TINYINT || type == Type.SMALLINT
|| type == Type.INT || type == Type.BIGINT || type == Type.FLOAT || type == Type.DOUBLE;
}
public Type getType() {
return type;
}

View File

@ -1301,7 +1301,11 @@ public class VectorColumn {
}
}
private Object[] newObjectContainerArray(ColumnType.Type type, int size) {
public Object[] newObjectContainerArray(int size) {
return newObjectContainerArray(columnType.getType(), size);
}
public Object[] newObjectContainerArray(ColumnType.Type type, int size) {
switch (type) {
case BOOLEAN:
return new Boolean[size];

View File

@ -138,13 +138,13 @@ public class VectorTable {
* @param converters A map of converters. Convert the column values if the type is not defined in ColumnType.
* The map key is the field ID in VectorTable.
*/
public Object[][] getMaterializedData(Map<Integer, ColumnValueConverter> converters) {
public Object[][] getMaterializedData(int start, int end, Map<Integer, ColumnValueConverter> converters) {
if (columns.length == 0) {
return new Object[0][0];
}
Object[][] data = new Object[columns.length][];
for (int j = 0; j < columns.length; ++j) {
Object[] columnData = columns[j].getObjectColumn(0, columns[j].numRows());
Object[] columnData = columns[j].getObjectColumn(start, end);
if (converters.containsKey(j)) {
data[j] = converters.get(j).convert(columnData);
} else {
@ -154,6 +154,10 @@ public class VectorTable {
return data;
}
public Object[][] getMaterializedData(Map<Integer, ColumnValueConverter> converters) {
return getMaterializedData(0, getNumRows(), converters);
}
public Object[][] getMaterializedData() {
return getMaterializedData(Collections.emptyMap());
}
@ -166,6 +170,10 @@ public class VectorTable {
return columns[fieldId];
}
public ColumnType getColumnType(int fieldId) {
return columnTypes[fieldId];
}
public ColumnType[] getColumnTypes() {
return columnTypes;
}

View File

@ -17,12 +17,14 @@
package org.apache.doris.udf;
import org.apache.doris.catalog.PrimitiveType;
import org.apache.doris.catalog.Type;
import org.apache.doris.common.Pair;
import org.apache.doris.common.exception.UdfRuntimeException;
import org.apache.doris.common.jni.utils.JavaUdfDataType;
import org.apache.doris.common.jni.utils.OffHeap;
import org.apache.doris.common.jni.utils.UdfUtils;
import org.apache.doris.common.jni.vec.ColumnValueConverter;
import org.apache.doris.common.jni.vec.VectorTable;
import org.apache.doris.thrift.TJavaUdfExecutorCtorParams;
import com.esotericsoftware.reflectasm.MethodAccess;
@ -34,12 +36,13 @@ import java.io.ByteArrayInputStream;
import java.io.ByteArrayOutputStream;
import java.io.DataInputStream;
import java.io.DataOutputStream;
import java.lang.reflect.Array;
import java.lang.reflect.Constructor;
import java.lang.reflect.Method;
import java.net.MalformedURLException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashMap;
import java.util.Map;
/**
* udaf executor.
@ -52,6 +55,7 @@ public class UdafExecutor extends BaseExecutor {
private HashMap<Long, Object> stateObjMap;
private Class retClass;
private int addIndex;
private VectorTable outputTable = null;
/**
* Constructor to create an object.
@ -65,102 +69,90 @@ public class UdafExecutor extends BaseExecutor {
*/
@Override
public void close() {
allMethods = null;
if (outputTable != null) {
outputTable.close();
}
super.close();
}
public Object[] convertBasicArguments(int argIdx, boolean isNullable, int rowStart, int rowEnd, long nullMapAddr,
long columnAddr, long strOffsetAddr) {
return convertBasicArg(false, argIdx, isNullable, rowStart, rowEnd, nullMapAddr, columnAddr, strOffsetAddr);
private Map<Integer, ColumnValueConverter> getInputConverters(int numColumns) {
Map<Integer, ColumnValueConverter> converters = new HashMap<>();
for (int j = 0; j < numColumns; ++j) {
ColumnValueConverter converter = getInputConverter(argTypes[j].getPrimitiveType(), argClass[j + 1]);
if (converter != null) {
converters.put(j, converter);
}
}
return converters;
}
public Object[] convertArrayArguments(int argIdx, boolean isNullable, int rowStart, int rowEnd, long nullMapAddr,
long offsetsAddr, long nestedNullMapAddr, long dataAddr, long strOffsetAddr) {
return convertArrayArg(argIdx, isNullable, rowStart, rowEnd, nullMapAddr, offsetsAddr, nestedNullMapAddr,
dataAddr, strOffsetAddr);
private ColumnValueConverter getOutputConverter() {
return getOutputConverter(retType.getPrimitiveType(), retClass);
}
public Object[] convertMapArguments(int argIdx, boolean isNullable, int rowStart, int rowEnd, long nullMapAddr,
long offsetsAddr, long keyNestedNullMapAddr, long keyDataAddr, long keyStrOffsetAddr,
long valueNestedNullMapAddr, long valueDataAddr, long valueStrOffsetAddr) {
PrimitiveType keyType = argTypes[argIdx].getKeyType().getPrimitiveType();
PrimitiveType valueType = argTypes[argIdx].getValueType().getPrimitiveType();
Object[] keyCol = convertMapArg(keyType, argIdx, isNullable, rowStart, rowEnd, nullMapAddr, offsetsAddr,
keyNestedNullMapAddr, keyDataAddr,
keyStrOffsetAddr, argTypes[argIdx].getKeyScale());
Object[] valueCol = convertMapArg(valueType, argIdx, isNullable, rowStart, rowEnd, nullMapAddr, offsetsAddr,
valueNestedNullMapAddr,
valueDataAddr,
valueStrOffsetAddr, argTypes[argIdx].getValueScale());
return buildHashMap(keyType, valueType, keyCol, valueCol);
}
public void addBatch(boolean isSinglePlace, int rowStart, int rowEnd, long placeAddr, int offset, Object[] column)
throws UdfRuntimeException {
if (isSinglePlace) {
addBatchSingle(rowStart, rowEnd, placeAddr, column);
} else {
addBatchPlaces(rowStart, rowEnd, placeAddr, offset, column);
public void addBatch(boolean isSinglePlace, int rowStart, int rowEnd, long placeAddr, int offset,
Map<String, String> inputParams) throws UdfRuntimeException {
try {
VectorTable inputTable = VectorTable.createReadableTable(inputParams);
Object[][] inputs = inputTable.getMaterializedData(rowStart, rowEnd,
getInputConverters(inputTable.getNumColumns()));
if (isSinglePlace) {
addBatchSingle(rowStart, rowEnd, placeAddr, inputs);
} else {
addBatchPlaces(rowStart, rowEnd, placeAddr, offset, inputs);
}
} catch (Exception e) {
LOG.warn("evaluate exception: " + debugString(), e);
throw new UdfRuntimeException("UDAF failed to evaluate", e);
}
}
public void addBatchSingle(int rowStart, int rowEnd, long placeAddr, Object[] column) throws UdfRuntimeException {
try {
Long curPlace = placeAddr;
Object[] inputArgs = new Object[argTypes.length + 1];
public void addBatchSingle(int rowStart, int rowEnd, long placeAddr, Object[][] inputs) throws UdfRuntimeException {
Long curPlace = placeAddr;
Object[] inputArgs = new Object[argTypes.length + 1];
Object state = stateObjMap.get(curPlace);
if (state != null) {
inputArgs[0] = state;
} else {
Object newState = createAggState();
stateObjMap.put(curPlace, newState);
inputArgs[0] = newState;
}
int numColumns = inputs.length;
int numRows = rowEnd - rowStart;
for (int i = 0; i < numRows; ++i) {
for (int j = 0; j < numColumns; ++j) {
inputArgs[j + 1] = inputs[j][i];
}
methodAccess.invoke(udf, addIndex, inputArgs);
}
}
public void addBatchPlaces(int rowStart, int rowEnd, long placeAddr, int offset, Object[][] inputs)
throws UdfRuntimeException {
int numColumns = inputs.length;
int numRows = rowEnd - rowStart;
Object[] placeState = new Object[numRows];
for (int row = rowStart; row < rowEnd; ++row) {
Long curPlace = OffHeap.UNSAFE.getLong(null, placeAddr + (8L * row)) + offset;
Object state = stateObjMap.get(curPlace);
if (state != null) {
inputArgs[0] = state;
placeState[row - rowStart] = state;
} else {
Object newState = createAggState();
stateObjMap.put(curPlace, newState);
inputArgs[0] = newState;
placeState[row - rowStart] = newState;
}
Object[][] inputs = (Object[][]) column;
for (int i = 0; i < (rowEnd - rowStart); ++i) {
for (int j = 0; j < column.length; ++j) {
inputArgs[j + 1] = inputs[j][i];
}
methodAccess.invoke(udf, addIndex, inputArgs);
}
} catch (Exception e) {
LOG.info("evaluate exception debug: " + debugString());
LOG.info("invoke add function meet some error: " + e.getCause().toString());
throw new UdfRuntimeException("UDAF failed to addBatchSingle: ", e);
}
}
//spilt into two for loop
public void addBatchPlaces(int rowStart, int rowEnd, long placeAddr, int offset, Object[] column)
throws UdfRuntimeException {
try {
Object[][] inputs = (Object[][]) column;
ArrayList<Object> placeState = new ArrayList<>(rowEnd - rowStart);
for (int row = rowStart; row < rowEnd; ++row) {
Long curPlace = UdfUtils.UNSAFE.getLong(null, placeAddr + (8L * row)) + offset;
Object state = stateObjMap.get(curPlace);
if (state != null) {
placeState.add(state);
} else {
Object newState = createAggState();
stateObjMap.put(curPlace, newState);
placeState.add(newState);
}
Object[] inputArgs = new Object[argTypes.length + 1];
for (int row = 0; row < numRows; ++row) {
inputArgs[0] = placeState[row];
for (int j = 0; j < numColumns; ++j) {
inputArgs[j + 1] = inputs[j][row];
}
//spilt into two for loop
Object[] inputArgs = new Object[argTypes.length + 1];
for (int row = 0; row < (rowEnd - rowStart); ++row) {
inputArgs[0] = placeState.get(row);
for (int j = 0; j < column.length; ++j) {
inputArgs[j + 1] = inputs[j][row];
}
methodAccess.invoke(udf, addIndex, inputArgs);
}
} catch (Exception e) {
LOG.info("evaluate exception debug: " + debugString());
LOG.info("invoke add function meet some error: " + Arrays.toString(e.getStackTrace()));
throw new UdfRuntimeException("UDAF failed to addBatchPlaces: ", e);
methodAccess.invoke(udf, addIndex, inputArgs);
}
}
@ -171,7 +163,7 @@ public class UdafExecutor extends BaseExecutor {
try {
return allMethods.get(UDAF_CREATE_FUNCTION).invoke(udf, null);
} catch (Exception e) {
LOG.warn("invoke createAggState function meet some error: " + e.getCause().toString());
LOG.warn("invoke createAggState function meet some error: ", e);
throw new UdfRuntimeException("UDAF failed to create: ", e);
}
}
@ -186,7 +178,7 @@ public class UdafExecutor extends BaseExecutor {
}
stateObjMap.clear();
} catch (Exception e) {
LOG.warn("invoke destroy function meet some error: " + e.getCause().toString());
LOG.warn("invoke destroy function meet some error: ", e);
throw new UdfRuntimeException("UDAF failed to destroy: ", e);
}
}
@ -198,31 +190,31 @@ public class UdafExecutor extends BaseExecutor {
try {
Object[] args = new Object[2];
ByteArrayOutputStream baos = new ByteArrayOutputStream();
args[0] = stateObjMap.get((Long) place);
args[0] = stateObjMap.get(place);
args[1] = new DataOutputStream(baos);
allMethods.get(UDAF_SERIALIZE_FUNCTION).invoke(udf, args);
return baos.toByteArray();
} catch (Exception e) {
LOG.info("evaluate exception debug: " + debugString());
LOG.warn("invoke serialize function meet some error: " + e.getCause().toString());
LOG.warn("invoke serialize function meet some error: ", e);
throw new UdfRuntimeException("UDAF failed to serialize: ", e);
}
}
/*
/**
* invoke reset function and reset the state to init.
*/
public void reset(long place) throws UdfRuntimeException {
try {
Object[] args = new Object[1];
args[0] = stateObjMap.get((Long) place);
args[0] = stateObjMap.get(place);
if (args[0] == null) {
return;
}
allMethods.get(UDAF_RESET_FUNCTION).invoke(udf, args);
} catch (Exception e) {
LOG.info("evaluate exception debug: " + debugString());
LOG.warn("invoke reset function meet some error: " + e.getCause().toString());
LOG.warn("invoke reset function meet some error: ", e);
throw new UdfRuntimeException("UDAF failed to reset: ", e);
}
}
@ -251,7 +243,7 @@ public class UdafExecutor extends BaseExecutor {
allMethods.get(UDAF_MERGE_FUNCTION).invoke(udf, args);
} catch (Exception e) {
LOG.info("evaluate exception debug: " + debugString());
LOG.warn("invoke merge function meet some error: " + e.getCause().toString());
LOG.warn("invoke merge function meet some error: ", e);
throw new UdfRuntimeException("UDAF failed to merge: ", e);
}
}
@ -259,75 +251,32 @@ public class UdafExecutor extends BaseExecutor {
/**
* invoke getValue to return finally result.
*/
public Object getValue(long place) throws UdfRuntimeException {
public long getValue(long place, Map<String, String> outputParams) throws UdfRuntimeException {
try {
if (outputTable != null) {
outputTable.close();
}
outputTable = VectorTable.createWritableTable(outputParams, 1);
if (stateObjMap.get(place) == null) {
stateObjMap.put(place, createAggState());
}
return allMethods.get(UDAF_RESULT_FUNCTION).invoke(udf, stateObjMap.get((Long) place));
Object value = allMethods.get(UDAF_RESULT_FUNCTION).invoke(udf, stateObjMap.get(place));
// If the return type is primitive, we can't cast the array of primitive type as array of Object,
// so we have to new its wrapped Object.
Object[] result = outputTable.getColumnType(0).isPrimitive()
? outputTable.getColumn(0).newObjectContainerArray(1)
: (Object[]) Array.newInstance(retClass, 1);
result[0] = value;
boolean isNullable = Boolean.parseBoolean(outputParams.getOrDefault("is_nullable", "true"));
outputTable.appendData(0, result, getOutputConverter(), isNullable);
return outputTable.getMetaAddress();
} catch (Exception e) {
LOG.info("evaluate exception debug: " + debugString());
LOG.warn("invoke getValue function meet some error: " + e.getCause().toString());
LOG.warn("invoke getValue function meet some error: ", e);
throw new UdfRuntimeException("UDAF failed to result", e);
}
}
public void copyTupleBasicResult(Object result, int row, long outputNullMapPtr, long outputBufferBase,
long charsAddress,
long offsetsAddr) throws UdfRuntimeException {
if (result == null) {
// put null obj
if (outputNullMapPtr == -1) {
throw new UdfRuntimeException("UDAF failed to store null data to not null column");
} else {
UdfUtils.UNSAFE.putByte(outputNullMapPtr + row, (byte) 1);
}
return;
}
try {
if (outputNullMapPtr != -1) {
UdfUtils.UNSAFE.putByte(outputNullMapPtr + row, (byte) 0);
}
copyTupleBasicResult(result, row, retClass, outputBufferBase, charsAddress,
offsetsAddr, retType);
} catch (UdfRuntimeException e) {
LOG.info(e.toString());
}
}
public void copyTupleArrayResult(long hasPutElementNum, boolean isNullable, int row, Object result,
long nullMapAddr,
long offsetsAddr, long nestedNullMapAddr, long dataAddr, long strOffsetAddr) throws UdfRuntimeException {
if (nullMapAddr > 0) {
UdfUtils.UNSAFE.putByte(nullMapAddr + row, (byte) 0);
}
copyTupleArrayResultImpl(hasPutElementNum, isNullable, row, result, nullMapAddr, offsetsAddr, nestedNullMapAddr,
dataAddr, strOffsetAddr, retType.getItemType().getPrimitiveType(), retType.getScale());
}
public void copyTupleMapResult(long hasPutElementNum, boolean isNullable, int row, Object result, long nullMapAddr,
long offsetsAddr,
long keyNsestedNullMapAddr, long keyDataAddr,
long keyStrOffsetAddr,
long valueNsestedNullMapAddr, long valueDataAddr, long valueStrOffsetAddr) throws UdfRuntimeException {
if (nullMapAddr > 0) {
UdfUtils.UNSAFE.putByte(nullMapAddr + row, (byte) 0);
}
PrimitiveType keyType = retType.getKeyType().getPrimitiveType();
PrimitiveType valueType = retType.getValueType().getPrimitiveType();
Object[] keyCol = new Object[1];
Object[] valueCol = new Object[1];
Object[] resultArr = new Object[1];
resultArr[0] = result;
buildArrayListFromHashMap(resultArr, keyType, valueType, keyCol, valueCol);
copyTupleArrayResultImpl(hasPutElementNum, isNullable, row,
valueCol[0], nullMapAddr, offsetsAddr,
valueNsestedNullMapAddr, valueDataAddr, valueStrOffsetAddr, valueType, retType.getKeyScale());
copyTupleArrayResultImpl(hasPutElementNum, isNullable, row, keyCol[0], nullMapAddr, offsetsAddr,
keyNsestedNullMapAddr, keyDataAddr, keyStrOffsetAddr, keyType, retType.getValueScale());
}
@Override
protected void init(TJavaUdfExecutorCtorParams request, String jarPath, Type funcRetType,
Type... parameterTypes) throws UdfRuntimeException {
@ -406,7 +355,9 @@ public class UdafExecutor extends BaseExecutor {
return;
}
StringBuilder sb = new StringBuilder();
sb.append("Unable to find evaluate function with the correct signature: ").append(className + ".evaluate(")
sb.append("Unable to find evaluate function with the correct signature: ")
.append(className)
.append(".evaluate(")
.append(Joiner.on(", ").join(parameterTypes)).append(")\n").append("UDF contains: \n ")
.append(Joiner.on("\n ").join(signatures));
throw new UdfRuntimeException(sb.toString());

View File

@ -25,7 +25,6 @@ import org.apache.doris.common.jni.utils.UdfUtils;
import org.apache.doris.common.jni.vec.ColumnValueConverter;
import org.apache.doris.common.jni.vec.VectorTable;
import org.apache.doris.thrift.TJavaUdfExecutorCtorParams;
import org.apache.doris.thrift.TPrimitiveType;
import com.esotericsoftware.reflectasm.MethodAccess;
import com.google.common.base.Joiner;
@ -36,15 +35,11 @@ import java.lang.reflect.Array;
import java.lang.reflect.Constructor;
import java.lang.reflect.Method;
import java.net.MalformedURLException;
import java.time.LocalDate;
import java.time.LocalDateTime;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.Map;
public class UdfExecutor extends BaseExecutor {
// private static final java.util.logging.Logger LOG =
// Logger.getLogger(UdfExecutor.class);
public static final Logger LOG = Logger.getLogger(UdfExecutor.class);
// setup by init() and cleared by close()
private Method method;
@ -76,78 +71,6 @@ public class UdfExecutor extends BaseExecutor {
super.close();
}
private ColumnValueConverter getInputConverter(TPrimitiveType primitiveType, Class clz) {
switch (primitiveType) {
case DATE:
case DATEV2: {
if (java.util.Date.class.equals(clz)) {
return (Object[] columnData) -> {
Object[] result = new java.util.Date[columnData.length];
for (int i = 0; i < columnData.length; ++i) {
if (columnData[i] != null) {
LocalDate v = (LocalDate) columnData[i];
result[i] = new java.util.Date(v.getYear() - 1900, v.getMonthValue() - 1,
v.getDayOfMonth());
}
}
return result;
};
} else if (org.joda.time.LocalDate.class.equals(clz)) {
return (Object[] columnData) -> {
Object[] result = new org.joda.time.LocalDate[columnData.length];
for (int i = 0; i < columnData.length; ++i) {
if (columnData[i] != null) {
LocalDate v = (LocalDate) columnData[i];
result[i] = new org.joda.time.LocalDate(v.getYear(), v.getMonthValue(),
v.getDayOfMonth());
}
}
return result;
};
} else if (!LocalDate.class.equals(clz)) {
throw new RuntimeException("Unsupported date type: " + clz.getCanonicalName());
}
break;
}
case DATETIME:
case DATETIMEV2: {
if (org.joda.time.DateTime.class.equals(clz)) {
return (Object[] columnData) -> {
Object[] result = new org.joda.time.DateTime[columnData.length];
for (int i = 0; i < columnData.length; ++i) {
if (columnData[i] != null) {
LocalDateTime v = (LocalDateTime) columnData[i];
result[i] = new org.joda.time.DateTime(v.getYear(), v.getMonthValue(),
v.getDayOfMonth(), v.getHour(),
v.getMinute(), v.getSecond(), v.getNano() / 1000000);
}
}
return result;
};
} else if (org.joda.time.LocalDateTime.class.equals(clz)) {
return (Object[] columnData) -> {
Object[] result = new org.joda.time.LocalDateTime[columnData.length];
for (int i = 0; i < columnData.length; ++i) {
if (columnData[i] != null) {
LocalDateTime v = (LocalDateTime) columnData[i];
result[i] = new org.joda.time.LocalDateTime(v.getYear(), v.getMonthValue(),
v.getDayOfMonth(), v.getHour(),
v.getMinute(), v.getSecond(), v.getNano() / 1000000);
}
}
return result;
};
} else if (!LocalDateTime.class.equals(clz)) {
throw new RuntimeException("Unsupported date type: " + clz.getCanonicalName());
}
break;
}
default:
break;
}
return null;
}
private Map<Integer, ColumnValueConverter> getInputConverters(int numColumns) {
Map<Integer, ColumnValueConverter> converters = new HashMap<>();
for (int j = 0; j < numColumns; ++j) {
@ -160,74 +83,7 @@ public class UdfExecutor extends BaseExecutor {
}
private ColumnValueConverter getOutputConverter() {
Class clz = method.getReturnType();
switch (retType.getPrimitiveType()) {
case DATE:
case DATEV2: {
if (java.util.Date.class.equals(clz)) {
return (Object[] columnData) -> {
Object[] result = new LocalDate[columnData.length];
for (int i = 0; i < columnData.length; ++i) {
if (columnData[i] != null) {
java.util.Date v = (java.util.Date) columnData[i];
result[i] = LocalDate.of(v.getYear() + 1900, v.getMonth() + 1, v.getDate());
}
}
return result;
};
} else if (org.joda.time.LocalDate.class.equals(clz)) {
return (Object[] columnData) -> {
Object[] result = new LocalDate[columnData.length];
for (int i = 0; i < columnData.length; ++i) {
if (columnData[i] != null) {
org.joda.time.LocalDate v = (org.joda.time.LocalDate) columnData[i];
result[i] = LocalDate.of(v.getYear(), v.getMonthOfYear(), v.getDayOfMonth());
}
}
return result;
};
} else if (!LocalDate.class.equals(clz)) {
throw new RuntimeException("Unsupported date type: " + clz.getCanonicalName());
}
break;
}
case DATETIME:
case DATETIMEV2: {
if (org.joda.time.DateTime.class.equals(clz)) {
return (Object[] columnData) -> {
Object[] result = new LocalDateTime[columnData.length];
for (int i = 0; i < columnData.length; ++i) {
if (columnData[i] != null) {
org.joda.time.DateTime v = (org.joda.time.DateTime) columnData[i];
result[i] = LocalDateTime.of(v.getYear(), v.getMonthOfYear(), v.getDayOfMonth(),
v.getHourOfDay(),
v.getMinuteOfHour(), v.getSecondOfMinute(), v.getMillisOfSecond() * 1000000);
}
}
return result;
};
} else if (org.joda.time.LocalDateTime.class.equals(clz)) {
return (Object[] columnData) -> {
Object[] result = new LocalDateTime[columnData.length];
for (int i = 0; i < columnData.length; ++i) {
if (columnData[i] != null) {
org.joda.time.LocalDateTime v = (org.joda.time.LocalDateTime) columnData[i];
result[i] = LocalDateTime.of(v.getYear(), v.getMonthOfYear(), v.getDayOfMonth(),
v.getHourOfDay(),
v.getMinuteOfHour(), v.getSecondOfMinute(), v.getMillisOfSecond() * 1000000);
}
}
return result;
};
} else if (!LocalDateTime.class.equals(clz)) {
throw new RuntimeException("Unsupported date type: " + clz.getCanonicalName());
}
break;
}
default:
break;
}
return null;
return getOutputConverter(retType.getPrimitiveType(), method.getReturnType());
}
public long evaluate(Map<String, String> inputParams, Map<String, String> outputParams) throws UdfRuntimeException {
@ -235,7 +91,16 @@ public class UdfExecutor extends BaseExecutor {
VectorTable inputTable = VectorTable.createReadableTable(inputParams);
int numRows = inputTable.getNumRows();
int numColumns = inputTable.getNumColumns();
Object[] result = (Object[]) Array.newInstance(method.getReturnType(), numRows);
if (outputTable != null) {
outputTable.close();
}
outputTable = VectorTable.createWritableTable(outputParams, numRows);
// If the return type is primitive, we can't cast the array of primitive type as array of Object,
// so we have to new its wrapped Object.
Object[] result = outputTable.getColumnType(0).isPrimitive()
? outputTable.getColumn(0).newObjectContainerArray(numRows)
: (Object[]) Array.newInstance(method.getReturnType(), numRows);
Object[][] inputs = inputTable.getMaterializedData(getInputConverters(numColumns));
Object[] parameters = new Object[numColumns];
for (int i = 0; i < numRows; ++i) {
@ -244,13 +109,7 @@ public class UdfExecutor extends BaseExecutor {
}
result[i] = methodAccess.invoke(udf, evaluateIndex, parameters);
}
if (outputTable != null) {
outputTable.close();
}
boolean isNullable = Boolean.parseBoolean(outputParams.getOrDefault("is_nullable", "true"));
outputTable = VectorTable.createWritableTable(outputParams, numRows);
outputTable.appendData(0, result, getOutputConverter(), isNullable);
return outputTable.getMetaAddress();
} catch (Exception e) {
@ -336,7 +195,8 @@ public class UdfExecutor extends BaseExecutor {
StringBuilder sb = new StringBuilder();
sb.append("Unable to find evaluate function with the correct signature: ")
.append(className + ".evaluate(")
.append(className)
.append(".evaluate(")
.append(Joiner.on(", ").join(parameterTypes))
.append(")\n")
.append("UDF contains: \n ")

View File

@ -32,9 +32,9 @@ suite("test_javaudf_agg_map") {
CREATE TABLE IF NOT EXISTS db_agg_map(
`id` INT NULL COMMENT "",
`i` INT NULL COMMENT "",
`d` Double NULL COMMENT "",
`mii` Map<INT, INT> NULL COMMENT "",
`mid` Map<INT, Double> NULL COMMENT ""
`d` Double NULL COMMENT "",
`mii` Map<INT, INT> NULL COMMENT "",
`mid` Map<INT, Double> NULL COMMENT ""
) ENGINE=OLAP
DUPLICATE KEY(`id`)
DISTRIBUTED BY HASH(`id`) BUCKETS 1