[refactor](udaf) refactor call udaf function and support map type in return (#22508)

This commit is contained in:
Mryange
2023-08-09 22:44:07 +08:00
committed by GitHub
parent 0d75a54d6c
commit 768088c95e
12 changed files with 763 additions and 1833 deletions

View File

@ -25,6 +25,7 @@
#include "common/compiler_util.h"
#include "common/exception.h"
#include "common/logging.h"
#include "common/status.h"
#include "gutil/strings/substitute.h"
#include "runtime/user_function_cache.h"
@ -55,15 +56,7 @@ const char* UDAF_EXECUTOR_RESET_SIGNATURE = "(J)V";
struct AggregateJavaUdafData {
public:
AggregateJavaUdafData() = default;
AggregateJavaUdafData(int64_t num_args) {
argument_size = num_args;
output_value_buffer = std::make_unique<int64_t>(0);
output_null_value = std::make_unique<int64_t>(0);
output_offsets_ptr = std::make_unique<int64_t>(0);
output_intermediate_state_ptr = std::make_unique<int64_t>(0);
output_array_null_ptr = std::make_unique<int64_t>(0);
output_array_string_offsets_ptr = std::make_unique<int64_t>(0);
}
AggregateJavaUdafData(int64_t num_args) { argument_size = num_args; }
~AggregateJavaUdafData() {
JNIEnv* env;
@ -89,16 +82,6 @@ public:
ctor_params.__set_fn(fn);
ctor_params.__set_location(local_location);
ctor_params.__set_output_buffer_ptr((int64_t)output_value_buffer.get());
ctor_params.__set_output_null_ptr((int64_t)output_null_value.get());
ctor_params.__set_output_offsets_ptr((int64_t)output_offsets_ptr.get());
ctor_params.__set_output_intermediate_state_ptr(
(int64_t)output_intermediate_state_ptr.get());
ctor_params.__set_output_array_null_ptr((int64_t)output_array_null_ptr.get());
ctor_params.__set_output_array_string_offsets_ptr(
(int64_t)output_array_string_offsets_ptr.get());
jbyteArray ctor_params_bytes;
// Pushed frame will be popped when jni_frame goes out-of-scope.
@ -295,23 +278,27 @@ public:
to.insert_default();
JNIEnv* env = nullptr;
RETURN_NOT_OK_STATUS_WITH_WARN(JniUtil::GetJNIEnv(&env), "Java-Udaf get value function");
int64_t nullmap_address = 0;
if (result_type->is_nullable()) {
auto& nullable = assert_cast<ColumnNullable&>(to);
*output_null_value =
nullmap_address =
reinterpret_cast<int64_t>(nullable.get_null_map_column().get_raw_data().data);
auto& data_col = nullable.get_nested_column();
RETURN_IF_ERROR(get_result(to, result_type, place, env, data_col));
RETURN_IF_ERROR(get_result(to, result_type, place, env, data_col, nullmap_address));
} else {
*output_null_value = -1;
nullmap_address = -1;
auto& data_col = to;
RETURN_IF_ERROR(get_result(to, result_type, place, env, data_col));
RETURN_IF_ERROR(get_result(to, result_type, place, env, data_col, nullmap_address));
}
return JniUtil::GetJniExceptionMsg(env);
}
private:
Status get_result(IColumn& to, const DataTypePtr& result_type, int64_t place, JNIEnv* env,
IColumn& data_col) const {
Status get_result(IColumn& to, const DataTypePtr& return_type, int64_t place, JNIEnv* env,
IColumn& data_col, int64_t nullmap_address) const {
jobject result_obj = env->CallNonvirtualObjectMethod(executor_obj, executor_cl,
executor_get_value_id, place);
bool result_nullable = return_type->is_nullable();
if (data_col.is_column_string()) {
const ColumnString* str_col = check_and_get_column<ColumnString>(data_col);
ColumnString::Chars& chars = const_cast<ColumnString::Chars&>(str_col->get_chars());
@ -320,109 +307,119 @@ private:
int increase_buffer_size = 0;
int64_t buffer_size = JniUtil::IncreaseReservedBufferSize(increase_buffer_size);
chars.resize(buffer_size);
*output_value_buffer = reinterpret_cast<int64_t>(chars.data());
*output_offsets_ptr = reinterpret_cast<int64_t>(offsets.data());
*output_intermediate_state_ptr = chars.size();
jboolean res = env->CallNonvirtualBooleanMethod(
executor_obj, executor_cl, executor_result_id, to.size() - 1, place);
while (res != JNI_TRUE) {
RETURN_IF_ERROR(JniUtil::GetJniExceptionMsg(env));
increase_buffer_size++;
buffer_size = JniUtil::IncreaseReservedBufferSize(increase_buffer_size);
try {
chars.resize(buffer_size);
} catch (std::bad_alloc const& e) {
throw doris::Exception(ErrorCode::INTERNAL_ERROR,
"memory allocate failed in column string, "
"buffer:{},size:{},reason:{}",
increase_buffer_size, buffer_size, e.what());
}
*output_value_buffer = reinterpret_cast<int64_t>(chars.data());
*output_intermediate_state_ptr = chars.size();
res = env->CallNonvirtualBooleanMethod(executor_obj, executor_cl,
executor_result_id, to.size() - 1, place);
}
env->CallNonvirtualVoidMethod(
executor_obj, executor_cl, executor_copy_basic_result_id, result_obj,
to.size() - 1, nullmap_address, reinterpret_cast<int64_t>(chars.data()),
reinterpret_cast<int64_t>(&chars), reinterpret_cast<int64_t>(offsets.data()));
} else if (data_col.is_numeric() || data_col.is_column_decimal()) {
*output_value_buffer = reinterpret_cast<int64_t>(data_col.get_raw_data().data);
env->CallNonvirtualBooleanMethod(executor_obj, executor_cl, executor_result_id,
to.size() - 1, place);
env->CallNonvirtualVoidMethod(executor_obj, executor_cl, executor_copy_basic_result_id,
result_obj, to.size() - 1, nullmap_address,
reinterpret_cast<int64_t>(data_col.get_raw_data().data),
0, 0);
} else if (data_col.is_column_array()) {
ColumnArray& array_col = assert_cast<ColumnArray&>(data_col);
jclass arraylist_class = env->FindClass("Ljava/util/ArrayList;");
ColumnArray* array_col = assert_cast<ColumnArray*>(&data_col);
ColumnNullable& array_nested_nullable =
assert_cast<ColumnNullable&>(array_col.get_data());
assert_cast<ColumnNullable&>(array_col->get_data());
auto data_column_null_map = array_nested_nullable.get_null_map_column_ptr();
auto data_column = array_nested_nullable.get_nested_column_ptr();
auto& offset_column = array_col.get_offsets_column();
int increase_buffer_size = 0;
int64_t buffer_size = JniUtil::IncreaseReservedBufferSize(increase_buffer_size);
*output_offsets_ptr = reinterpret_cast<int64_t>(offset_column.get_raw_data().data);
data_column_null_map->resize(buffer_size);
auto& offset_column = array_col->get_offsets_column();
auto offset_address = reinterpret_cast<int64_t>(offset_column.get_raw_data().data);
auto& null_map_data =
assert_cast<ColumnVector<UInt8>*>(data_column_null_map.get())->get_data();
*output_array_null_ptr = reinterpret_cast<int64_t>(null_map_data.data());
*output_intermediate_state_ptr = buffer_size;
auto nested_nullmap_address = reinterpret_cast<int64_t>(null_map_data.data());
jmethodID list_size = env->GetMethodID(arraylist_class, "size", "()I");
size_t has_put_element_size = array_col->get_offsets().back();
size_t arrar_list_size = env->CallIntMethod(result_obj, list_size);
size_t element_size = has_put_element_size + arrar_list_size;
array_nested_nullable.resize(element_size);
memset(null_map_data.data() + has_put_element_size, 0, arrar_list_size);
int64_t nested_data_address = 0, nested_offset_address = 0;
if (data_column->is_column_string()) {
ColumnString* str_col = assert_cast<ColumnString*>(data_column.get());
ColumnString::Chars& chars =
assert_cast<ColumnString::Chars&>(str_col->get_chars());
ColumnString::Offsets& offsets =
assert_cast<ColumnString::Offsets&>(str_col->get_offsets());
chars.resize(buffer_size);
offsets.resize(buffer_size);
*output_value_buffer = reinterpret_cast<int64_t>(chars.data());
*output_array_string_offsets_ptr = reinterpret_cast<int64_t>(offsets.data());
jboolean res = env->CallNonvirtualBooleanMethod(
executor_obj, executor_cl, executor_result_id, to.size() - 1, place);
while (res != JNI_TRUE) {
RETURN_IF_ERROR(JniUtil::GetJniExceptionMsg(env));
increase_buffer_size++;
buffer_size = JniUtil::IncreaseReservedBufferSize(increase_buffer_size);
try {
null_map_data.resize(buffer_size);
chars.resize(buffer_size);
offsets.resize(buffer_size);
} catch (std::bad_alloc const& e) {
throw doris::Exception(ErrorCode::INTERNAL_ERROR,
"memory allocate failed in array column string, "
"buffer:{},size:{},reason:{}",
increase_buffer_size, buffer_size, e.what());
}
*output_array_null_ptr = reinterpret_cast<int64_t>(null_map_data.data());
*output_value_buffer = reinterpret_cast<int64_t>(chars.data());
*output_array_string_offsets_ptr = reinterpret_cast<int64_t>(offsets.data());
*output_intermediate_state_ptr = buffer_size;
res = env->CallNonvirtualBooleanMethod(
executor_obj, executor_cl, executor_result_id, to.size() - 1, place);
}
nested_data_address = reinterpret_cast<int64_t>(&chars);
nested_offset_address = reinterpret_cast<int64_t>(offsets.data());
} else {
data_column->resize(buffer_size);
*output_value_buffer = reinterpret_cast<int64_t>(data_column->get_raw_data().data);
jboolean res = env->CallNonvirtualBooleanMethod(
executor_obj, executor_cl, executor_result_id, to.size() - 1, place);
while (res != JNI_TRUE) {
RETURN_IF_ERROR(JniUtil::GetJniExceptionMsg(env));
increase_buffer_size++;
buffer_size = JniUtil::IncreaseReservedBufferSize(increase_buffer_size);
try {
null_map_data.resize(buffer_size);
data_column->resize(buffer_size);
} catch (std::bad_alloc const& e) {
throw doris::Exception(ErrorCode::INTERNAL_ERROR,
"memory allocate failed in array number column, "
"buffer:{},size:{},reason:{}",
increase_buffer_size, buffer_size, e.what());
}
*output_array_null_ptr = reinterpret_cast<int64_t>(null_map_data.data());
*output_value_buffer =
reinterpret_cast<int64_t>(data_column->get_raw_data().data);
*output_intermediate_state_ptr = buffer_size;
res = env->CallNonvirtualBooleanMethod(
executor_obj, executor_cl, executor_result_id, to.size() - 1, place);
}
nested_data_address = reinterpret_cast<int64_t>(data_column->get_raw_data().data);
}
int row = to.size() - 1;
env->CallNonvirtualVoidMethod(executor_obj, executor_cl, executor_copy_array_result_id,
has_put_element_size, result_nullable, row, result_obj,
nullmap_address, offset_address, nested_nullmap_address,
nested_data_address, nested_offset_address);
env->DeleteLocalRef(arraylist_class);
} else if (data_col.is_column_map()) {
jclass hashmap_class = env->FindClass("Ljava/util/HashMap;");
ColumnMap* map_col = assert_cast<ColumnMap*>(&data_col);
auto& offset_column = map_col->get_offsets_column();
auto offset_address = reinterpret_cast<int64_t>(offset_column.get_raw_data().data);
ColumnNullable& map_key_column_nullable =
assert_cast<ColumnNullable&>(map_col->get_keys());
auto key_data_column_null_map = map_key_column_nullable.get_null_map_column_ptr();
auto key_data_column = map_key_column_nullable.get_nested_column_ptr();
auto& key_null_map_data =
assert_cast<ColumnVector<UInt8>*>(key_data_column_null_map.get())->get_data();
auto key_nested_nullmap_address = reinterpret_cast<int64_t>(key_null_map_data.data());
ColumnNullable& map_value_column_nullable =
assert_cast<ColumnNullable&>(map_col->get_values());
auto value_data_column_null_map = map_value_column_nullable.get_null_map_column_ptr();
auto value_data_column = map_value_column_nullable.get_nested_column_ptr();
auto& value_null_map_data =
assert_cast<ColumnVector<UInt8>*>(value_data_column_null_map.get())->get_data();
auto value_nested_nullmap_address =
reinterpret_cast<int64_t>(value_null_map_data.data());
jmethodID map_size = env->GetMethodID(hashmap_class, "size", "()I");
size_t has_put_element_size = map_col->get_offsets().back();
size_t hashmap_size = env->CallIntMethod(result_obj, map_size);
size_t element_size = has_put_element_size + hashmap_size;
map_key_column_nullable.resize(element_size);
memset(key_null_map_data.data() + has_put_element_size, 0, hashmap_size);
map_value_column_nullable.resize(element_size);
memset(value_null_map_data.data() + has_put_element_size, 0, hashmap_size);
int64_t key_nested_data_address = 0, key_nested_offset_address = 0;
if (key_data_column->is_column_string()) {
ColumnString* str_col = assert_cast<ColumnString*>(key_data_column.get());
ColumnString::Chars& chars =
assert_cast<ColumnString::Chars&>(str_col->get_chars());
ColumnString::Offsets& offsets =
assert_cast<ColumnString::Offsets&>(str_col->get_offsets());
key_nested_data_address = reinterpret_cast<int64_t>(&chars);
key_nested_offset_address = reinterpret_cast<int64_t>(offsets.data());
} else {
key_nested_data_address =
reinterpret_cast<int64_t>(key_data_column->get_raw_data().data);
}
int64_t value_nested_data_address = 0, value_nested_offset_address = 0;
if (value_data_column->is_column_string()) {
ColumnString* str_col = assert_cast<ColumnString*>(value_data_column.get());
ColumnString::Chars& chars =
assert_cast<ColumnString::Chars&>(str_col->get_chars());
ColumnString::Offsets& offsets =
assert_cast<ColumnString::Offsets&>(str_col->get_offsets());
value_nested_data_address = reinterpret_cast<int64_t>(&chars);
value_nested_offset_address = reinterpret_cast<int64_t>(offsets.data());
} else {
value_nested_data_address =
reinterpret_cast<int64_t>(value_data_column->get_raw_data().data);
}
int row = to.size() - 1;
env->CallNonvirtualVoidMethod(executor_obj, executor_cl, executor_copy_map_result_id,
has_put_element_size, result_nullable, row, result_obj,
nullmap_address, offset_address,
key_nested_nullmap_address, key_nested_data_address,
key_nested_offset_address, value_nested_nullmap_address,
value_nested_data_address, value_nested_offset_address);
env->DeleteLocalRef(hashmap_class);
} else {
return Status::InvalidArgument(strings::Substitute(
"Java UDAF doesn't support return type is $0 now !", result_type->get_name()));
"Java UDAF doesn't support return type is $0 now !", return_type->get_name()));
}
return Status::OK();
}
@ -438,14 +435,12 @@ private:
return s;
};
RETURN_IF_ERROR(register_id("<init>", UDAF_EXECUTOR_CTOR_SIGNATURE, executor_ctor_id));
RETURN_IF_ERROR(register_id("add", UDAF_EXECUTOR_ADD_SIGNATURE, executor_add_id));
RETURN_IF_ERROR(register_id("reset", UDAF_EXECUTOR_RESET_SIGNATURE, executor_reset_id));
RETURN_IF_ERROR(register_id("close", UDAF_EXECUTOR_CLOSE_SIGNATURE, executor_close_id));
RETURN_IF_ERROR(register_id("merge", UDAF_EXECUTOR_MERGE_SIGNATURE, executor_merge_id));
RETURN_IF_ERROR(
register_id("serialize", UDAF_EXECUTOR_SERIALIZE_SIGNATURE, executor_serialize_id));
RETURN_IF_ERROR(
register_id("getValue", UDAF_EXECUTOR_RESULT_SIGNATURE, executor_result_id));
RETURN_IF_ERROR(register_id("getValue", "(J)Ljava/lang/Object;", executor_get_value_id));
RETURN_IF_ERROR(
register_id("destroy", UDAF_EXECUTOR_DESTROY_SIGNATURE, executor_destroy_id));
RETURN_IF_ERROR(register_id("convertBasicArguments", "(IZIIJJJ)[Ljava/lang/Object;",
@ -454,6 +449,16 @@ private:
executor_convert_array_argument_id));
RETURN_IF_ERROR(register_id("convertMapArguments", "(IZIIJJJJJJJJ)[Ljava/lang/Object;",
executor_convert_map_argument_id));
RETURN_IF_ERROR(register_id("copyTupleBasicResult", "(Ljava/lang/Object;IJJJJ)V",
executor_copy_basic_result_id));
RETURN_IF_ERROR(register_id("copyTupleArrayResult", "(JZILjava/lang/Object;JJJJJ)V",
executor_copy_array_result_id));
RETURN_IF_ERROR(register_id("copyTupleMapResult", "(JZILjava/lang/Object;JJJJJJJJ)V",
executor_copy_map_result_id));
RETURN_IF_ERROR(
register_id("addBatch", "(ZIIJI[Ljava/lang/Object;)V", executor_add_batch_id));
return Status::OK();
@ -466,24 +471,19 @@ private:
jobject executor_obj;
jmethodID executor_ctor_id;
jmethodID executor_add_id;
jmethodID executor_add_batch_id;
jmethodID executor_merge_id;
jmethodID executor_serialize_id;
jmethodID executor_result_id;
jmethodID executor_get_value_id;
jmethodID executor_reset_id;
jmethodID executor_close_id;
jmethodID executor_destroy_id;
jmethodID executor_convert_basic_argument_id;
jmethodID executor_convert_array_argument_id;
jmethodID executor_convert_map_argument_id;
std::unique_ptr<int64_t> output_value_buffer;
std::unique_ptr<int64_t> output_null_value;
std::unique_ptr<int64_t> output_offsets_ptr;
std::unique_ptr<int64_t> output_intermediate_state_ptr;
std::unique_ptr<int64_t> output_array_null_ptr;
std::unique_ptr<int64_t> output_array_string_offsets_ptr;
jmethodID executor_copy_basic_result_id;
jmethodID executor_copy_array_result_id;
jmethodID executor_copy_map_result_id;
int argument_size = 0;
std::string serialize_data;
};

View File

@ -346,19 +346,6 @@ Status JavaFunctionCall::execute(FunctionContext* context, Block& block,
auto& key_null_map_data =
assert_cast<ColumnVector<UInt8>*>(key_data_column_null_map.get())->get_data();
auto key_nested_nullmap_address = reinterpret_cast<int64_t>(key_null_map_data.data());
int64_t key_nested_data_address = 0, key_nested_offset_address = 0;
if (key_data_column->is_column_string()) {
ColumnString* str_col = assert_cast<ColumnString*>(key_data_column.get());
ColumnString::Chars& chars = assert_cast<ColumnString::Chars&>(str_col->get_chars());
ColumnString::Offsets& offsets =
assert_cast<ColumnString::Offsets&>(str_col->get_offsets());
key_nested_data_address = reinterpret_cast<int64_t>(&chars);
key_nested_offset_address = reinterpret_cast<int64_t>(offsets.data());
} else {
key_nested_data_address =
reinterpret_cast<int64_t>(key_data_column->get_raw_data().data);
}
ColumnNullable& map_value_column_nullable =
assert_cast<ColumnNullable&>(map_col->get_values());
auto value_data_column_null_map = map_value_column_nullable.get_null_map_column_ptr();
@ -366,19 +353,6 @@ Status JavaFunctionCall::execute(FunctionContext* context, Block& block,
auto& value_null_map_data =
assert_cast<ColumnVector<UInt8>*>(value_data_column_null_map.get())->get_data();
auto value_nested_nullmap_address = reinterpret_cast<int64_t>(value_null_map_data.data());
int64_t value_nested_data_address = 0, value_nested_offset_address = 0;
// array type need pass address: [nullmap_address], offset_address, nested_nullmap_address, nested_data_address/nested_char_address,nested_offset_address
if (value_data_column->is_column_string()) {
ColumnString* str_col = assert_cast<ColumnString*>(value_data_column.get());
ColumnString::Chars& chars = assert_cast<ColumnString::Chars&>(str_col->get_chars());
ColumnString::Offsets& offsets =
assert_cast<ColumnString::Offsets&>(str_col->get_offsets());
value_nested_data_address = reinterpret_cast<int64_t>(&chars);
value_nested_offset_address = reinterpret_cast<int64_t>(offsets.data());
} else {
value_nested_data_address =
reinterpret_cast<int64_t>(value_data_column->get_raw_data().data);
}
jmethodID map_size = env->GetMethodID(hashmap_class, "size", "()I");
int element_size = 0; // get all element size in num_rows of map column
for (int i = 0; i < num_rows; ++i) {
@ -393,6 +367,30 @@ Status JavaFunctionCall::execute(FunctionContext* context, Block& block,
memset(key_null_map_data.data(), 0, element_size);
map_value_column_nullable.resize(element_size);
memset(value_null_map_data.data(), 0, element_size);
int64_t key_nested_data_address = 0, key_nested_offset_address = 0;
if (key_data_column->is_column_string()) {
ColumnString* str_col = assert_cast<ColumnString*>(key_data_column.get());
ColumnString::Chars& chars = assert_cast<ColumnString::Chars&>(str_col->get_chars());
ColumnString::Offsets& offsets =
assert_cast<ColumnString::Offsets&>(str_col->get_offsets());
key_nested_data_address = reinterpret_cast<int64_t>(&chars);
key_nested_offset_address = reinterpret_cast<int64_t>(offsets.data());
} else {
key_nested_data_address =
reinterpret_cast<int64_t>(key_data_column->get_raw_data().data);
}
int64_t value_nested_data_address = 0, value_nested_offset_address = 0;
if (value_data_column->is_column_string()) {
ColumnString* str_col = assert_cast<ColumnString*>(value_data_column.get());
ColumnString::Chars& chars = assert_cast<ColumnString::Chars&>(str_col->get_chars());
ColumnString::Offsets& offsets =
assert_cast<ColumnString::Offsets&>(str_col->get_offsets());
value_nested_data_address = reinterpret_cast<int64_t>(&chars);
value_nested_offset_address = reinterpret_cast<int64_t>(offsets.data());
} else {
value_nested_data_address =
reinterpret_cast<int64_t>(value_data_column->get_raw_data().data);
}
env->CallNonvirtualVoidMethod(jni_ctx->executor, jni_env->executor_cl,
jni_env->executor_result_map_batch_id, result_nullable,
num_rows, result_obj, nullmap_address, offset_address,

View File

@ -163,43 +163,6 @@ public class UdafExecutor extends BaseExecutor {
}
}
/**
* invoke add function, add row in loop [rowStart, rowEnd).
*/
public void add(boolean isSinglePlace, long rowStart, long rowEnd) throws UdfRuntimeException {
try {
long idx = rowStart;
do {
Long curPlace = null;
if (isSinglePlace) {
curPlace = UdfUtils.UNSAFE.getLong(null, UdfUtils.UNSAFE.getLong(null, inputPlacesPtr));
} else {
curPlace = UdfUtils.UNSAFE.getLong(null, UdfUtils.UNSAFE.getLong(null, inputPlacesPtr) + 8L * idx);
}
Object[] inputArgs = new Object[argTypes.length + 1];
Object state = stateObjMap.get(curPlace);
if (state != null) {
inputArgs[0] = state;
} else {
Object newState = createAggState();
stateObjMap.put(curPlace, newState);
inputArgs[0] = newState;
}
do {
Object[] inputObjects = allocateInputObjects(idx, 1);
for (int i = 0; i < argTypes.length; ++i) {
inputArgs[i + 1] = inputObjects[i];
}
allMethods.get(UDAF_ADD_FUNCTION).invoke(udf, inputArgs);
idx++;
} while (isSinglePlace && idx < rowEnd);
} while (idx < rowEnd);
} catch (Exception e) {
LOG.warn("invoke add function meet some error: " + e.getCause().toString());
throw new UdfRuntimeException("UDAF failed to add: ", e);
}
}
/**
* invoke user create function to get obj.
*/
@ -292,40 +255,71 @@ public class UdafExecutor extends BaseExecutor {
/**
* invoke getValue to return finally result.
*/
public boolean getValue(long row, long place) throws UdfRuntimeException {
public Object getValue(long place) throws UdfRuntimeException {
try {
if (stateObjMap.get(place) == null) {
stateObjMap.put(place, createAggState());
}
return storeUdfResult(allMethods.get(UDAF_RESULT_FUNCTION).invoke(udf, stateObjMap.get((Long) place)),
row, retClass);
return allMethods.get(UDAF_RESULT_FUNCTION).invoke(udf, stateObjMap.get((Long) place));
} catch (Exception e) {
LOG.warn("invoke getValue function meet some error: " + e.getCause().toString());
throw new UdfRuntimeException("UDAF failed to result", e);
}
}
@Override
protected boolean storeUdfResult(Object obj, long row, Class retClass) throws UdfRuntimeException {
if (obj == null) {
// If result is null, return true directly when row == 0 as we have already inserted default value.
if (UdfUtils.UNSAFE.getLong(null, outputNullPtr) == -1) {
public void copyTupleBasicResult(Object result, int row, long outputNullMapPtr, long outputBufferBase,
long charsAddress,
long offsetsAddr) throws UdfRuntimeException {
if (result == null) {
// put null obj
if (outputNullMapPtr == -1) {
throw new UdfRuntimeException("UDAF failed to store null data to not null column");
} else {
UdfUtils.UNSAFE.putByte(outputNullMapPtr + row, (byte) 1);
}
return true;
return;
}
try {
if (outputNullMapPtr != -1) {
UdfUtils.UNSAFE.putByte(outputNullMapPtr + row, (byte) 0);
}
copyTupleBasicResult(result, row, retClass, outputBufferBase, charsAddress,
offsetsAddr, retType);
} catch (UdfRuntimeException e) {
LOG.info(e.toString());
}
return super.storeUdfResult(obj, row, retClass);
}
@Override
protected long getCurrentOutputOffset(long row, boolean isArrayType) {
if (isArrayType) {
return Integer.toUnsignedLong(
UdfUtils.UNSAFE.getInt(null, UdfUtils.UNSAFE.getLong(null, outputOffsetsPtr) + 8L * (row - 1)));
} else {
return Integer.toUnsignedLong(
UdfUtils.UNSAFE.getInt(null, UdfUtils.UNSAFE.getLong(null, outputOffsetsPtr) + 4L * (row - 1)));
public void copyTupleArrayResult(long hasPutElementNum, boolean isNullable, int row, Object result,
long nullMapAddr,
long offsetsAddr, long nestedNullMapAddr, long dataAddr, long strOffsetAddr) throws UdfRuntimeException {
if (nullMapAddr > 0) {
UdfUtils.UNSAFE.putByte(nullMapAddr + row, (byte) 0);
}
copyTupleArrayResultImpl(hasPutElementNum, isNullable, row, result, nullMapAddr, offsetsAddr, nestedNullMapAddr,
dataAddr, strOffsetAddr, retType.getItemType().getPrimitiveType());
}
public void copyTupleMapResult(long hasPutElementNum, boolean isNullable, int row, Object result, long nullMapAddr,
long offsetsAddr,
long keyNsestedNullMapAddr, long keyDataAddr, long keyStrOffsetAddr,
long valueNsestedNullMapAddr, long valueDataAddr, long valueStrOffsetAddr) throws UdfRuntimeException {
if (nullMapAddr > 0) {
UdfUtils.UNSAFE.putByte(nullMapAddr + row, (byte) 0);
}
PrimitiveType keyType = retType.getKeyType().getPrimitiveType();
PrimitiveType valueType = retType.getValueType().getPrimitiveType();
Object[] keyCol = new Object[1];
Object[] valueCol = new Object[1];
Object[] resultArr = new Object[1];
resultArr[0] = result;
buildArrayListFromHashMap(resultArr, keyType, valueType, keyCol, valueCol);
copyTupleArrayResultImpl(hasPutElementNum, isNullable, row,
valueCol[0], nullMapAddr, offsetsAddr,
valueNsestedNullMapAddr, valueDataAddr, valueStrOffsetAddr, valueType);
copyTupleArrayResultImpl(hasPutElementNum, isNullable, row, keyCol[0], nullMapAddr, offsetsAddr,
keyNsestedNullMapAddr, keyDataAddr, keyStrOffsetAddr, keyType);
}
@Override

View File

@ -707,9 +707,9 @@ public class UdfConvert {
//////////////////////////////////// copyBatchArray//////////////////////////////////////////////////////////
public static long copyBatchArrayBooleanResult(long hasPutElementNum, boolean isNullable, int row, Object[] result,
public static long copyBatchArrayBooleanResult(long hasPutElementNum, boolean isNullable, int row, Object result,
long nullMapAddr, long offsetsAddr, long nestedNullMapAddr, long dataAddr) {
ArrayList<Boolean> data = (ArrayList<Boolean>) result[row];
ArrayList<Boolean> data = (ArrayList<Boolean>) result;
if (isNullable) {
if (data == null) {
UdfUtils.UNSAFE.putByte(nullMapAddr + row, (byte) 1);
@ -741,9 +741,9 @@ public class UdfConvert {
return hasPutElementNum;
}
public static long copyBatchArrayTinyIntResult(long hasPutElementNum, boolean isNullable, int row, Object[] result,
public static long copyBatchArrayTinyIntResult(long hasPutElementNum, boolean isNullable, int row, Object result,
long nullMapAddr, long offsetsAddr, long nestedNullMapAddr, long dataAddr) {
ArrayList<Byte> data = (ArrayList<Byte>) result[row];
ArrayList<Byte> data = (ArrayList<Byte>) result;
if (isNullable) {
if (data == null) {
UdfUtils.UNSAFE.putByte(nullMapAddr + row, (byte) 1);
@ -775,9 +775,9 @@ public class UdfConvert {
return hasPutElementNum;
}
public static long copyBatchArraySmallIntResult(long hasPutElementNum, boolean isNullable, int row, Object[] result,
public static long copyBatchArraySmallIntResult(long hasPutElementNum, boolean isNullable, int row, Object result,
long nullMapAddr, long offsetsAddr, long nestedNullMapAddr, long dataAddr) {
ArrayList<Short> data = (ArrayList<Short>) result[row];
ArrayList<Short> data = (ArrayList<Short>) result;
if (isNullable) {
if (data == null) {
UdfUtils.UNSAFE.putByte(nullMapAddr + row, (byte) 1);
@ -809,9 +809,9 @@ public class UdfConvert {
return hasPutElementNum;
}
public static long copyBatchArrayIntResult(long hasPutElementNum, boolean isNullable, int row, Object[] result,
public static long copyBatchArrayIntResult(long hasPutElementNum, boolean isNullable, int row, Object result,
long nullMapAddr, long offsetsAddr, long nestedNullMapAddr, long dataAddr) {
ArrayList<Integer> data = (ArrayList<Integer>) result[row];
ArrayList<Integer> data = (ArrayList<Integer>) result;
if (isNullable) {
if (data == null) {
UdfUtils.UNSAFE.putByte(nullMapAddr + row, (byte) 1);
@ -843,9 +843,9 @@ public class UdfConvert {
return hasPutElementNum;
}
public static long copyBatchArrayBigIntResult(long hasPutElementNum, boolean isNullable, int row, Object[] result,
public static long copyBatchArrayBigIntResult(long hasPutElementNum, boolean isNullable, int row, Object result,
long nullMapAddr, long offsetsAddr, long nestedNullMapAddr, long dataAddr) {
ArrayList<Long> data = (ArrayList<Long>) result[row];
ArrayList<Long> data = (ArrayList<Long>) result;
if (isNullable) {
if (data == null) {
UdfUtils.UNSAFE.putByte(nullMapAddr + row, (byte) 1);
@ -877,9 +877,9 @@ public class UdfConvert {
return hasPutElementNum;
}
public static long copyBatchArrayFloatResult(long hasPutElementNum, boolean isNullable, int row, Object[] result,
public static long copyBatchArrayFloatResult(long hasPutElementNum, boolean isNullable, int row, Object result,
long nullMapAddr, long offsetsAddr, long nestedNullMapAddr, long dataAddr) {
ArrayList<Float> data = (ArrayList<Float>) result[row];
ArrayList<Float> data = (ArrayList<Float>) result;
if (isNullable) {
if (data == null) {
UdfUtils.UNSAFE.putByte(nullMapAddr + row, (byte) 1);
@ -911,9 +911,9 @@ public class UdfConvert {
return hasPutElementNum;
}
public static long copyBatchArrayDoubleResult(long hasPutElementNum, boolean isNullable, int row, Object[] result,
public static long copyBatchArrayDoubleResult(long hasPutElementNum, boolean isNullable, int row, Object result,
long nullMapAddr, long offsetsAddr, long nestedNullMapAddr, long dataAddr) {
ArrayList<Double> data = (ArrayList<Double>) result[row];
ArrayList<Double> data = (ArrayList<Double>) result;
if (isNullable) {
if (data == null) {
UdfUtils.UNSAFE.putByte(nullMapAddr + row, (byte) 1);
@ -945,9 +945,9 @@ public class UdfConvert {
return hasPutElementNum;
}
public static long copyBatchArrayDateResult(long hasPutElementNum, boolean isNullable, int row, Object[] result,
public static long copyBatchArrayDateResult(long hasPutElementNum, boolean isNullable, int row, Object result,
long nullMapAddr, long offsetsAddr, long nestedNullMapAddr, long dataAddr) {
ArrayList<LocalDate> data = (ArrayList<LocalDate>) result[row];
ArrayList<LocalDate> data = (ArrayList<LocalDate>) result;
if (isNullable) {
if (data == null) {
UdfUtils.UNSAFE.putByte(nullMapAddr + row, (byte) 1);
@ -981,9 +981,9 @@ public class UdfConvert {
return hasPutElementNum;
}
public static long copyBatchArrayDateTimeResult(long hasPutElementNum, boolean isNullable, int row, Object[] result,
public static long copyBatchArrayDateTimeResult(long hasPutElementNum, boolean isNullable, int row, Object result,
long nullMapAddr, long offsetsAddr, long nestedNullMapAddr, long dataAddr) {
ArrayList<LocalDateTime> data = (ArrayList<LocalDateTime>) result[row];
ArrayList<LocalDateTime> data = (ArrayList<LocalDateTime>) result;
if (isNullable) {
if (data == null) {
UdfUtils.UNSAFE.putByte(nullMapAddr + row, (byte) 1);
@ -1017,9 +1017,9 @@ public class UdfConvert {
return hasPutElementNum;
}
public static long copyBatchArrayDateV2Result(long hasPutElementNum, boolean isNullable, int row, Object[] result,
public static long copyBatchArrayDateV2Result(long hasPutElementNum, boolean isNullable, int row, Object result,
long nullMapAddr, long offsetsAddr, long nestedNullMapAddr, long dataAddr) {
ArrayList<LocalDate> data = (ArrayList<LocalDate>) result[row];
ArrayList<LocalDate> data = (ArrayList<LocalDate>) result;
if (isNullable) {
if (data == null) {
UdfUtils.UNSAFE.putByte(nullMapAddr + row, (byte) 1);
@ -1054,9 +1054,9 @@ public class UdfConvert {
}
public static long copyBatchArrayDateTimeV2Result(long hasPutElementNum, boolean isNullable, int row,
Object[] result,
Object result,
long nullMapAddr, long offsetsAddr, long nestedNullMapAddr, long dataAddr) {
ArrayList<LocalDateTime> data = (ArrayList<LocalDateTime>) result[row];
ArrayList<LocalDateTime> data = (ArrayList<LocalDateTime>) result;
if (isNullable) {
if (data == null) {
UdfUtils.UNSAFE.putByte(nullMapAddr + row, (byte) 1);
@ -1090,9 +1090,9 @@ public class UdfConvert {
return hasPutElementNum;
}
public static long copyBatchArrayLargeIntResult(long hasPutElementNum, boolean isNullable, int row, Object[] result,
public static long copyBatchArrayLargeIntResult(long hasPutElementNum, boolean isNullable, int row, Object result,
long nullMapAddr, long offsetsAddr, long nestedNullMapAddr, long dataAddr) {
ArrayList<BigInteger> data = (ArrayList<BigInteger>) result[row];
ArrayList<BigInteger> data = (ArrayList<BigInteger>) result;
if (isNullable) {
if (data == null) {
UdfUtils.UNSAFE.putByte(nullMapAddr + row, (byte) 1);
@ -1140,9 +1140,9 @@ public class UdfConvert {
return hasPutElementNum;
}
public static long copyBatchArrayDecimalResult(long hasPutElementNum, boolean isNullable, int row, Object[] result,
public static long copyBatchArrayDecimalResult(long hasPutElementNum, boolean isNullable, int row, Object result,
long nullMapAddr, long offsetsAddr, long nestedNullMapAddr, long dataAddr) {
ArrayList<BigDecimal> data = (ArrayList<BigDecimal>) result[row];
ArrayList<BigDecimal> data = (ArrayList<BigDecimal>) result;
if (isNullable) {
if (data == null) {
UdfUtils.UNSAFE.putByte(nullMapAddr + row, (byte) 1);
@ -1194,9 +1194,9 @@ public class UdfConvert {
public static long copyBatchArrayDecimalV3Result(int scale, long typeLen, long hasPutElementNum, boolean isNullable,
int row,
Object[] result,
Object result,
long nullMapAddr, long offsetsAddr, long nestedNullMapAddr, long dataAddr) {
ArrayList<BigDecimal> data = (ArrayList<BigDecimal>) result[row];
ArrayList<BigDecimal> data = (ArrayList<BigDecimal>) result;
if (isNullable) {
if (data == null) {
UdfUtils.UNSAFE.putByte(nullMapAddr + row, (byte) 1);
@ -1247,9 +1247,9 @@ public class UdfConvert {
}
public static long copyBatchArrayStringResult(long hasPutElementNum, boolean isNullable, int row,
Object[] result, long nullMapAddr, long offsetsAddr, long nestedNullMapAddr, long dataAddr,
Object result, long nullMapAddr, long offsetsAddr, long nestedNullMapAddr, long dataAddr,
long strOffsetAddr) {
ArrayList<String> data = (ArrayList<String>) result[row];
ArrayList<String> data = (ArrayList<String>) result;
if (isNullable) {
if (data == null) {
UdfUtils.UNSAFE.putByte(nullMapAddr + row, (byte) 1);
@ -1270,8 +1270,12 @@ public class UdfConvert {
offset += byteRes[i].length;
offsets[i] = offset;
}
byte[] bytes = new byte[offsets[num - 1] - oldOffsetNum];
long bytesAddr = JNINativeMethod.resizeStringColumn(dataAddr, offsets[num - 1]);
int oldSzie = 0;
if (num > 0) {
oldSzie = offsets[num - 1];
}
byte[] bytes = new byte[oldSzie - oldOffsetNum];
long bytesAddr = JNINativeMethod.resizeStringColumn(dataAddr, oldSzie);
int dst = 0;
for (int i = 0; i < num; i++) {
for (int j = 0; j < byteRes[i].length; j++) {
@ -1281,7 +1285,7 @@ public class UdfConvert {
UdfUtils.copyMemory(offsets, UdfUtils.INT_ARRAY_OFFSET, null, strOffsetAddr + (4L * hasPutElementNum),
num * 4L);
UdfUtils.copyMemory(bytes, UdfUtils.BYTE_ARRAY_OFFSET, null, bytesAddr + oldOffsetNum,
offsets[num - 1] - oldOffsetNum);
oldSzie - oldOffsetNum);
hasPutElementNum = hasPutElementNum + num;
}
} else {
@ -1300,9 +1304,13 @@ public class UdfConvert {
offset += byteRes[i].length;
offsets[i] = offset;
}
byte[] bytes = new byte[offsets[num - 1]];
int oldOffsetNum = UdfUtils.UNSAFE.getInt(null, strOffsetAddr + ((hasPutElementNum - 1) * 4L));
long bytesAddr = JNINativeMethod.resizeStringColumn(dataAddr, oldOffsetNum + offsets[num - 1]);
int oldSzie = 0;
if (num > 0) {
oldSzie = offsets[num - 1];
}
byte[] bytes = new byte[oldSzie];
long bytesAddr = JNINativeMethod.resizeStringColumn(dataAddr, oldOffsetNum + oldSzie);
int dst = 0;
for (int i = 0; i < num; i++) {
for (int j = 0; j < byteRes[i].length; j++) {
@ -1312,7 +1320,7 @@ public class UdfConvert {
UdfUtils.copyMemory(offsets, UdfUtils.INT_ARRAY_OFFSET, null, strOffsetAddr + (4L * oldOffsetNum),
num * 4L);
UdfUtils.copyMemory(bytes, UdfUtils.BYTE_ARRAY_OFFSET, null, bytesAddr + oldOffsetNum,
offsets[num - 1]);
oldSzie);
hasPutElementNum = hasPutElementNum + num;
}
UdfUtils.UNSAFE.putLong(null, offsetsAddr + 8L * row, hasPutElementNum);

View File

@ -74,50 +74,6 @@ public class UdfExecutor extends BaseExecutor {
super.close();
}
/**
* evaluate function called by the backend. The inputs to the UDF have
* been serialized to 'input'
*/
public void evaluate() throws UdfRuntimeException {
int batchSize = UdfUtils.UNSAFE.getInt(null, batchSizePtr);
try {
if (retType.equals(JavaUdfDataType.STRING) || retType.equals(JavaUdfDataType.VARCHAR)
|| retType.equals(JavaUdfDataType.CHAR) || retType.equals(JavaUdfDataType.ARRAY_TYPE)
|| retType.equals(JavaUdfDataType.MAP_TYPE)) {
// If this udf return variable-size type (e.g.) String, we have to allocate output
// buffer multiple times until buffer size is enough to store output column. So we
// always begin with the last evaluated row instead of beginning of this batch.
rowIdx = UdfUtils.UNSAFE.getLong(null, outputIntermediateStatePtr + 8);
if (rowIdx == 0) {
outputOffset = 0L;
}
} else {
rowIdx = 0;
}
for (; rowIdx < batchSize; rowIdx++) {
inputObjects = allocateInputObjects(rowIdx, 0);
// `storeUdfResult` is called to store udf result to output column. If true
// is returned, current value is stored successfully. Otherwise, current result is
// not processed successfully (e.g. current output buffer is not large enough) so
// we break this loop directly.
if (!storeUdfResult(evaluate(inputObjects), rowIdx, method.getReturnType())) {
UdfUtils.UNSAFE.putLong(null, outputIntermediateStatePtr + 8, rowIdx);
return;
}
}
} catch (Exception e) {
if (retType.equals(JavaUdfDataType.STRING) || retType.equals(JavaUdfDataType.ARRAY_TYPE)
|| retType.equals(JavaUdfDataType.MAP_TYPE)) {
UdfUtils.UNSAFE.putLong(null, outputIntermediateStatePtr + 8, batchSize);
}
throw new UdfRuntimeException("UDF::evaluate() ran into a problem.", e);
}
if (retType.equals(JavaUdfDataType.STRING) || retType.equals(JavaUdfDataType.ARRAY_TYPE)
|| retType.equals(JavaUdfDataType.MAP_TYPE)) {
UdfUtils.UNSAFE.putLong(null, outputIntermediateStatePtr + 8, rowIdx);
}
}
public Object[] convertBasicArguments(int argIdx, boolean isNullable, int numRows, long nullMapAddr,
long columnAddr, long strOffsetAddr) {
return convertBasicArg(true, argIdx, isNullable, 0, numRows, nullMapAddr, columnAddr, strOffsetAddr);
@ -211,30 +167,6 @@ public class UdfExecutor extends BaseExecutor {
return method;
}
// Sets the result object 'obj' into the outputBufferPtr and outputNullPtr_
@Override
protected boolean storeUdfResult(Object obj, long row, Class retClass) throws UdfRuntimeException {
if (obj == null) {
if (UdfUtils.UNSAFE.getLong(null, outputNullPtr) == -1) {
throw new UdfRuntimeException("UDF failed to store null data to not null column");
}
UdfUtils.UNSAFE.putByte(null, UdfUtils.UNSAFE.getLong(null, outputNullPtr) + row, (byte) 1);
if (retType.equals(JavaUdfDataType.STRING)) {
UdfUtils.UNSAFE.putInt(null, UdfUtils.UNSAFE.getLong(null, outputOffsetsPtr)
+ 4L * row, Integer.parseUnsignedInt(String.valueOf(outputOffset)));
} else if (retType.equals(JavaUdfDataType.ARRAY_TYPE)) {
UdfUtils.UNSAFE.putLong(null, UdfUtils.UNSAFE.getLong(null, outputOffsetsPtr) + 8L * row,
Long.parseUnsignedLong(String.valueOf(outputOffset)));
}
return true;
}
return super.storeUdfResult(obj, row, retClass);
}
@Override
protected long getCurrentOutputOffset(long row, boolean isArrayType) {
return outputOffset;
}
@Override
protected void updateOutputOffset(long offset) {

View File

@ -1,600 +0,0 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.
package org.apache.doris.udf;
import org.apache.doris.common.jni.utils.UdfUtils;
import org.apache.doris.thrift.TFunction;
import org.apache.doris.thrift.TFunctionBinaryType;
import org.apache.doris.thrift.TFunctionName;
import org.apache.doris.thrift.TJavaUdfExecutorCtorParams;
import org.apache.doris.thrift.TPrimitiveType;
import org.apache.doris.thrift.TScalarFunction;
import org.apache.doris.thrift.TScalarType;
import org.apache.doris.thrift.TTypeDesc;
import org.apache.doris.thrift.TTypeNode;
import org.apache.doris.thrift.TTypeNodeType;
import org.apache.thrift.TSerializer;
import org.apache.thrift.protocol.TBinaryProtocol;
import org.junit.Test;
import java.math.BigDecimal;
import java.math.BigInteger;
import java.nio.charset.StandardCharsets;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
public class UdfExecutorTest {
@Test
public void testDateTimeUdf() throws Exception {
TScalarFunction scalarFunction = new TScalarFunction();
scalarFunction.symbol = "org.apache.doris.udf.DateTimeUdf";
TFunction fn = new TFunction();
fn.setBinaryType(TFunctionBinaryType.JAVA_UDF);
TTypeNode typeNode = new TTypeNode(TTypeNodeType.SCALAR);
typeNode.setScalarType(new TScalarType(TPrimitiveType.INT));
fn.setRetType(new TTypeDesc(Collections.singletonList(typeNode)));
TTypeNode typeNodeArg = new TTypeNode(TTypeNodeType.SCALAR);
typeNodeArg.setScalarType(new TScalarType(TPrimitiveType.DATETIME));
TTypeDesc typeDescArg = new TTypeDesc(Collections.singletonList(typeNodeArg));
fn.arg_types = Arrays.asList(typeDescArg);
fn.scalar_fn = scalarFunction;
fn.name = new TFunctionName("DateTimeUdf");
long batchSizePtr = UdfUtils.UNSAFE.allocateMemory(4);
int batchSize = 10;
UdfUtils.UNSAFE.putInt(batchSizePtr, batchSize);
TJavaUdfExecutorCtorParams params = new TJavaUdfExecutorCtorParams();
params.setBatchSizePtr(batchSizePtr);
params.setFn(fn);
long outputBufferPtr = UdfUtils.UNSAFE.allocateMemory(8);
long outputNullPtr = UdfUtils.UNSAFE.allocateMemory(8);
long outputBuffer = UdfUtils.UNSAFE.allocateMemory(4 * batchSize);
long outputNull = UdfUtils.UNSAFE.allocateMemory(batchSize);
UdfUtils.UNSAFE.putLong(outputBufferPtr, outputBuffer);
UdfUtils.UNSAFE.putLong(outputNullPtr, outputNull);
params.setOutputBufferPtr(outputBufferPtr);
params.setOutputNullPtr(outputNullPtr);
int numCols = 1;
long inputBufferPtr = UdfUtils.UNSAFE.allocateMemory(8 * numCols);
long inputNullPtr = UdfUtils.UNSAFE.allocateMemory(8 * numCols);
long inputBuffer1 = UdfUtils.UNSAFE.allocateMemory(8 * batchSize);
long inputNull1 = UdfUtils.UNSAFE.allocateMemory(batchSize);
UdfUtils.UNSAFE.putLong(inputBufferPtr, inputBuffer1);
UdfUtils.UNSAFE.putLong(inputNullPtr, inputNull1);
long[] inputLongDateTime =
new long[] {562960991655690406L, 563242466632401062L, 563523941609111718L, 563805416585822374L,
564086891562533030L, 564368366539243686L, 564649841515954342L, 564931316492664998L,
565212791469375654L, 565494266446086310L};
for (int i = 0; i < batchSize; ++i) {
UdfUtils.UNSAFE.putLong(null, inputBuffer1 + i * 8, inputLongDateTime[i]);
UdfUtils.UNSAFE.putByte(null, inputNull1 + i, (byte) 0);
}
params.setInputBufferPtrs(inputBufferPtr);
params.setInputNullsPtrs(inputNullPtr);
params.setInputOffsetsPtrs(0);
TBinaryProtocol.Factory factory = new TBinaryProtocol.Factory();
TSerializer serializer = new TSerializer(factory);
UdfExecutor executor = new UdfExecutor(serializer.serialize(params));
executor.evaluate();
for (int i = 0; i < batchSize; ++i) {
assert (UdfUtils.UNSAFE.getByte(outputNull + i) == 0);
assert (UdfUtils.UNSAFE.getInt(outputBuffer + 4 * i) == (2000 + i));
}
}
@Test
public void testDecimalUdf() throws Exception {
TScalarFunction scalarFunction = new TScalarFunction();
scalarFunction.symbol = "org.apache.doris.udf.DecimalUdf";
TFunction fn = new TFunction();
fn.binary_type = TFunctionBinaryType.JAVA_UDF;
TTypeNode typeNode = new TTypeNode(TTypeNodeType.SCALAR);
TScalarType scalarType = new TScalarType(TPrimitiveType.DECIMALV2);
scalarType.setScale(9);
scalarType.setPrecision(27);
typeNode.scalar_type = scalarType;
TTypeDesc typeDesc = new TTypeDesc(Collections.singletonList(typeNode));
fn.ret_type = typeDesc;
fn.arg_types = Arrays.asList(typeDesc, typeDesc);
fn.scalar_fn = scalarFunction;
fn.name = new TFunctionName("DecimalUdf");
long batchSizePtr = UdfUtils.UNSAFE.allocateMemory(8);
int batchSize = 10;
UdfUtils.UNSAFE.putInt(batchSizePtr, batchSize);
TJavaUdfExecutorCtorParams params = new TJavaUdfExecutorCtorParams();
params.setBatchSizePtr(batchSizePtr);
params.setFn(fn);
long outputBufferPtr = UdfUtils.UNSAFE.allocateMemory(8);
long outputNullPtr = UdfUtils.UNSAFE.allocateMemory(8);
long outputBuffer = UdfUtils.UNSAFE.allocateMemory(16 * batchSize);
long outputNull = UdfUtils.UNSAFE.allocateMemory(batchSize);
UdfUtils.UNSAFE.putLong(outputBufferPtr, outputBuffer);
UdfUtils.UNSAFE.putLong(outputNullPtr, outputNull);
params.setOutputBufferPtr(outputBufferPtr);
params.setOutputNullPtr(outputNullPtr);
int numCols = 2;
long inputBufferPtr = UdfUtils.UNSAFE.allocateMemory(8 * numCols);
long inputNullPtr = UdfUtils.UNSAFE.allocateMemory(8 * numCols);
long inputBuffer1 = UdfUtils.UNSAFE.allocateMemory(16 * batchSize);
long inputNull1 = UdfUtils.UNSAFE.allocateMemory(batchSize);
long inputBuffer2 = UdfUtils.UNSAFE.allocateMemory(16 * batchSize);
long inputNull2 = UdfUtils.UNSAFE.allocateMemory(batchSize);
UdfUtils.UNSAFE.putLong(inputBufferPtr, inputBuffer1);
UdfUtils.UNSAFE.putLong(inputBufferPtr + 8, inputBuffer2);
UdfUtils.UNSAFE.putLong(inputNullPtr, inputNull1);
UdfUtils.UNSAFE.putLong(inputNullPtr + 8, inputNull2);
long[] inputLong =
new long[] {562960991655690406L, 563242466632401062L, 563523941609111718L, 563805416585822374L,
564086891562533030L, 564368366539243686L, 564649841515954342L, 564931316492664998L,
565212791469375654L, 565494266446086310L};
BigDecimal[] decimalArray = new BigDecimal[10];
for (int i = 0; i < batchSize; ++i) {
BigInteger temp = BigInteger.valueOf(inputLong[i]);
decimalArray[i] = new BigDecimal(temp, 9);
}
BigDecimal decimal2 = new BigDecimal(BigInteger.valueOf(0L), 9);
byte[] intput2 = convertByteOrder(decimal2.unscaledValue().toByteArray());
byte[] value2 = new byte[16];
if (decimal2.signum() == -1) {
Arrays.fill(value2, (byte) -1);
}
for (int index = 0; index < Math.min(intput2.length, value2.length); ++index) {
value2[index] = intput2[index];
}
for (int i = 0; i < batchSize; ++i) {
byte[] intput1 = convertByteOrder(decimalArray[i].unscaledValue().toByteArray());
byte[] value1 = new byte[16];
if (decimalArray[i].signum() == -1) {
Arrays.fill(value1, (byte) -1);
}
for (int index = 0; index < Math.min(intput1.length, value1.length); ++index) {
value1[index] = intput1[index];
}
UdfUtils.copyMemory(value1, UdfUtils.BYTE_ARRAY_OFFSET, null, inputBuffer1 + i * 16, value1.length);
UdfUtils.copyMemory(value2, UdfUtils.BYTE_ARRAY_OFFSET, null, inputBuffer2 + i * 16, value2.length);
UdfUtils.UNSAFE.putByte(null, inputNull1 + i, (byte) 0);
UdfUtils.UNSAFE.putByte(null, inputNull2 + i, (byte) 0);
}
params.setInputBufferPtrs(inputBufferPtr);
params.setInputNullsPtrs(inputNullPtr);
params.setInputOffsetsPtrs(0);
TBinaryProtocol.Factory factory = new TBinaryProtocol.Factory();
TSerializer serializer = new TSerializer(factory);
UdfExecutor udfExecutor = new UdfExecutor(serializer.serialize(params));
udfExecutor.evaluate();
for (int i = 0; i < batchSize; ++i) {
byte[] bytes = new byte[16];
assert (UdfUtils.UNSAFE.getByte(outputNull + i) == 0);
UdfUtils.copyMemory(null, outputBuffer + 16 * i, bytes, UdfUtils.BYTE_ARRAY_OFFSET, bytes.length);
BigInteger integer = new BigInteger(convertByteOrder(bytes));
BigDecimal result = new BigDecimal(integer, 9);
assert (result.equals(decimalArray[i]));
}
}
@Test
public void testConstantOneUdf() throws Exception {
TScalarFunction scalarFunction = new TScalarFunction();
scalarFunction.symbol = "org.apache.doris.udf.ConstantOneUdf";
TFunction fn = new TFunction();
fn.binary_type = TFunctionBinaryType.JAVA_UDF;
TTypeNode typeNode = new TTypeNode(TTypeNodeType.SCALAR);
typeNode.scalar_type = new TScalarType(TPrimitiveType.INT);
fn.ret_type = new TTypeDesc(Collections.singletonList(typeNode));
fn.arg_types = new ArrayList<>();
fn.scalar_fn = scalarFunction;
fn.name = new TFunctionName("ConstantOne");
long batchSizePtr = UdfUtils.UNSAFE.allocateMemory(4);
int batchSize = 10;
UdfUtils.UNSAFE.putInt(batchSizePtr, batchSize);
TJavaUdfExecutorCtorParams params = new TJavaUdfExecutorCtorParams();
params.setBatchSizePtr(batchSizePtr);
params.setFn(fn);
long outputBuffer = UdfUtils.UNSAFE.allocateMemory(4 * batchSize);
long outputNull = UdfUtils.UNSAFE.allocateMemory(batchSize);
long outputBufferPtr = UdfUtils.UNSAFE.allocateMemory(8);
UdfUtils.UNSAFE.putLong(outputBufferPtr, outputBuffer);
long outputNullPtr = UdfUtils.UNSAFE.allocateMemory(8);
UdfUtils.UNSAFE.putLong(outputNullPtr, outputNull);
params.setOutputBufferPtr(outputBufferPtr);
params.setOutputNullPtr(outputNullPtr);
params.setInputBufferPtrs(0);
params.setInputNullsPtrs(0);
params.setInputOffsetsPtrs(0);
TBinaryProtocol.Factory factory =
new TBinaryProtocol.Factory();
TSerializer serializer = new TSerializer(factory);
UdfExecutor executor;
executor = new UdfExecutor(serializer.serialize(params));
executor.evaluate();
for (int i = 0; i < 10; i++) {
assert (UdfUtils.UNSAFE.getByte(outputNull + i) == 0);
assert (UdfUtils.UNSAFE.getInt(outputBuffer + 4 * i) == 1);
}
}
@Test
public void testSimpleAddUdf() throws Exception {
TScalarFunction scalarFunction = new TScalarFunction();
scalarFunction.symbol = "org.apache.doris.udf.SimpleAddUdf";
TFunction fn = new TFunction();
fn.binary_type = TFunctionBinaryType.JAVA_UDF;
TTypeNode typeNode = new TTypeNode(TTypeNodeType.SCALAR);
typeNode.scalar_type = new TScalarType(TPrimitiveType.INT);
TTypeDesc typeDesc = new TTypeDesc(Collections.singletonList(typeNode));
fn.ret_type = typeDesc;
fn.arg_types = Arrays.asList(typeDesc, typeDesc);
fn.scalar_fn = scalarFunction;
fn.name = new TFunctionName("SimpleAdd");
long batchSizePtr = UdfUtils.UNSAFE.allocateMemory(4);
int batchSize = 10;
UdfUtils.UNSAFE.putInt(batchSizePtr, batchSize);
TJavaUdfExecutorCtorParams params = new TJavaUdfExecutorCtorParams();
params.setBatchSizePtr(batchSizePtr);
params.setFn(fn);
long outputBufferPtr = UdfUtils.UNSAFE.allocateMemory(8);
long outputNullPtr = UdfUtils.UNSAFE.allocateMemory(8);
long outputBuffer = UdfUtils.UNSAFE.allocateMemory(4 * batchSize);
long outputNull = UdfUtils.UNSAFE.allocateMemory(batchSize);
UdfUtils.UNSAFE.putLong(outputBufferPtr, outputBuffer);
UdfUtils.UNSAFE.putLong(outputNullPtr, outputNull);
params.setOutputBufferPtr(outputBufferPtr);
params.setOutputNullPtr(outputNullPtr);
int numCols = 2;
long inputBufferPtr = UdfUtils.UNSAFE.allocateMemory(8 * numCols);
long inputNullPtr = UdfUtils.UNSAFE.allocateMemory(8 * numCols);
long inputBuffer1 = UdfUtils.UNSAFE.allocateMemory(4 * batchSize);
long inputNull1 = UdfUtils.UNSAFE.allocateMemory(batchSize);
long inputBuffer2 = UdfUtils.UNSAFE.allocateMemory(4 * batchSize);
long inputNull2 = UdfUtils.UNSAFE.allocateMemory(batchSize);
UdfUtils.UNSAFE.putLong(inputBufferPtr, inputBuffer1);
UdfUtils.UNSAFE.putLong(inputBufferPtr + 8, inputBuffer2);
UdfUtils.UNSAFE.putLong(inputNullPtr, inputNull1);
UdfUtils.UNSAFE.putLong(inputNullPtr + 8, inputNull2);
for (int i = 0; i < batchSize; i++) {
UdfUtils.UNSAFE.putInt(null, inputBuffer1 + i * 4, i);
UdfUtils.UNSAFE.putInt(null, inputBuffer2 + i * 4, i);
if (i % 2 == 0) {
UdfUtils.UNSAFE.putByte(null, inputNull1 + i, (byte) 1);
} else {
UdfUtils.UNSAFE.putByte(null, inputNull1 + i, (byte) 0);
}
UdfUtils.UNSAFE.putByte(null, inputNull2 + i, (byte) 0);
}
params.setInputBufferPtrs(inputBufferPtr);
params.setInputNullsPtrs(inputNullPtr);
params.setInputOffsetsPtrs(0);
TBinaryProtocol.Factory factory =
new TBinaryProtocol.Factory();
TSerializer serializer = new TSerializer(factory);
UdfExecutor executor;
executor = new UdfExecutor(serializer.serialize(params));
executor.evaluate();
for (int i = 0; i < batchSize; i++) {
if (i % 2 == 0) {
assert (UdfUtils.UNSAFE.getByte(outputNull + i) == 1);
} else {
assert (UdfUtils.UNSAFE.getByte(outputNull + i) == 0);
assert (UdfUtils.UNSAFE.getInt(outputBuffer + 4 * i) == i * 2);
}
}
}
@Test
public void testStringConcatUdf() throws Exception {
TScalarFunction scalarFunction = new TScalarFunction();
scalarFunction.symbol = "org.apache.doris.udf.StringConcatUdf";
TFunction fn = new TFunction();
fn.binary_type = TFunctionBinaryType.JAVA_UDF;
TTypeNode typeNode = new TTypeNode(TTypeNodeType.SCALAR);
typeNode.scalar_type = new TScalarType(TPrimitiveType.STRING);
TTypeDesc typeDesc = new TTypeDesc(Collections.singletonList(typeNode));
fn.ret_type = typeDesc;
fn.arg_types = Arrays.asList(typeDesc, typeDesc);
fn.scalar_fn = scalarFunction;
fn.name = new TFunctionName("StringConcat");
long batchSizePtr = UdfUtils.UNSAFE.allocateMemory(32);
int batchSize = 10;
UdfUtils.UNSAFE.putInt(batchSizePtr, batchSize);
TJavaUdfExecutorCtorParams params = new TJavaUdfExecutorCtorParams();
params.setBatchSizePtr(batchSizePtr);
params.setFn(fn);
long outputBufferPtr = UdfUtils.UNSAFE.allocateMemory(8);
long outputNullPtr = UdfUtils.UNSAFE.allocateMemory(8);
long outputOffsetsPtr = UdfUtils.UNSAFE.allocateMemory(8);
long outputIntermediateStatePtr = UdfUtils.UNSAFE.allocateMemory(8 * 2);
String[] input1 = new String[batchSize];
String[] input2 = new String[batchSize];
long[] inputOffsets1 = new long[batchSize];
long[] inputOffsets2 = new long[batchSize];
long inputBufferSize1 = 0;
long inputBufferSize2 = 0;
for (int i = 0; i < batchSize; i++) {
input1[i] = "Input1_" + i;
input2[i] = "Input2_" + i;
inputOffsets1[i] = i == 0 ? input1[i].getBytes(StandardCharsets.UTF_8).length
: inputOffsets1[i - 1] + input1[i].getBytes(StandardCharsets.UTF_8).length;
inputOffsets2[i] = i == 0 ? input2[i].getBytes(StandardCharsets.UTF_8).length
: inputOffsets2[i - 1] + input2[i].getBytes(StandardCharsets.UTF_8).length;
inputBufferSize1 += input1[i].getBytes(StandardCharsets.UTF_8).length;
inputBufferSize2 += input2[i].getBytes(StandardCharsets.UTF_8).length;
}
// In our test case, output buffer is (8 + 1) bytes * batchSize
long outputBuffer = UdfUtils.UNSAFE.allocateMemory(inputBufferSize1 + inputBufferSize2 + batchSize);
long outputNull = UdfUtils.UNSAFE.allocateMemory(batchSize);
long outputOffset = UdfUtils.UNSAFE.allocateMemory(4 * batchSize);
UdfUtils.UNSAFE.putLong(outputBufferPtr, outputBuffer);
UdfUtils.UNSAFE.putLong(outputNullPtr, outputNull);
UdfUtils.UNSAFE.putLong(outputOffsetsPtr, outputOffset);
// reserved buffer size
UdfUtils.UNSAFE.putLong(outputIntermediateStatePtr, inputBufferSize1 + inputBufferSize2 + batchSize);
// current row id
UdfUtils.UNSAFE.putLong(outputIntermediateStatePtr + 8, 0);
params.setOutputBufferPtr(outputBufferPtr);
params.setOutputNullPtr(outputNullPtr);
params.setOutputOffsetsPtr(outputOffsetsPtr);
params.setOutputIntermediateStatePtr(outputIntermediateStatePtr);
int numCols = 2;
long inputBufferPtr = UdfUtils.UNSAFE.allocateMemory(8 * numCols);
long inputNullPtr = UdfUtils.UNSAFE.allocateMemory(8 * numCols);
long inputOffsetsPtr = UdfUtils.UNSAFE.allocateMemory(8 * numCols);
long inputBuffer1 = UdfUtils.UNSAFE.allocateMemory(inputBufferSize1 + batchSize);
long inputOffset1 = UdfUtils.UNSAFE.allocateMemory(4 * batchSize);
long inputBuffer2 = UdfUtils.UNSAFE.allocateMemory(inputBufferSize2 + batchSize);
long inputOffset2 = UdfUtils.UNSAFE.allocateMemory(4 * batchSize);
UdfUtils.UNSAFE.putLong(inputBufferPtr, inputBuffer1);
UdfUtils.UNSAFE.putLong(inputBufferPtr + 8, inputBuffer2);
UdfUtils.UNSAFE.putLong(inputNullPtr, -1);
UdfUtils.UNSAFE.putLong(inputNullPtr + 8, -1);
UdfUtils.UNSAFE.putLong(inputOffsetsPtr, inputOffset1);
UdfUtils.UNSAFE.putLong(inputOffsetsPtr + 8, inputOffset2);
for (int i = 0; i < batchSize; i++) {
if (i == 0) {
UdfUtils.copyMemory(input1[i].getBytes(StandardCharsets.UTF_8),
UdfUtils.BYTE_ARRAY_OFFSET, null, inputBuffer1,
input1[i].getBytes(StandardCharsets.UTF_8).length);
UdfUtils.copyMemory(input2[i].getBytes(StandardCharsets.UTF_8),
UdfUtils.BYTE_ARRAY_OFFSET, null, inputBuffer2,
input2[i].getBytes(StandardCharsets.UTF_8).length);
} else {
UdfUtils.copyMemory(input1[i].getBytes(StandardCharsets.UTF_8),
UdfUtils.BYTE_ARRAY_OFFSET, null, inputBuffer1 + inputOffsets1[i - 1],
input1[i].getBytes(StandardCharsets.UTF_8).length);
UdfUtils.copyMemory(input2[i].getBytes(StandardCharsets.UTF_8),
UdfUtils.BYTE_ARRAY_OFFSET, null, inputBuffer2 + inputOffsets2[i - 1],
input2[i].getBytes(StandardCharsets.UTF_8).length);
}
UdfUtils.UNSAFE.putInt(null, inputOffset1 + 4L * i,
Integer.parseUnsignedInt(String.valueOf(inputOffsets1[i])));
UdfUtils.UNSAFE.putInt(null, inputOffset2 + 4L * i,
Integer.parseUnsignedInt(String.valueOf(inputOffsets2[i])));
}
params.setInputBufferPtrs(inputBufferPtr);
params.setInputNullsPtrs(inputNullPtr);
params.setInputOffsetsPtrs(inputOffsetsPtr);
TBinaryProtocol.Factory factory =
new TBinaryProtocol.Factory();
TSerializer serializer = new TSerializer(factory);
UdfExecutor executor;
executor = new UdfExecutor(serializer.serialize(params));
executor.evaluate();
for (int i = 0; i < batchSize; i++) {
byte[] bytes = new byte[input1[i].getBytes(StandardCharsets.UTF_8).length
+ input2[i].getBytes(StandardCharsets.UTF_8).length];
assert (UdfUtils.UNSAFE.getByte(outputNull + i) == 0);
if (i == 0) {
UdfUtils.copyMemory(null, outputBuffer, bytes, UdfUtils.BYTE_ARRAY_OFFSET,
bytes.length);
} else {
long lastOffset = UdfUtils.UNSAFE.getInt(null, outputOffset + 4 * (i - 1));
UdfUtils.copyMemory(null, outputBuffer + lastOffset, bytes, UdfUtils.BYTE_ARRAY_OFFSET,
bytes.length);
}
assert (new String(bytes, StandardCharsets.UTF_8).equals(input1[i] + input2[i]));
assert (UdfUtils.UNSAFE.getByte(null, outputNull + i) == 0);
}
}
@Test
public void testLargeIntUdf() throws Exception {
TScalarFunction scalarFunction = new TScalarFunction();
scalarFunction.symbol = "org.apache.doris.udf.LargeIntUdf";
TFunction fn = new TFunction();
fn.binary_type = TFunctionBinaryType.JAVA_UDF;
TTypeNode typeNode = new TTypeNode(TTypeNodeType.SCALAR);
typeNode.scalar_type = new TScalarType(TPrimitiveType.LARGEINT);
TTypeDesc typeDesc = new TTypeDesc(Collections.singletonList(typeNode));
fn.ret_type = typeDesc;
fn.arg_types = Arrays.asList(typeDesc, typeDesc);
fn.scalar_fn = scalarFunction;
fn.name = new TFunctionName("LargeIntUdf");
long batchSizePtr = UdfUtils.UNSAFE.allocateMemory(8);
int batchSize = 10;
UdfUtils.UNSAFE.putInt(batchSizePtr, batchSize);
TJavaUdfExecutorCtorParams params = new TJavaUdfExecutorCtorParams();
params.setBatchSizePtr(batchSizePtr);
params.setFn(fn);
long outputBufferPtr = UdfUtils.UNSAFE.allocateMemory(8);
long outputNullPtr = UdfUtils.UNSAFE.allocateMemory(8);
long outputBuffer = UdfUtils.UNSAFE.allocateMemory(16 * batchSize);
long outputNull = UdfUtils.UNSAFE.allocateMemory(batchSize);
UdfUtils.UNSAFE.putLong(outputBufferPtr, outputBuffer);
UdfUtils.UNSAFE.putLong(outputNullPtr, outputNull);
params.setOutputBufferPtr(outputBufferPtr);
params.setOutputNullPtr(outputNullPtr);
int numCols = 2;
long inputBufferPtr = UdfUtils.UNSAFE.allocateMemory(8 * numCols);
long inputNullPtr = UdfUtils.UNSAFE.allocateMemory(8 * numCols);
long inputBuffer1 = UdfUtils.UNSAFE.allocateMemory(16 * batchSize);
long inputNull1 = UdfUtils.UNSAFE.allocateMemory(batchSize);
long inputBuffer2 = UdfUtils.UNSAFE.allocateMemory(16 * batchSize);
long inputNull2 = UdfUtils.UNSAFE.allocateMemory(batchSize);
UdfUtils.UNSAFE.putLong(inputBufferPtr, inputBuffer1);
UdfUtils.UNSAFE.putLong(inputBufferPtr + 8, inputBuffer2);
UdfUtils.UNSAFE.putLong(inputNullPtr, inputNull1);
UdfUtils.UNSAFE.putLong(inputNullPtr + 8, inputNull2);
long[] inputLong =
new long[] {562960991655690406L, 563242466632401062L, 563523941609111718L, 563805416585822374L,
564086891562533030L, 564368366539243686L, 564649841515954342L, 564931316492664998L,
565212791469375654L, 565494266446086310L};
BigInteger[] integerArray = new BigInteger[10];
for (int i = 0; i < batchSize; ++i) {
integerArray[i] = BigInteger.valueOf(inputLong[i]);
}
BigInteger integer2 = BigInteger.valueOf(1L);
byte[] intput2 = convertByteOrder(integer2.toByteArray());
byte[] value2 = new byte[16];
if (integer2.signum() == -1) {
Arrays.fill(value2, (byte) -1);
}
for (int index = 0; index < Math.min(intput2.length, value2.length); ++index) {
value2[index] = intput2[index];
}
for (int i = 0; i < batchSize; ++i) {
byte[] intput1 = convertByteOrder(integerArray[i].toByteArray());
byte[] value1 = new byte[16];
if (integerArray[i].signum() == -1) {
Arrays.fill(value1, (byte) -1);
}
for (int index = 0; index < Math.min(intput1.length, value1.length); ++index) {
value1[index] = intput1[index];
}
UdfUtils.copyMemory(value1, UdfUtils.BYTE_ARRAY_OFFSET, null, inputBuffer1 + i * 16, value1.length);
UdfUtils.copyMemory(value2, UdfUtils.BYTE_ARRAY_OFFSET, null, inputBuffer2 + i * 16, value2.length);
UdfUtils.UNSAFE.putByte(null, inputNull1 + i, (byte) 0);
UdfUtils.UNSAFE.putByte(null, inputNull2 + i, (byte) 0);
}
params.setInputBufferPtrs(inputBufferPtr);
params.setInputNullsPtrs(inputNullPtr);
params.setInputOffsetsPtrs(0);
TBinaryProtocol.Factory factory = new TBinaryProtocol.Factory();
TSerializer serializer = new TSerializer(factory);
UdfExecutor udfExecutor = new UdfExecutor(serializer.serialize(params));
udfExecutor.evaluate();
for (int i = 0; i < batchSize; ++i) {
byte[] bytes = new byte[16];
assert (UdfUtils.UNSAFE.getByte(outputNull + i) == 0);
UdfUtils.copyMemory(null, outputBuffer + 16 * i, bytes, UdfUtils.BYTE_ARRAY_OFFSET, bytes.length);
BigInteger result = new BigInteger(convertByteOrder(bytes));
assert (result.equals(integerArray[i].add(BigInteger.valueOf(1))));
}
}
public byte[] convertByteOrder(byte[] bytes) {
int length = bytes.length;
for (int i = 0; i < length / 2; ++i) {
byte temp = bytes[i];
bytes[i] = bytes[length - 1 - i];
bytes[length - 1 - i] = temp;
}
return bytes;
}
}

View File

@ -0,0 +1,31 @@
-- This file is automatically generated. You should know what you did if you want to edit this
-- !select_1 --
{1:10, 2:20, 3:30, 4:40, 5:50}
-- !select_2 --
{1:0.01, 2:0.02, 3:0.03, 4:0.04, 5:0.05}
-- !select_3 --
{1:10}
{2:20}
{3:30}
{4:40}
{5:50}
-- !select_4 --
{1:0.01}
{2:0.02}
{3:0.03}
{4:0.04}
{5:0.05}
-- !select_5 --
{"2 114":"0.02 514", "3 114":"0.03 514", "1 114":"0.01 514", "5 114":"0.05 514", "4 114":"0.04 514"}
-- !select_6 --
{"1 114":"0.01 514"}
{"2 114":"0.02 514"}
{"3 114":"0.03 514"}
{"4 114":"0.04 514"}
{"5 114":"0.05 514"}

View File

@ -0,0 +1,75 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.
package org.apache.doris.udf;
import org.apache.log4j.Logger;
import java.io.DataInputStream;
import java.io.DataOutputStream;
import java.io.IOException;
import java.util.*;
public class MyReturnMapString {
private static final Logger LOG = Logger.getLogger(MyReturnMapString.class);
public static class State {
public HashMap<Integer,Double> counter = new HashMap<>();
}
public State create() {
return new State();
}
public void destroy(State state) {
}
public void add(State state, Integer k, Double v) {
LOG.info("udaf nest k v " + k + " " + v);
state.counter.put(k, v);
}
public void serialize(State state, DataOutputStream out) throws IOException {
int size = state.counter.size();
out.writeInt(size);
for(Map.Entry<Integer,Double> it : state.counter.entrySet()){
out.writeInt(it.getKey());
out.writeDouble(it.getValue());
}
}
public void deserialize(State state, DataInputStream in) throws IOException {
int size = in.readInt();
for (int i = 0; i < size; ++i) {
Integer key = in.readInt();
Double value = in.readDouble();
state.counter.put(key, value);
}
}
public void merge(State state, State rhs) {
for(Map.Entry<Integer,Double> it : rhs.counter.entrySet()){
state.counter.put(it.getKey(), it.getValue());
}
}
public HashMap<String,String> getValue(State state) {
//sort for regression test
HashMap<String,String> map = new HashMap<>();
for(Map.Entry<Integer,Double> it : state.counter.entrySet()){
map.put(it.getKey() + " 114", it.getValue() + " 514");
}
return map;
}
}

View File

@ -0,0 +1,73 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.
package org.apache.doris.udf;
import org.apache.log4j.Logger;
import java.io.DataInputStream;
import java.io.DataOutputStream;
import java.io.IOException;
import java.util.*;
public class MySumReturnMapInt {
private static final Logger LOG = Logger.getLogger(MySumReturnMapInt.class);
public static class State {
public HashMap<Integer,Integer> counter = new HashMap<>();
}
public State create() {
return new State();
}
public void destroy(State state) {
}
public void add(State state, Integer val) {
if (val == null) {
return;
}
state.counter.put(val, 10 * val);
}
public void serialize(State state, DataOutputStream out) throws IOException {
int size = state.counter.size();
out.writeInt(size);
for(Map.Entry<Integer,Integer> it : state.counter.entrySet()){
out.writeInt(it.getKey());
out.writeInt(it.getValue());
}
}
public void deserialize(State state, DataInputStream in) throws IOException {
int size = in.readInt();
for (int i = 0; i < size; ++i) {
Integer key = in.readInt();
Integer value = in.readInt();
state.counter.put(key, value);
}
}
public void merge(State state, State rhs) {
for(Map.Entry<Integer,Integer> it : rhs.counter.entrySet()){
state.counter.put(it.getKey(), it.getValue());
}
}
public HashMap<Integer,Integer> getValue(State state) {
//sort for regression test
return state.counter;
}
}

View File

@ -0,0 +1,74 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.
package org.apache.doris.udf;
import org.apache.log4j.Logger;
import com.carrotsearch.hppc.DoubleByteAssociativeContainer;
import java.io.DataInputStream;
import java.io.DataOutputStream;
import java.io.IOException;
import java.util.*;
public class MySumReturnMapIntDou {
private static final Logger LOG = Logger.getLogger(MySumReturnMapIntDou.class);
public static class State {
public HashMap<Integer,Double> counter = new HashMap<>();
}
public State create() {
return new State();
}
public void destroy(State state) {
}
public void add(State state, Integer k, Double v) {
LOG.info("udaf nest k v " + k + " " + v);
state.counter.put(k, v);
}
public void serialize(State state, DataOutputStream out) throws IOException {
int size = state.counter.size();
out.writeInt(size);
for(Map.Entry<Integer,Double> it : state.counter.entrySet()){
out.writeInt(it.getKey());
out.writeDouble(it.getValue());
}
}
public void deserialize(State state, DataInputStream in) throws IOException {
int size = in.readInt();
for (int i = 0; i < size; ++i) {
Integer key = in.readInt();
Double value = in.readDouble();
state.counter.put(key, value);
}
}
public void merge(State state, State rhs) {
for(Map.Entry<Integer,Double> it : rhs.counter.entrySet()){
state.counter.put(it.getKey(), it.getValue());
}
}
public HashMap<Integer,Double> getValue(State state) {
//sort for regression test
return state.counter;
}
}

View File

@ -0,0 +1,104 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.
import org.codehaus.groovy.runtime.IOGroovyMethods
import java.nio.charset.StandardCharsets
import java.nio.file.Files
import java.nio.file.Paths
suite("test_javaudaf_return_map") {
def jarPath = """${context.file.parent}/jars/java-udf-case-jar-with-dependencies.jar"""
log.info("Jar path: ${jarPath}".toString())
try {
try_sql("DROP FUNCTION IF EXISTS aggmap(int);")
try_sql("DROP FUNCTION IF EXISTS aggmap2(int,double);")
try_sql("DROP FUNCTION IF EXISTS aggmap3(int,double);")
try_sql("DROP TABLE IF EXISTS aggdb")
sql """
CREATE TABLE IF NOT EXISTS aggdb(
`id` INT NULL COMMENT "" ,
`d` Double NULL COMMENT ""
) ENGINE=OLAP
DUPLICATE KEY(`id`)
DISTRIBUTED BY HASH(`id`) BUCKETS 1
PROPERTIES (
"replication_allocation" = "tag.location.default: 1",
"storage_format" = "V2"
);
"""
sql """ INSERT INTO aggdb VALUES(1,0.01); """
sql """ INSERT INTO aggdb VALUES(2,0.02); """
sql """ INSERT INTO aggdb VALUES(3,0.03); """
sql """ INSERT INTO aggdb VALUES(4,0.04); """
sql """ INSERT INTO aggdb VALUES(5,0.05); """
sql """
CREATE AGGREGATE FUNCTION aggmap(int) RETURNS Map<int,int> PROPERTIES (
"file"="file://${jarPath}",
"symbol"="org.apache.doris.udf.MySumReturnMapInt",
"type"="JAVA_UDF"
);
"""
sql """
CREATE AGGREGATE FUNCTION aggmap2(int,double) RETURNS Map<int,double> PROPERTIES (
"file"="file://${jarPath}",
"symbol"="org.apache.doris.udf.MySumReturnMapIntDou",
"type"="JAVA_UDF"
);
"""
sql """
CREATE AGGREGATE FUNCTION aggmap3(int,double) RETURNS Map<String,String> PROPERTIES (
"file"="file://${jarPath}",
"symbol"="org.apache.doris.udf.MyReturnMapString",
"type"="JAVA_UDF"
);
"""
qt_select_1 """ select aggmap(id) from aggdb; """
qt_select_2 """ select aggmap2(id,d) from aggdb; """
qt_select_3 """ select aggmap(id) from aggdb group by id; """
qt_select_4 """ select aggmap2(id,d) from aggdb group by id; """
qt_select_5 """ select aggmap3(id,d) from aggdb; """
qt_select_6 """ select aggmap3(id,d) from aggdb group by id; """
} finally {
try_sql("DROP FUNCTION IF EXISTS aggmap(int);")
try_sql("DROP FUNCTION IF EXISTS aggmap2(int,double);")
try_sql("DROP FUNCTION IF EXISTS aggmap3(int,double);")
try_sql("DROP TABLE IF EXISTS aggdb")
}
}