diff --git a/be/src/vec/aggregate_functions/aggregate_function_java_udaf.h b/be/src/vec/aggregate_functions/aggregate_function_java_udaf.h index 6fe4742064..d51c219f3f 100644 --- a/be/src/vec/aggregate_functions/aggregate_function_java_udaf.h +++ b/be/src/vec/aggregate_functions/aggregate_function_java_udaf.h @@ -31,6 +31,7 @@ #include "util/jni-util.h" #include "vec/aggregate_functions/aggregate_function.h" #include "vec/columns/column_array.h" +#include "vec/columns/column_map.h" #include "vec/columns/column_string.h" #include "vec/common/string_ref.h" #include "vec/core/field.h" @@ -56,12 +57,6 @@ public: AggregateJavaUdafData() = default; AggregateJavaUdafData(int64_t num_args) { argument_size = num_args; - input_values_buffer_ptr = std::make_unique(num_args); - input_nulls_buffer_ptr = std::make_unique(num_args); - input_offsets_ptrs = std::make_unique(num_args); - input_array_nulls_buffer_ptr = std::make_unique(num_args); - input_array_string_offsets_ptrs = std::make_unique(num_args); - input_place_ptrs = std::make_unique(0); output_value_buffer = std::make_unique(0); output_null_value = std::make_unique(0); output_offsets_ptr = std::make_unique(0); @@ -93,16 +88,8 @@ public: TJavaUdfExecutorCtorParams ctor_params; ctor_params.__set_fn(fn); ctor_params.__set_location(local_location); - ctor_params.__set_input_offsets_ptrs((int64_t)input_offsets_ptrs.get()); - ctor_params.__set_input_buffer_ptrs((int64_t)input_values_buffer_ptr.get()); - ctor_params.__set_input_nulls_ptrs((int64_t)input_nulls_buffer_ptr.get()); - ctor_params.__set_input_array_nulls_buffer_ptr( - (int64_t)input_array_nulls_buffer_ptr.get()); - ctor_params.__set_input_array_string_offsets_ptrs( - (int64_t)input_array_string_offsets_ptrs.get()); ctor_params.__set_output_buffer_ptr((int64_t)output_value_buffer.get()); - ctor_params.__set_input_places_ptr((int64_t)input_place_ptrs.get()); ctor_params.__set_output_null_ptr((int64_t)output_null_value.get()); ctor_params.__set_output_offsets_ptr((int64_t)output_offsets_ptr.get()); @@ -188,6 +175,57 @@ public: arg_column_nullable, row_num_start, row_num_end, nullmap_address, offset_address, nested_nullmap_address, nested_data_address, nested_offset_address); + } else if (data_col->is_column_map()) { + const ColumnMap* map_col = assert_cast(data_col); + auto offset_address = reinterpret_cast( + map_col->get_offsets_column().get_raw_data().data); + const ColumnNullable& map_key_column_nullable = + assert_cast(map_col->get_keys()); + auto key_data_column_null_map = map_key_column_nullable.get_null_map_column_ptr(); + auto key_data_column = map_key_column_nullable.get_nested_column_ptr(); + + auto key_nested_nullmap_address = reinterpret_cast( + check_and_get_column>(key_data_column_null_map) + ->get_data() + .data()); + int64_t key_nested_data_address = 0, key_nested_offset_address = 0; + if (key_data_column->is_column_string()) { + const ColumnString* col = + assert_cast(key_data_column.get()); + key_nested_data_address = reinterpret_cast(col->get_chars().data()); + key_nested_offset_address = + reinterpret_cast(col->get_offsets().data()); + } else { + key_nested_data_address = + reinterpret_cast(key_data_column->get_raw_data().data); + } + + const ColumnNullable& map_value_column_nullable = + assert_cast(map_col->get_values()); + auto value_data_column_null_map = + map_value_column_nullable.get_null_map_column_ptr(); + auto value_data_column = map_value_column_nullable.get_nested_column_ptr(); + auto value_nested_nullmap_address = reinterpret_cast( + check_and_get_column>(value_data_column_null_map) + ->get_data() + .data()); + int64_t value_nested_data_address = 0, value_nested_offset_address = 0; + if (value_data_column->is_column_string()) { + const ColumnString* col = + assert_cast(value_data_column.get()); + value_nested_data_address = reinterpret_cast(col->get_chars().data()); + value_nested_offset_address = + reinterpret_cast(col->get_offsets().data()); + } else { + value_nested_data_address = + reinterpret_cast(value_data_column->get_raw_data().data); + } + arr_obj = (jobjectArray)env->CallObjectMethod( + executor_obj, executor_convert_map_argument_id, arg_idx, + arg_column_nullable, row_num_start, row_num_end, nullmap_address, + offset_address, key_nested_nullmap_address, key_nested_data_address, + key_nested_offset_address, value_nested_nullmap_address, + value_nested_data_address, value_nested_offset_address); } else { return Status::InvalidArgument( strings::Substitute("Java UDAF doesn't support type is $0 now !", @@ -262,131 +300,133 @@ public: *output_null_value = reinterpret_cast(nullable.get_null_map_column().get_raw_data().data); auto& data_col = nullable.get_nested_column(); - -#ifndef EVALUATE_JAVA_UDAF -#define EVALUATE_JAVA_UDAF \ - if (data_col.is_column_string()) { \ - const ColumnString* str_col = check_and_get_column(data_col); \ - ColumnString::Chars& chars = const_cast(str_col->get_chars()); \ - ColumnString::Offsets& offsets = \ - const_cast(str_col->get_offsets()); \ - int increase_buffer_size = 0; \ - int64_t buffer_size = JniUtil::IncreaseReservedBufferSize(increase_buffer_size); \ - chars.resize(buffer_size); \ - *output_value_buffer = reinterpret_cast(chars.data()); \ - *output_offsets_ptr = reinterpret_cast(offsets.data()); \ - *output_intermediate_state_ptr = chars.size(); \ - jboolean res = env->CallNonvirtualBooleanMethod(executor_obj, executor_cl, \ - executor_result_id, to.size() - 1, place); \ - while (res != JNI_TRUE) { \ - RETURN_IF_ERROR(JniUtil::GetJniExceptionMsg(env)); \ - increase_buffer_size++; \ - buffer_size = JniUtil::IncreaseReservedBufferSize(increase_buffer_size); \ - try { \ - chars.resize(buffer_size); \ - } catch (std::bad_alloc const& e) { \ - throw doris::Exception( \ - ErrorCode::INTERNAL_ERROR, \ - "memory allocate failed in column string, buffer:{},size:{},reason:{}", \ - increase_buffer_size, buffer_size, e.what()); \ - } \ - *output_value_buffer = reinterpret_cast(chars.data()); \ - *output_intermediate_state_ptr = chars.size(); \ - res = env->CallNonvirtualBooleanMethod(executor_obj, executor_cl, executor_result_id, \ - to.size() - 1, place); \ - } \ - } else if (data_col.is_numeric() || data_col.is_column_decimal()) { \ - *output_value_buffer = reinterpret_cast(data_col.get_raw_data().data); \ - env->CallNonvirtualBooleanMethod(executor_obj, executor_cl, executor_result_id, \ - to.size() - 1, place); \ - } else if (data_col.is_column_array()) { \ - ColumnArray& array_col = assert_cast(data_col); \ - ColumnNullable& array_nested_nullable = \ - assert_cast(array_col.get_data()); \ - auto data_column_null_map = array_nested_nullable.get_null_map_column_ptr(); \ - auto data_column = array_nested_nullable.get_nested_column_ptr(); \ - auto& offset_column = array_col.get_offsets_column(); \ - int increase_buffer_size = 0; \ - int64_t buffer_size = JniUtil::IncreaseReservedBufferSize(increase_buffer_size); \ - *output_offsets_ptr = reinterpret_cast(offset_column.get_raw_data().data); \ - data_column_null_map->resize(buffer_size); \ - auto& null_map_data = \ - assert_cast*>(data_column_null_map.get())->get_data(); \ - *output_array_null_ptr = reinterpret_cast(null_map_data.data()); \ - *output_intermediate_state_ptr = buffer_size; \ - if (data_column->is_column_string()) { \ - ColumnString* str_col = assert_cast(data_column.get()); \ - ColumnString::Chars& chars = assert_cast(str_col->get_chars()); \ - ColumnString::Offsets& offsets = \ - assert_cast(str_col->get_offsets()); \ - chars.resize(buffer_size); \ - offsets.resize(buffer_size); \ - *output_value_buffer = reinterpret_cast(chars.data()); \ - *output_array_string_offsets_ptr = reinterpret_cast(offsets.data()); \ - jboolean res = env->CallNonvirtualBooleanMethod( \ - executor_obj, executor_cl, executor_result_id, to.size() - 1, place); \ - while (res != JNI_TRUE) { \ - RETURN_IF_ERROR(JniUtil::GetJniExceptionMsg(env)); \ - increase_buffer_size++; \ - buffer_size = JniUtil::IncreaseReservedBufferSize(increase_buffer_size); \ - try { \ - null_map_data.resize(buffer_size); \ - chars.resize(buffer_size); \ - offsets.resize(buffer_size); \ - } catch (std::bad_alloc const& e) { \ - throw doris::Exception(ErrorCode::INTERNAL_ERROR, \ - "memory allocate failed in array column string, " \ - "buffer:{},size:{},reason:{}", \ - increase_buffer_size, buffer_size, e.what()); \ - } \ - *output_array_null_ptr = reinterpret_cast(null_map_data.data()); \ - *output_value_buffer = reinterpret_cast(chars.data()); \ - *output_array_string_offsets_ptr = reinterpret_cast(offsets.data()); \ - *output_intermediate_state_ptr = buffer_size; \ - res = env->CallNonvirtualBooleanMethod(executor_obj, executor_cl, \ - executor_result_id, to.size() - 1, place); \ - } \ - } else { \ - data_column->resize(buffer_size); \ - *output_value_buffer = reinterpret_cast(data_column->get_raw_data().data); \ - jboolean res = env->CallNonvirtualBooleanMethod( \ - executor_obj, executor_cl, executor_result_id, to.size() - 1, place); \ - while (res != JNI_TRUE) { \ - RETURN_IF_ERROR(JniUtil::GetJniExceptionMsg(env)); \ - increase_buffer_size++; \ - buffer_size = JniUtil::IncreaseReservedBufferSize(increase_buffer_size); \ - try { \ - null_map_data.resize(buffer_size); \ - data_column->resize(buffer_size); \ - } catch (std::bad_alloc const& e) { \ - throw doris::Exception(ErrorCode::INTERNAL_ERROR, \ - "memory allocate failed in array number column, " \ - "buffer:{},size:{},reason:{}", \ - increase_buffer_size, buffer_size, e.what()); \ - } \ - *output_array_null_ptr = reinterpret_cast(null_map_data.data()); \ - *output_value_buffer = \ - reinterpret_cast(data_column->get_raw_data().data); \ - *output_intermediate_state_ptr = buffer_size; \ - res = env->CallNonvirtualBooleanMethod(executor_obj, executor_cl, \ - executor_result_id, to.size() - 1, place); \ - } \ - } \ - } else { \ - return Status::InvalidArgument(strings::Substitute( \ - "Java UDAF doesn't support return type is $0 now !", result_type->get_name())); \ - } -#endif - EVALUATE_JAVA_UDAF; + RETURN_IF_ERROR(get_result(to, result_type, place, env, data_col)); } else { *output_null_value = -1; auto& data_col = to; - EVALUATE_JAVA_UDAF; + RETURN_IF_ERROR(get_result(to, result_type, place, env, data_col)); } return JniUtil::GetJniExceptionMsg(env); } private: + Status get_result(IColumn& to, const DataTypePtr& result_type, int64_t place, JNIEnv* env, + IColumn& data_col) const { + if (data_col.is_column_string()) { + const ColumnString* str_col = check_and_get_column(data_col); + ColumnString::Chars& chars = const_cast(str_col->get_chars()); + ColumnString::Offsets& offsets = + const_cast(str_col->get_offsets()); + int increase_buffer_size = 0; + int64_t buffer_size = JniUtil::IncreaseReservedBufferSize(increase_buffer_size); + chars.resize(buffer_size); + *output_value_buffer = reinterpret_cast(chars.data()); + *output_offsets_ptr = reinterpret_cast(offsets.data()); + *output_intermediate_state_ptr = chars.size(); + jboolean res = env->CallNonvirtualBooleanMethod( + executor_obj, executor_cl, executor_result_id, to.size() - 1, place); + while (res != JNI_TRUE) { + RETURN_IF_ERROR(JniUtil::GetJniExceptionMsg(env)); + increase_buffer_size++; + buffer_size = JniUtil::IncreaseReservedBufferSize(increase_buffer_size); + try { + chars.resize(buffer_size); + } catch (std::bad_alloc const& e) { + throw doris::Exception(ErrorCode::INTERNAL_ERROR, + "memory allocate failed in column string, " + "buffer:{},size:{},reason:{}", + increase_buffer_size, buffer_size, e.what()); + } + *output_value_buffer = reinterpret_cast(chars.data()); + *output_intermediate_state_ptr = chars.size(); + res = env->CallNonvirtualBooleanMethod(executor_obj, executor_cl, + executor_result_id, to.size() - 1, place); + } + } else if (data_col.is_numeric() || data_col.is_column_decimal()) { + *output_value_buffer = reinterpret_cast(data_col.get_raw_data().data); + env->CallNonvirtualBooleanMethod(executor_obj, executor_cl, executor_result_id, + to.size() - 1, place); + } else if (data_col.is_column_array()) { + ColumnArray& array_col = assert_cast(data_col); + ColumnNullable& array_nested_nullable = + assert_cast(array_col.get_data()); + auto data_column_null_map = array_nested_nullable.get_null_map_column_ptr(); + auto data_column = array_nested_nullable.get_nested_column_ptr(); + auto& offset_column = array_col.get_offsets_column(); + int increase_buffer_size = 0; + int64_t buffer_size = JniUtil::IncreaseReservedBufferSize(increase_buffer_size); + *output_offsets_ptr = reinterpret_cast(offset_column.get_raw_data().data); + data_column_null_map->resize(buffer_size); + auto& null_map_data = + assert_cast*>(data_column_null_map.get())->get_data(); + *output_array_null_ptr = reinterpret_cast(null_map_data.data()); + *output_intermediate_state_ptr = buffer_size; + if (data_column->is_column_string()) { + ColumnString* str_col = assert_cast(data_column.get()); + ColumnString::Chars& chars = + assert_cast(str_col->get_chars()); + ColumnString::Offsets& offsets = + assert_cast(str_col->get_offsets()); + chars.resize(buffer_size); + offsets.resize(buffer_size); + *output_value_buffer = reinterpret_cast(chars.data()); + *output_array_string_offsets_ptr = reinterpret_cast(offsets.data()); + jboolean res = env->CallNonvirtualBooleanMethod( + executor_obj, executor_cl, executor_result_id, to.size() - 1, place); + while (res != JNI_TRUE) { + RETURN_IF_ERROR(JniUtil::GetJniExceptionMsg(env)); + increase_buffer_size++; + buffer_size = JniUtil::IncreaseReservedBufferSize(increase_buffer_size); + try { + null_map_data.resize(buffer_size); + chars.resize(buffer_size); + offsets.resize(buffer_size); + } catch (std::bad_alloc const& e) { + throw doris::Exception(ErrorCode::INTERNAL_ERROR, + "memory allocate failed in array column string, " + "buffer:{},size:{},reason:{}", + increase_buffer_size, buffer_size, e.what()); + } + *output_array_null_ptr = reinterpret_cast(null_map_data.data()); + *output_value_buffer = reinterpret_cast(chars.data()); + *output_array_string_offsets_ptr = reinterpret_cast(offsets.data()); + *output_intermediate_state_ptr = buffer_size; + res = env->CallNonvirtualBooleanMethod( + executor_obj, executor_cl, executor_result_id, to.size() - 1, place); + } + } else { + data_column->resize(buffer_size); + *output_value_buffer = reinterpret_cast(data_column->get_raw_data().data); + jboolean res = env->CallNonvirtualBooleanMethod( + executor_obj, executor_cl, executor_result_id, to.size() - 1, place); + while (res != JNI_TRUE) { + RETURN_IF_ERROR(JniUtil::GetJniExceptionMsg(env)); + increase_buffer_size++; + buffer_size = JniUtil::IncreaseReservedBufferSize(increase_buffer_size); + try { + null_map_data.resize(buffer_size); + data_column->resize(buffer_size); + } catch (std::bad_alloc const& e) { + throw doris::Exception(ErrorCode::INTERNAL_ERROR, + "memory allocate failed in array number column, " + "buffer:{},size:{},reason:{}", + increase_buffer_size, buffer_size, e.what()); + } + *output_array_null_ptr = reinterpret_cast(null_map_data.data()); + *output_value_buffer = + reinterpret_cast(data_column->get_raw_data().data); + *output_intermediate_state_ptr = buffer_size; + res = env->CallNonvirtualBooleanMethod( + executor_obj, executor_cl, executor_result_id, to.size() - 1, place); + } + } + } else { + return Status::InvalidArgument(strings::Substitute( + "Java UDAF doesn't support return type is $0 now !", result_type->get_name())); + } + return Status::OK(); + } + Status register_func_id(JNIEnv* env) { auto register_id = [&](const char* func_name, const char* func_sign, jmethodID& func_id) { func_id = env->GetMethodID(executor_cl, func_name, func_sign); @@ -397,7 +437,6 @@ private: } return s; }; - RETURN_IF_ERROR(register_id("", UDAF_EXECUTOR_CTOR_SIGNATURE, executor_ctor_id)); RETURN_IF_ERROR(register_id("add", UDAF_EXECUTOR_ADD_SIGNATURE, executor_add_id)); RETURN_IF_ERROR(register_id("reset", UDAF_EXECUTOR_RESET_SIGNATURE, executor_reset_id)); @@ -413,6 +452,8 @@ private: executor_convert_basic_argument_id)); RETURN_IF_ERROR(register_id("convertArrayArguments", "(IZIIJJJJJ)[Ljava/lang/Object;", executor_convert_array_argument_id)); + RETURN_IF_ERROR(register_id("convertMapArguments", "(IZIIJJJJJJJJ)[Ljava/lang/Object;", + executor_convert_map_argument_id)); RETURN_IF_ERROR( register_id("addBatch", "(ZIIJI[Ljava/lang/Object;)V", executor_add_batch_id)); return Status::OK(); @@ -435,13 +476,7 @@ private: jmethodID executor_destroy_id; jmethodID executor_convert_basic_argument_id; jmethodID executor_convert_array_argument_id; - - std::unique_ptr input_values_buffer_ptr; - std::unique_ptr input_nulls_buffer_ptr; - std::unique_ptr input_offsets_ptrs; - std::unique_ptr input_array_nulls_buffer_ptr; - std::unique_ptr input_array_string_offsets_ptrs; - std::unique_ptr input_place_ptrs; + jmethodID executor_convert_map_argument_id; std::unique_ptr output_value_buffer; std::unique_ptr output_null_value; std::unique_ptr output_offsets_ptr; diff --git a/be/src/vec/functions/function_java_udf.cpp b/be/src/vec/functions/function_java_udf.cpp index 46bf887515..a2d4124551 100644 --- a/be/src/vec/functions/function_java_udf.cpp +++ b/be/src/vec/functions/function_java_udf.cpp @@ -101,23 +101,6 @@ Status JavaFunctionCall::open(FunctionContext* context, FunctionContext::Functio TJavaUdfExecutorCtorParams ctor_params; ctor_params.__set_fn(fn_); ctor_params.__set_location(local_location); - ctor_params.__set_input_offsets_ptrs((int64_t)jni_ctx->input_offsets_ptrs.get()); - ctor_params.__set_input_buffer_ptrs((int64_t)jni_ctx->input_values_buffer_ptr.get()); - ctor_params.__set_input_nulls_ptrs((int64_t)jni_ctx->input_nulls_buffer_ptr.get()); - ctor_params.__set_input_array_nulls_buffer_ptr( - (int64_t)jni_ctx->input_array_nulls_buffer_ptr.get()); - ctor_params.__set_input_array_string_offsets_ptrs( - (int64_t)jni_ctx->input_array_string_offsets_ptrs.get()); - ctor_params.__set_output_buffer_ptr((int64_t)jni_ctx->output_value_buffer.get()); - ctor_params.__set_output_null_ptr((int64_t)jni_ctx->output_null_value.get()); - ctor_params.__set_output_offsets_ptr((int64_t)jni_ctx->output_offsets_ptr.get()); - ctor_params.__set_output_array_null_ptr((int64_t)jni_ctx->output_array_null_ptr.get()); - ctor_params.__set_output_array_string_offsets_ptr( - (int64_t)jni_ctx->output_array_string_offsets_ptr.get()); - ctor_params.__set_output_intermediate_state_ptr( - (int64_t)jni_ctx->output_intermediate_state_ptr.get()); - ctor_params.__set_batch_size_ptr((int64_t)jni_ctx->batch_size_ptr.get()); - jbyteArray ctor_params_bytes; // Pushed frame will be popped when jni_frame goes out-of-scope. @@ -255,7 +238,6 @@ Status JavaFunctionCall::execute(FunctionContext* context, Block& block, ->get_data() .data()); int64_t value_nested_data_address = 0, value_nested_offset_address = 0; - // array type need pass address: [nullmap_address], offset_address, nested_nullmap_address, nested_data_address/nested_char_address,nested_offset_address if (value_data_column->is_column_string()) { const ColumnString* col = assert_cast(value_data_column.get()); value_nested_data_address = reinterpret_cast(col->get_chars().data()); diff --git a/be/src/vec/functions/function_java_udf.h b/be/src/vec/functions/function_java_udf.h index c0828a2a3f..ddbe300e89 100644 --- a/be/src/vec/functions/function_java_udf.h +++ b/be/src/vec/functions/function_java_udf.h @@ -115,39 +115,8 @@ private: jobject executor = nullptr; bool is_closed = false; - std::unique_ptr input_values_buffer_ptr; - std::unique_ptr input_nulls_buffer_ptr; - std::unique_ptr input_offsets_ptrs; - //used for array type nested column null map, because array nested column must be nullable - std::unique_ptr input_array_nulls_buffer_ptr; - //used for array type of nested string column offset, not the array column offset - std::unique_ptr input_array_string_offsets_ptrs; - std::unique_ptr output_value_buffer; - std::unique_ptr output_null_value; - std::unique_ptr output_offsets_ptr; - //used for array type nested column null map - std::unique_ptr output_array_null_ptr; - //used for array type of nested string column offset - std::unique_ptr output_array_string_offsets_ptr; - std::unique_ptr batch_size_ptr; - // intermediate_state includes two parts: reserved / used buffer size and rows - std::unique_ptr output_intermediate_state_ptr; - JniContext(int64_t num_args, jclass executor_cl, jmethodID executor_close_id) - : executor_cl_(executor_cl), - executor_close_id_(executor_close_id), - input_values_buffer_ptr(new int64_t[num_args]), - input_nulls_buffer_ptr(new int64_t[num_args]), - input_offsets_ptrs(new int64_t[num_args]), - input_array_nulls_buffer_ptr(new int64_t[num_args]), - input_array_string_offsets_ptrs(new int64_t[num_args]), - output_value_buffer(new int64_t()), - output_null_value(new int64_t()), - output_offsets_ptr(new int64_t()), - output_array_null_ptr(new int64_t()), - output_array_string_offsets_ptr(new int64_t()), - batch_size_ptr(new int32_t()), - output_intermediate_state_ptr(new IntermediateState()) {} + : executor_cl_(executor_cl), executor_close_id_(executor_close_id) {} void close() { if (is_closed) { diff --git a/fe/be-java-extensions/java-udf/src/main/java/org/apache/doris/udf/BaseExecutor.java b/fe/be-java-extensions/java-udf/src/main/java/org/apache/doris/udf/BaseExecutor.java index df5026742d..eae5270872 100644 --- a/fe/be-java-extensions/java-udf/src/main/java/org/apache/doris/udf/BaseExecutor.java +++ b/fe/be-java-extensions/java-udf/src/main/java/org/apache/doris/udf/BaseExecutor.java @@ -25,6 +25,7 @@ import org.apache.doris.common.jni.utils.UdfUtils; import org.apache.doris.common.jni.utils.UdfUtils.JavaUdfDataType; import org.apache.doris.thrift.TJavaUdfExecutorCtorParams; +import com.esotericsoftware.reflectasm.MethodAccess; import com.google.common.base.Preconditions; import org.apache.log4j.Logger; import org.apache.thrift.TDeserializer; @@ -33,6 +34,7 @@ import org.apache.thrift.protocol.TBinaryProtocol; import java.io.IOException; import java.lang.reflect.Array; +import java.lang.reflect.Method; import java.math.BigDecimal; import java.math.BigInteger; import java.math.RoundingMode; @@ -42,6 +44,8 @@ import java.time.LocalDate; import java.time.LocalDateTime; import java.util.ArrayList; import java.util.Arrays; +import java.util.HashMap; +import java.util.Map; public abstract class BaseExecutor { private static final Logger LOG = Logger.getLogger(BaseExecutor.class); @@ -88,12 +92,14 @@ public abstract class BaseExecutor { protected final long outputArrayStringOffsetsPtr; protected final long outputIntermediateStatePtr; protected Class[] argClass; + protected MethodAccess methodAccess; /** * Create a UdfExecutor, using parameters from a serialized thrift object. Used * by * the backend. */ + public BaseExecutor(byte[] thriftParams) throws Exception { TJavaUdfExecutorCtorParams request = new TJavaUdfExecutorCtorParams(); TDeserializer deserializer = new TDeserializer(PROTOCOL_FACTORY); @@ -1320,4 +1326,498 @@ public abstract class BaseExecutor { } return argument; } + + public Object[] buildHashMap(PrimitiveType keyType, PrimitiveType valueType, Object[] keyCol, Object[] valueCol) { + switch (keyType) { + case BOOLEAN: { + return new HashMapBuilder().get(keyCol, valueCol, valueType); + } + case TINYINT: { + return new HashMapBuilder().get(keyCol, valueCol, valueType); + } + case SMALLINT: { + return new HashMapBuilder().get(keyCol, valueCol, valueType); + } + case INT: { + return new HashMapBuilder().get(keyCol, valueCol, valueType); + } + case BIGINT: { + return new HashMapBuilder().get(keyCol, valueCol, valueType); + } + case LARGEINT: { + return new HashMapBuilder().get(keyCol, valueCol, valueType); + } + case FLOAT: { + return new HashMapBuilder().get(keyCol, valueCol, valueType); + } + case DOUBLE: { + return new HashMapBuilder().get(keyCol, valueCol, valueType); + } + case CHAR: + case VARCHAR: + case STRING: { + return new HashMapBuilder().get(keyCol, valueCol, valueType); + } + case DATEV2: + case DATE: { + return new HashMapBuilder().get(keyCol, valueCol, valueType); + } + case DATETIMEV2: + case DATETIME: { + return new HashMapBuilder().get(keyCol, valueCol, valueType); + } + case DECIMAL32: + case DECIMAL64: + case DECIMALV2: + case DECIMAL128: { + return new HashMapBuilder().get(keyCol, valueCol, valueType); + } + default: { + LOG.info("Not support: " + keyType); + Preconditions.checkState(false, "Not support type " + keyType.toString()); + break; + } + } + return null; + } + + public static class HashMapBuilder { + public Object[] get(Object[] keyCol, Object[] valueCol, PrimitiveType valueType) { + switch (valueType) { + case BOOLEAN: { + return new BuildMapFromType().get(keyCol, valueCol); + } + case TINYINT: { + return new BuildMapFromType().get(keyCol, valueCol); + } + case SMALLINT: { + return new BuildMapFromType().get(keyCol, valueCol); + } + case INT: { + return new BuildMapFromType().get(keyCol, valueCol); + } + case BIGINT: { + return new BuildMapFromType().get(keyCol, valueCol); + } + case LARGEINT: { + return new BuildMapFromType().get(keyCol, valueCol); + } + case FLOAT: { + return new BuildMapFromType().get(keyCol, valueCol); + } + case DOUBLE: { + return new BuildMapFromType().get(keyCol, valueCol); + } + case CHAR: + case VARCHAR: + case STRING: { + return new BuildMapFromType().get(keyCol, valueCol); + } + case DATEV2: + case DATE: { + return new BuildMapFromType().get(keyCol, valueCol); + } + case DATETIMEV2: + case DATETIME: { + return new BuildMapFromType().get(keyCol, valueCol); + } + case DECIMAL32: + case DECIMAL64: + case DECIMALV2: + case DECIMAL128: { + return new BuildMapFromType().get(keyCol, valueCol); + } + default: { + LOG.info("Not support: " + valueType); + Preconditions.checkState(false, "Not support type " + valueType.toString()); + break; + } + } + return null; + } + } + + public static class BuildMapFromType { + public Object[] get(Object[] keyCol, Object[] valueCol) { + Object[] retHashMap = new HashMap[keyCol.length]; + for (int colIdx = 0; colIdx < keyCol.length; colIdx++) { + HashMap hashMap = new HashMap<>(); + ArrayList keys = (ArrayList) (keyCol[colIdx]); + ArrayList values = (ArrayList) (valueCol[colIdx]); + for (int i = 0; i < keys.size(); i++) { + T1 key = keys.get(i); + T2 value = values.get(i); + if (!hashMap.containsKey(key)) { + hashMap.put(key, value); + } + } + retHashMap[colIdx] = hashMap; + } + return retHashMap; + } + } + + public void copyBatchBasicResultImpl(boolean isNullable, int numRows, Object[] result, long nullMapAddr, + long resColumnAddr, long strOffsetAddr, Method method) { + switch (retType) { + case BOOLEAN: { + UdfConvert.copyBatchBooleanResult(isNullable, numRows, (Boolean[]) result, nullMapAddr, resColumnAddr); + break; + } + case TINYINT: { + UdfConvert.copyBatchTinyIntResult(isNullable, numRows, (Byte[]) result, nullMapAddr, resColumnAddr); + break; + } + case SMALLINT: { + UdfConvert.copyBatchSmallIntResult(isNullable, numRows, (Short[]) result, nullMapAddr, resColumnAddr); + break; + } + case INT: { + UdfConvert.copyBatchIntResult(isNullable, numRows, (Integer[]) result, nullMapAddr, resColumnAddr); + break; + } + case BIGINT: { + UdfConvert.copyBatchBigIntResult(isNullable, numRows, (Long[]) result, nullMapAddr, resColumnAddr); + break; + } + case LARGEINT: { + UdfConvert.copyBatchLargeIntResult(isNullable, numRows, (BigInteger[]) result, nullMapAddr, + resColumnAddr); + break; + } + case FLOAT: { + UdfConvert.copyBatchFloatResult(isNullable, numRows, (Float[]) result, nullMapAddr, resColumnAddr); + break; + } + case DOUBLE: { + UdfConvert.copyBatchDoubleResult(isNullable, numRows, (Double[]) result, nullMapAddr, resColumnAddr); + break; + } + case CHAR: + case VARCHAR: + case STRING: { + UdfConvert.copyBatchStringResult(isNullable, numRows, (String[]) result, nullMapAddr, resColumnAddr, + strOffsetAddr); + break; + } + case DATE: { + UdfConvert.copyBatchDateResult(method.getReturnType(), isNullable, numRows, result, + nullMapAddr, resColumnAddr); + break; + } + case DATETIME: { + UdfConvert + .copyBatchDateTimeResult(method.getReturnType(), isNullable, numRows, result, + nullMapAddr, + resColumnAddr); + break; + } + case DATEV2: { + UdfConvert.copyBatchDateV2Result(method.getReturnType(), isNullable, numRows, result, + nullMapAddr, + resColumnAddr); + break; + } + case DATETIMEV2: { + UdfConvert.copyBatchDateTimeV2Result(method.getReturnType(), isNullable, numRows, + result, nullMapAddr, + resColumnAddr); + break; + } + case DECIMALV2: + case DECIMAL128: { + UdfConvert.copyBatchDecimal128Result(retType.getScale(), isNullable, numRows, (BigDecimal[]) result, + nullMapAddr, + resColumnAddr); + break; + } + case DECIMAL32: { + UdfConvert.copyBatchDecimal32Result(retType.getScale(), isNullable, numRows, (BigDecimal[]) result, + nullMapAddr, + resColumnAddr); + break; + } + case DECIMAL64: { + UdfConvert.copyBatchDecimal64Result(retType.getScale(), isNullable, numRows, (BigDecimal[]) result, + nullMapAddr, + resColumnAddr); + break; + } + default: { + LOG.info("Not support return type: " + retType); + Preconditions.checkState(false, "Not support type: " + retType.toString()); + break; + } + } + } + + public void copyBatchArrayResultImpl(boolean isNullable, int numRows, Object[] result, long nullMapAddr, + long offsetsAddr, long nestedNullMapAddr, long dataAddr, long strOffsetAddr, + PrimitiveType type) { + long hasPutElementNum = 0; + for (int row = 0; row < numRows; ++row) { + switch (type) { + case BOOLEAN: { + hasPutElementNum = UdfConvert + .copyBatchArrayBooleanResult(hasPutElementNum, isNullable, row, result, nullMapAddr, + offsetsAddr, nestedNullMapAddr, dataAddr); + break; + } + case TINYINT: { + hasPutElementNum = UdfConvert + .copyBatchArrayTinyIntResult(hasPutElementNum, isNullable, row, result, nullMapAddr, + offsetsAddr, nestedNullMapAddr, dataAddr); + break; + } + case SMALLINT: { + hasPutElementNum = UdfConvert + .copyBatchArraySmallIntResult(hasPutElementNum, isNullable, row, result, nullMapAddr, + offsetsAddr, nestedNullMapAddr, dataAddr); + break; + } + case INT: { + hasPutElementNum = UdfConvert + .copyBatchArrayIntResult(hasPutElementNum, isNullable, row, result, nullMapAddr, + offsetsAddr, nestedNullMapAddr, dataAddr); + break; + } + case BIGINT: { + hasPutElementNum = UdfConvert + .copyBatchArrayBigIntResult(hasPutElementNum, isNullable, row, result, nullMapAddr, + offsetsAddr, nestedNullMapAddr, dataAddr); + break; + } + case LARGEINT: { + hasPutElementNum = UdfConvert + .copyBatchArrayLargeIntResult(hasPutElementNum, isNullable, row, result, nullMapAddr, + offsetsAddr, nestedNullMapAddr, dataAddr); + break; + } + case FLOAT: { + hasPutElementNum = UdfConvert + .copyBatchArrayFloatResult(hasPutElementNum, isNullable, row, result, nullMapAddr, + offsetsAddr, nestedNullMapAddr, dataAddr); + break; + } + case DOUBLE: { + hasPutElementNum = UdfConvert + .copyBatchArrayDoubleResult(hasPutElementNum, isNullable, row, result, nullMapAddr, + offsetsAddr, nestedNullMapAddr, dataAddr); + break; + } + case CHAR: + case VARCHAR: + case STRING: { + hasPutElementNum = UdfConvert + .copyBatchArrayStringResult(hasPutElementNum, isNullable, row, result, nullMapAddr, + offsetsAddr, nestedNullMapAddr, dataAddr, strOffsetAddr); + break; + } + case DATE: { + hasPutElementNum = UdfConvert + .copyBatchArrayDateResult(hasPutElementNum, isNullable, row, result, nullMapAddr, + offsetsAddr, nestedNullMapAddr, dataAddr); + break; + } + case DATETIME: { + hasPutElementNum = UdfConvert + .copyBatchArrayDateTimeResult(hasPutElementNum, isNullable, row, result, nullMapAddr, + offsetsAddr, nestedNullMapAddr, dataAddr); + break; + } + case DATEV2: { + hasPutElementNum = UdfConvert + .copyBatchArrayDateV2Result(hasPutElementNum, isNullable, row, result, nullMapAddr, + offsetsAddr, nestedNullMapAddr, dataAddr); + break; + } + case DATETIMEV2: { + hasPutElementNum = UdfConvert + .copyBatchArrayDateTimeV2Result(hasPutElementNum, isNullable, row, result, nullMapAddr, + offsetsAddr, nestedNullMapAddr, dataAddr); + break; + } + case DECIMALV2: { + hasPutElementNum = UdfConvert + .copyBatchArrayDecimalResult(hasPutElementNum, isNullable, row, result, nullMapAddr, + offsetsAddr, nestedNullMapAddr, dataAddr); + break; + } + case DECIMAL32: { + hasPutElementNum = UdfConvert + .copyBatchArrayDecimalV3Result(retType.getScale(), 4L, hasPutElementNum, isNullable, row, + result, nullMapAddr, + offsetsAddr, nestedNullMapAddr, dataAddr); + break; + } + case DECIMAL64: { + hasPutElementNum = UdfConvert + .copyBatchArrayDecimalV3Result(retType.getScale(), 8L, hasPutElementNum, isNullable, row, + result, nullMapAddr, + offsetsAddr, nestedNullMapAddr, dataAddr); + break; + } + case DECIMAL128: { + hasPutElementNum = UdfConvert + .copyBatchArrayDecimalV3Result(retType.getScale(), 16L, hasPutElementNum, isNullable, row, + result, nullMapAddr, + offsetsAddr, nestedNullMapAddr, dataAddr); + break; + } + default: { + Preconditions.checkState(false, "Not support type in array: " + retType); + break; + } + } + } + } + + public void buildArrayListFromHashMap(Object[] result, PrimitiveType keyType, PrimitiveType valueType, + Object[] keyCol, Object[] valueCol) { + switch (keyType) { + case BOOLEAN: { + new ArrayListBuilder().get(result, keyCol, valueCol, valueType); + break; + } + case TINYINT: { + new ArrayListBuilder().get(result, keyCol, valueCol, valueType); + break; + } + case SMALLINT: { + new ArrayListBuilder().get(result, keyCol, valueCol, valueType); + break; + } + case INT: { + new ArrayListBuilder().get(result, keyCol, valueCol, valueType); + break; + } + case BIGINT: { + new ArrayListBuilder().get(result, keyCol, valueCol, valueType); + break; + } + case LARGEINT: { + new ArrayListBuilder().get(result, keyCol, valueCol, valueType); + break; + } + case FLOAT: { + new ArrayListBuilder().get(result, keyCol, valueCol, valueType); + break; + } + case DOUBLE: { + new ArrayListBuilder().get(result, keyCol, valueCol, valueType); + break; + } + case CHAR: + case VARCHAR: + case STRING: { + new ArrayListBuilder().get(result, keyCol, valueCol, valueType); + break; + } + case DATEV2: + case DATE: { + new ArrayListBuilder().get(result, keyCol, valueCol, valueType); + break; + } + case DATETIMEV2: + case DATETIME: { + new ArrayListBuilder().get(result, keyCol, valueCol, valueType); + break; + } + case DECIMAL32: + case DECIMAL64: + case DECIMALV2: + case DECIMAL128: { + new ArrayListBuilder().get(result, keyCol, valueCol, valueType); + break; + } + default: { + LOG.info("Not support: " + keyType); + Preconditions.checkState(false, "Not support type " + keyType.toString()); + break; + } + } + } + + public static class ArrayListBuilder { + public void get(Object[] map, Object[] keyCol, Object[] valueCol, PrimitiveType valueType) { + switch (valueType) { + case BOOLEAN: { + new BuildArrayFromType().get(map, keyCol, valueCol); + break; + } + case TINYINT: { + new BuildArrayFromType().get(map, keyCol, valueCol); + break; + } + case SMALLINT: { + new BuildArrayFromType().get(map, keyCol, valueCol); + break; + } + case INT: { + new BuildArrayFromType().get(map, keyCol, valueCol); + break; + } + case BIGINT: { + new BuildArrayFromType().get(map, keyCol, valueCol); + break; + } + case LARGEINT: { + new BuildArrayFromType().get(map, keyCol, valueCol); + break; + } + case FLOAT: { + new BuildArrayFromType().get(map, keyCol, valueCol); + break; + } + case DOUBLE: { + new BuildArrayFromType().get(map, keyCol, valueCol); + break; + } + case CHAR: + case VARCHAR: + case STRING: { + new BuildArrayFromType().get(map, keyCol, valueCol); + break; + } + case DATEV2: + case DATE: { + new BuildArrayFromType().get(map, keyCol, valueCol); + break; + } + case DATETIMEV2: + case DATETIME: { + new BuildArrayFromType().get(map, keyCol, valueCol); + break; + } + case DECIMAL32: + case DECIMAL64: + case DECIMALV2: + case DECIMAL128: { + new BuildArrayFromType().get(map, keyCol, valueCol); + break; + } + default: { + LOG.info("Not support: " + valueType); + Preconditions.checkState(false, "Not support type " + valueType.toString()); + break; + } + } + } + } + + public static class BuildArrayFromType { + public void get(Object[] map, Object[] keyCol, Object[] valueCol) { + for (int colIdx = 0; colIdx < map.length; colIdx++) { + HashMap hashMap = (HashMap) map[colIdx]; + ArrayList keys = new ArrayList<>(); + ArrayList values = new ArrayList<>(); + for (Map.Entry entry : hashMap.entrySet()) { + keys.add(entry.getKey()); + values.add(entry.getValue()); + } + keyCol[colIdx] = keys; + valueCol[colIdx] = values; + } + } + } } diff --git a/fe/be-java-extensions/java-udf/src/main/java/org/apache/doris/udf/UdafExecutor.java b/fe/be-java-extensions/java-udf/src/main/java/org/apache/doris/udf/UdafExecutor.java index a0736b5a72..dff689ed40 100644 --- a/fe/be-java-extensions/java-udf/src/main/java/org/apache/doris/udf/UdafExecutor.java +++ b/fe/be-java-extensions/java-udf/src/main/java/org/apache/doris/udf/UdafExecutor.java @@ -17,6 +17,7 @@ package org.apache.doris.udf; +import org.apache.doris.catalog.PrimitiveType; import org.apache.doris.catalog.Type; import org.apache.doris.common.Pair; import org.apache.doris.common.exception.UdfRuntimeException; @@ -52,7 +53,6 @@ public class UdafExecutor extends BaseExecutor { private HashMap stateObjMap; private Class retClass; private int addIndex; - private MethodAccess methodAccess; /** * Constructor to create an object. @@ -81,6 +81,21 @@ public class UdafExecutor extends BaseExecutor { dataAddr, strOffsetAddr); } + public Object[] convertMapArguments(int argIdx, boolean isNullable, int rowStart, int rowEnd, long nullMapAddr, + long offsetsAddr, long keyNestedNullMapAddr, long keyDataAddr, long keyStrOffsetAddr, + long valueNestedNullMapAddr, long valueDataAddr, long valueStrOffsetAddr) { + PrimitiveType keyType = argTypes[argIdx].getKeyType().getPrimitiveType(); + PrimitiveType valueType = argTypes[argIdx].getValueType().getPrimitiveType(); + Object[] keyCol = convertMapArg(keyType, argIdx, isNullable, rowStart, rowEnd, nullMapAddr, offsetsAddr, + keyNestedNullMapAddr, keyDataAddr, + keyStrOffsetAddr); + Object[] valueCol = convertMapArg(valueType, argIdx, isNullable, rowStart, rowEnd, nullMapAddr, offsetsAddr, + valueNestedNullMapAddr, + valueDataAddr, + valueStrOffsetAddr); + return buildHashMap(keyType, valueType, keyCol, valueCol); + } + public void addBatch(boolean isSinglePlace, int rowStart, int rowEnd, long placeAddr, int offset, Object[] column) throws UdfRuntimeException { if (isSinglePlace) { @@ -111,7 +126,7 @@ public class UdafExecutor extends BaseExecutor { methodAccess.invoke(udf, addIndex, inputArgs); } } catch (Exception e) { - LOG.warn("invoke add function meet some error: " + e.getCause().toString()); + LOG.info("invoke add function meet some error: " + e.getCause().toString()); throw new UdfRuntimeException("UDAF failed to addBatchSingle: ", e); } } @@ -143,7 +158,7 @@ public class UdafExecutor extends BaseExecutor { methodAccess.invoke(udf, addIndex, inputArgs); } } catch (Exception e) { - LOG.warn("invoke add function meet some error: " + Arrays.toString(e.getStackTrace())); + LOG.info("invoke add function meet some error: " + Arrays.toString(e.getStackTrace())); throw new UdfRuntimeException("UDAF failed to addBatchPlaces: ", e); } } diff --git a/fe/be-java-extensions/java-udf/src/main/java/org/apache/doris/udf/UdfExecutor.java b/fe/be-java-extensions/java-udf/src/main/java/org/apache/doris/udf/UdfExecutor.java index 1140d1824b..2f6ca99fdd 100644 --- a/fe/be-java-extensions/java-udf/src/main/java/org/apache/doris/udf/UdfExecutor.java +++ b/fe/be-java-extensions/java-udf/src/main/java/org/apache/doris/udf/UdfExecutor.java @@ -34,14 +34,8 @@ import org.apache.log4j.Logger; import java.lang.reflect.Array; import java.lang.reflect.Constructor; import java.lang.reflect.Method; -import java.math.BigDecimal; -import java.math.BigInteger; import java.net.MalformedURLException; -import java.time.LocalDate; -import java.time.LocalDateTime; import java.util.ArrayList; -import java.util.HashMap; -import java.util.Map; public class UdfExecutor extends BaseExecutor { // private static final java.util.logging.Logger LOG = @@ -60,7 +54,6 @@ public class UdfExecutor extends BaseExecutor { private long batchSizePtr; private int evaluateIndex; - private MethodAccess methodAccess; /** * Create a UdfExecutor, using parameters from a serialized thrift object. Used by @@ -147,57 +140,7 @@ public class UdfExecutor extends BaseExecutor { Object[] valueCol = convertMapArg(valueType, argIdx, isNullable, 0, numRows, nullMapAddr, offsetsAddr, valueNestedNullMapAddr, valueDataAddr, valueStrOffsetAddr); - switch (keyType) { - case BOOLEAN: { - return new HashMapBuilder().get(keyCol, valueCol, valueType); - } - case TINYINT: { - return new HashMapBuilder().get(keyCol, valueCol, valueType); - } - case SMALLINT: { - return new HashMapBuilder().get(keyCol, valueCol, valueType); - } - case INT: { - return new HashMapBuilder().get(keyCol, valueCol, valueType); - } - case BIGINT: { - return new HashMapBuilder().get(keyCol, valueCol, valueType); - } - case LARGEINT: { - return new HashMapBuilder().get(keyCol, valueCol, valueType); - } - case FLOAT: { - return new HashMapBuilder().get(keyCol, valueCol, valueType); - } - case DOUBLE: { - return new HashMapBuilder().get(keyCol, valueCol, valueType); - } - case CHAR: - case VARCHAR: - case STRING: { - return new HashMapBuilder().get(keyCol, valueCol, valueType); - } - case DATEV2: - case DATE: { - return new HashMapBuilder().get(keyCol, valueCol, valueType); - } - case DATETIMEV2: - case DATETIME: { - return new HashMapBuilder().get(keyCol, valueCol, valueType); - } - case DECIMAL32: - case DECIMAL64: - case DECIMALV2: - case DECIMAL128: { - return new HashMapBuilder().get(keyCol, valueCol, valueType); - } - default: { - LOG.info("Not support: " + keyType); - Preconditions.checkState(false, "Not support type " + keyType.toString()); - break; - } - } - return null; + return buildHashMap(keyType, valueType, keyCol, valueCol); } /** @@ -223,217 +166,7 @@ public class UdfExecutor extends BaseExecutor { public void copyBatchBasicResult(boolean isNullable, int numRows, Object[] result, long nullMapAddr, long resColumnAddr, long strOffsetAddr) { - switch (retType) { - case BOOLEAN: { - UdfConvert.copyBatchBooleanResult(isNullable, numRows, (Boolean[]) result, nullMapAddr, resColumnAddr); - break; - } - case TINYINT: { - UdfConvert.copyBatchTinyIntResult(isNullable, numRows, (Byte[]) result, nullMapAddr, resColumnAddr); - break; - } - case SMALLINT: { - UdfConvert.copyBatchSmallIntResult(isNullable, numRows, (Short[]) result, nullMapAddr, resColumnAddr); - break; - } - case INT: { - UdfConvert.copyBatchIntResult(isNullable, numRows, (Integer[]) result, nullMapAddr, resColumnAddr); - break; - } - case BIGINT: { - UdfConvert.copyBatchBigIntResult(isNullable, numRows, (Long[]) result, nullMapAddr, resColumnAddr); - break; - } - case LARGEINT: { - UdfConvert.copyBatchLargeIntResult(isNullable, numRows, (BigInteger[]) result, nullMapAddr, - resColumnAddr); - break; - } - case FLOAT: { - UdfConvert.copyBatchFloatResult(isNullable, numRows, (Float[]) result, nullMapAddr, resColumnAddr); - break; - } - case DOUBLE: { - UdfConvert.copyBatchDoubleResult(isNullable, numRows, (Double[]) result, nullMapAddr, resColumnAddr); - break; - } - case CHAR: - case VARCHAR: - case STRING: { - UdfConvert.copyBatchStringResult(isNullable, numRows, (String[]) result, nullMapAddr, resColumnAddr, - strOffsetAddr); - break; - } - case DATE: { - UdfConvert.copyBatchDateResult(method.getReturnType(), isNullable, numRows, result, - nullMapAddr, resColumnAddr); - break; - } - case DATETIME: { - UdfConvert - .copyBatchDateTimeResult(method.getReturnType(), isNullable, numRows, result, - nullMapAddr, - resColumnAddr); - break; - } - case DATEV2: { - UdfConvert.copyBatchDateV2Result(method.getReturnType(), isNullable, numRows, result, - nullMapAddr, - resColumnAddr); - break; - } - case DATETIMEV2: { - UdfConvert.copyBatchDateTimeV2Result(method.getReturnType(), isNullable, numRows, - result, nullMapAddr, - resColumnAddr); - break; - } - case DECIMALV2: - case DECIMAL128: { - UdfConvert.copyBatchDecimal128Result(retType.getScale(), isNullable, numRows, (BigDecimal[]) result, - nullMapAddr, - resColumnAddr); - break; - } - case DECIMAL32: { - UdfConvert.copyBatchDecimal32Result(retType.getScale(), isNullable, numRows, (BigDecimal[]) result, - nullMapAddr, - resColumnAddr); - break; - } - case DECIMAL64: { - UdfConvert.copyBatchDecimal64Result(retType.getScale(), isNullable, numRows, (BigDecimal[]) result, - nullMapAddr, - resColumnAddr); - break; - } - default: { - LOG.info("Not support return type: " + retType); - Preconditions.checkState(false, "Not support type: " + retType.toString()); - break; - } - } - } - - public void copyBatchArrayResultImpl(boolean isNullable, int numRows, Object[] result, long nullMapAddr, - long offsetsAddr, long nestedNullMapAddr, long dataAddr, long strOffsetAddr, - PrimitiveType type) { - long hasPutElementNum = 0; - for (int row = 0; row < numRows; ++row) { - switch (type) { - case BOOLEAN: { - hasPutElementNum = UdfConvert - .copyBatchArrayBooleanResult(hasPutElementNum, isNullable, row, result, nullMapAddr, - offsetsAddr, nestedNullMapAddr, dataAddr); - break; - } - case TINYINT: { - hasPutElementNum = UdfConvert - .copyBatchArrayTinyIntResult(hasPutElementNum, isNullable, row, result, nullMapAddr, - offsetsAddr, nestedNullMapAddr, dataAddr); - break; - } - case SMALLINT: { - hasPutElementNum = UdfConvert - .copyBatchArraySmallIntResult(hasPutElementNum, isNullable, row, result, nullMapAddr, - offsetsAddr, nestedNullMapAddr, dataAddr); - break; - } - case INT: { - hasPutElementNum = UdfConvert - .copyBatchArrayIntResult(hasPutElementNum, isNullable, row, result, nullMapAddr, - offsetsAddr, nestedNullMapAddr, dataAddr); - break; - } - case BIGINT: { - hasPutElementNum = UdfConvert - .copyBatchArrayBigIntResult(hasPutElementNum, isNullable, row, result, nullMapAddr, - offsetsAddr, nestedNullMapAddr, dataAddr); - break; - } - case LARGEINT: { - hasPutElementNum = UdfConvert - .copyBatchArrayLargeIntResult(hasPutElementNum, isNullable, row, result, nullMapAddr, - offsetsAddr, nestedNullMapAddr, dataAddr); - break; - } - case FLOAT: { - hasPutElementNum = UdfConvert - .copyBatchArrayFloatResult(hasPutElementNum, isNullable, row, result, nullMapAddr, - offsetsAddr, nestedNullMapAddr, dataAddr); - break; - } - case DOUBLE: { - hasPutElementNum = UdfConvert - .copyBatchArrayDoubleResult(hasPutElementNum, isNullable, row, result, nullMapAddr, - offsetsAddr, nestedNullMapAddr, dataAddr); - break; - } - case CHAR: - case VARCHAR: - case STRING: { - hasPutElementNum = UdfConvert - .copyBatchArrayStringResult(hasPutElementNum, isNullable, row, result, nullMapAddr, - offsetsAddr, nestedNullMapAddr, dataAddr, strOffsetAddr); - break; - } - case DATE: { - hasPutElementNum = UdfConvert - .copyBatchArrayDateResult(hasPutElementNum, isNullable, row, result, nullMapAddr, - offsetsAddr, nestedNullMapAddr, dataAddr); - break; - } - case DATETIME: { - hasPutElementNum = UdfConvert - .copyBatchArrayDateTimeResult(hasPutElementNum, isNullable, row, result, nullMapAddr, - offsetsAddr, nestedNullMapAddr, dataAddr); - break; - } - case DATEV2: { - hasPutElementNum = UdfConvert - .copyBatchArrayDateV2Result(hasPutElementNum, isNullable, row, result, nullMapAddr, - offsetsAddr, nestedNullMapAddr, dataAddr); - break; - } - case DATETIMEV2: { - hasPutElementNum = UdfConvert - .copyBatchArrayDateTimeV2Result(hasPutElementNum, isNullable, row, result, nullMapAddr, - offsetsAddr, nestedNullMapAddr, dataAddr); - break; - } - case DECIMALV2: { - hasPutElementNum = UdfConvert - .copyBatchArrayDecimalResult(hasPutElementNum, isNullable, row, result, nullMapAddr, - offsetsAddr, nestedNullMapAddr, dataAddr); - break; - } - case DECIMAL32: { - hasPutElementNum = UdfConvert - .copyBatchArrayDecimalV3Result(retType.getScale(), 4L, hasPutElementNum, isNullable, row, - result, nullMapAddr, - offsetsAddr, nestedNullMapAddr, dataAddr); - break; - } - case DECIMAL64: { - hasPutElementNum = UdfConvert - .copyBatchArrayDecimalV3Result(retType.getScale(), 8L, hasPutElementNum, isNullable, row, - result, nullMapAddr, - offsetsAddr, nestedNullMapAddr, dataAddr); - break; - } - case DECIMAL128: { - hasPutElementNum = UdfConvert - .copyBatchArrayDecimalV3Result(retType.getScale(), 16L, hasPutElementNum, isNullable, row, - result, nullMapAddr, - offsetsAddr, nestedNullMapAddr, dataAddr); - break; - } - default: { - Preconditions.checkState(false, "Not support type in array: " + retType); - break; - } - } - } + copyBatchBasicResultImpl(isNullable, numRows, result, nullMapAddr, resColumnAddr, strOffsetAddr, getMethod()); } public void copyBatchArrayResult(boolean isNullable, int numRows, Object[] result, long nullMapAddr, @@ -453,68 +186,7 @@ public class UdfExecutor extends BaseExecutor { PrimitiveType valueType = retType.getValueType().getPrimitiveType(); Object[] keyCol = new Object[result.length]; Object[] valueCol = new Object[result.length]; - switch (keyType) { - case BOOLEAN: { - new ArrayListBuilder().get(result, keyCol, valueCol, valueType); - break; - } - case TINYINT: { - new ArrayListBuilder().get(result, keyCol, valueCol, valueType); - break; - } - case SMALLINT: { - new ArrayListBuilder().get(result, keyCol, valueCol, valueType); - break; - } - case INT: { - new ArrayListBuilder().get(result, keyCol, valueCol, valueType); - break; - } - case BIGINT: { - new ArrayListBuilder().get(result, keyCol, valueCol, valueType); - break; - } - case LARGEINT: { - new ArrayListBuilder().get(result, keyCol, valueCol, valueType); - break; - } - case FLOAT: { - new ArrayListBuilder().get(result, keyCol, valueCol, valueType); - break; - } - case DOUBLE: { - new ArrayListBuilder().get(result, keyCol, valueCol, valueType); - break; - } - case CHAR: - case VARCHAR: - case STRING: { - new ArrayListBuilder().get(result, keyCol, valueCol, valueType); - break; - } - case DATEV2: - case DATE: { - new ArrayListBuilder().get(result, keyCol, valueCol, valueType); - break; - } - case DATETIMEV2: - case DATETIME: { - new ArrayListBuilder().get(result, keyCol, valueCol, valueType); - break; - } - case DECIMAL32: - case DECIMAL64: - case DECIMALV2: - case DECIMAL128: { - new ArrayListBuilder().get(result, keyCol, valueCol, valueType); - break; - } - default: { - LOG.info("Not support: " + keyType); - Preconditions.checkState(false, "Not support type " + keyType.toString()); - break; - } - } + buildArrayListFromHashMap(result, keyType, valueType, keyCol, valueCol); copyBatchArrayResultImpl(isNullable, numRows, valueCol, nullMapAddr, offsetsAddr, valueNsestedNullMapAddr, valueDataAddr, @@ -522,7 +194,6 @@ public class UdfExecutor extends BaseExecutor { copyBatchArrayResultImpl(isNullable, numRows, keyCol, nullMapAddr, offsetsAddr, keyNsestedNullMapAddr, keyDataAddr, keyStrOffsetAddr, keyType); - } /** @@ -669,163 +340,4 @@ public class UdfExecutor extends BaseExecutor { } } - public static class HashMapBuilder { - public Object[] get(Object[] keyCol, Object[] valueCol, PrimitiveType valueType) { - switch (valueType) { - case BOOLEAN: { - return new BuildMapFromType().get(keyCol, valueCol); - } - case TINYINT: { - return new BuildMapFromType().get(keyCol, valueCol); - } - case SMALLINT: { - return new BuildMapFromType().get(keyCol, valueCol); - } - case INT: { - return new BuildMapFromType().get(keyCol, valueCol); - } - case BIGINT: { - return new BuildMapFromType().get(keyCol, valueCol); - } - case LARGEINT: { - return new BuildMapFromType().get(keyCol, valueCol); - } - case FLOAT: { - return new BuildMapFromType().get(keyCol, valueCol); - } - case DOUBLE: { - return new BuildMapFromType().get(keyCol, valueCol); - } - case CHAR: - case VARCHAR: - case STRING: { - return new BuildMapFromType().get(keyCol, valueCol); - } - case DATEV2: - case DATE: { - return new BuildMapFromType().get(keyCol, valueCol); - } - case DATETIMEV2: - case DATETIME: { - return new BuildMapFromType().get(keyCol, valueCol); - } - case DECIMAL32: - case DECIMAL64: - case DECIMALV2: - case DECIMAL128: { - return new BuildMapFromType().get(keyCol, valueCol); - } - default: { - LOG.info("Not support: " + valueType); - Preconditions.checkState(false, "Not support type " + valueType.toString()); - break; - } - } - return null; - } - } - - public static class BuildMapFromType { - public Object[] get(Object[] keyCol, Object[] valueCol) { - Object[] retHashMap = new HashMap[keyCol.length]; - for (int colIdx = 0; colIdx < keyCol.length; colIdx++) { - HashMap hashMap = new HashMap<>(); - ArrayList keys = (ArrayList) (keyCol[colIdx]); - ArrayList values = (ArrayList) (valueCol[colIdx]); - for (int i = 0; i < keys.size(); i++) { - T1 key = keys.get(i); - T2 value = values.get(i); - if (!hashMap.containsKey(key)) { - hashMap.put(key, value); - } - } - retHashMap[colIdx] = hashMap; - } - return retHashMap; - } - } - - public static class ArrayListBuilder { - public void get(Object[] map, Object[] keyCol, Object[] valueCol, PrimitiveType valueType) { - switch (valueType) { - case BOOLEAN: { - new BuildArrayFromType().get(map, keyCol, valueCol); - break; - } - case TINYINT: { - new BuildArrayFromType().get(map, keyCol, valueCol); - break; - } - case SMALLINT: { - new BuildArrayFromType().get(map, keyCol, valueCol); - break; - } - case INT: { - new BuildArrayFromType().get(map, keyCol, valueCol); - break; - } - case BIGINT: { - new BuildArrayFromType().get(map, keyCol, valueCol); - break; - } - case LARGEINT: { - new BuildArrayFromType().get(map, keyCol, valueCol); - break; - } - case FLOAT: { - new BuildArrayFromType().get(map, keyCol, valueCol); - break; - } - case DOUBLE: { - new BuildArrayFromType().get(map, keyCol, valueCol); - break; - } - case CHAR: - case VARCHAR: - case STRING: { - new BuildArrayFromType().get(map, keyCol, valueCol); - break; - } - case DATEV2: - case DATE: { - new BuildArrayFromType().get(map, keyCol, valueCol); - break; - } - case DATETIMEV2: - case DATETIME: { - new BuildArrayFromType().get(map, keyCol, valueCol); - break; - } - case DECIMAL32: - case DECIMAL64: - case DECIMALV2: - case DECIMAL128: { - new BuildArrayFromType().get(map, keyCol, valueCol); - break; - } - default: { - LOG.info("Not support: " + valueType); - Preconditions.checkState(false, "Not support type " + valueType.toString()); - break; - } - } - } - } - - public static class BuildArrayFromType { - public void get(Object[] map, Object[] keyCol, Object[] valueCol) { - for (int colIdx = 0; colIdx < map.length; colIdx++) { - HashMap hashMap = (HashMap) map[colIdx]; - ArrayList keys = new ArrayList<>(); - ArrayList values = new ArrayList<>(); - for (Map.Entry entry : hashMap.entrySet()) { - keys.add(entry.getKey()); - values.add(entry.getValue()); - } - keyCol[colIdx] = keys; - valueCol[colIdx] = values; - } - } - } - } diff --git a/regression-test/data/javaudf_p0/test_javaudf_agg_map.out b/regression-test/data/javaudf_p0/test_javaudf_agg_map.out new file mode 100644 index 0000000000..b4093b461c --- /dev/null +++ b/regression-test/data/javaudf_p0/test_javaudf_agg_map.out @@ -0,0 +1,7 @@ +-- This file is automatically generated. You should know what you did if you want to edit this +-- !select_1 -- +616.0 + +-- !select_2 -- +342 + diff --git a/regression-test/java-udf-src/src/main/java/org/apache/doris/udf/MySumMapInt.java b/regression-test/java-udf-src/src/main/java/org/apache/doris/udf/MySumMapInt.java new file mode 100644 index 0000000000..6310355ccb --- /dev/null +++ b/regression-test/java-udf-src/src/main/java/org/apache/doris/udf/MySumMapInt.java @@ -0,0 +1,64 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. +package org.apache.doris.udf; +import org.apache.log4j.Logger; +import java.io.DataInputStream; +import java.io.DataOutputStream; +import java.io.IOException; +import java.util.HashMap; +import java.util.Map; + +public class MySumMapInt { + private static final Logger LOG = Logger.getLogger(MySumMapInt.class); + public static class State { + public long counter = 0; + } + + public State create() { + return new State(); + } + + public void destroy(State state) { + } + + public void add(State state, HashMap val) { + if (val == null) { + return; + } + for(Map.Entry it : val.entrySet()){ + Integer key = it.getKey(); + Integer value = it.getValue(); + state.counter += key + value; + } + } + + public void serialize(State state, DataOutputStream out) throws IOException { + out.writeLong(state.counter); + } + + public void deserialize(State state, DataInputStream in) throws IOException { + state.counter = in.readLong(); + } + + public void merge(State state, State rhs) { + state.counter += rhs.counter; + } + + public long getValue(State state) { + return state.counter; + } +} \ No newline at end of file diff --git a/regression-test/java-udf-src/src/main/java/org/apache/doris/udf/MySumMapIntDou.java b/regression-test/java-udf-src/src/main/java/org/apache/doris/udf/MySumMapIntDou.java new file mode 100644 index 0000000000..7690bba603 --- /dev/null +++ b/regression-test/java-udf-src/src/main/java/org/apache/doris/udf/MySumMapIntDou.java @@ -0,0 +1,64 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. +package org.apache.doris.udf; +import org.apache.log4j.Logger; +import java.io.DataInputStream; +import java.io.DataOutputStream; +import java.io.IOException; +import java.util.HashMap; +import java.util.Map; + +public class MySumMapIntDou { + private static final Logger LOG = Logger.getLogger(MySumMapIntDou.class); + public static class State { + public Double counter = 0.0; + } + + public State create() { + return new State(); + } + + public void destroy(State state) { + } + + public void add(State state, HashMap val) { + if (val == null) { + return; + } + for(Map.Entry it : val.entrySet()){ + Integer key = it.getKey(); + Double value = it.getValue(); + state.counter += key * value; + } + } + + public void serialize(State state, DataOutputStream out) throws IOException { + out.writeDouble(state.counter); + } + + public void deserialize(State state, DataInputStream in) throws IOException { + state.counter = in.readDouble(); + } + + public void merge(State state, State rhs) { + state.counter += rhs.counter; + } + + public Double getValue(State state) { + return state.counter; + } +} \ No newline at end of file diff --git a/regression-test/suites/javaudf_p0/test_javaudf_agg_map.groovy b/regression-test/suites/javaudf_p0/test_javaudf_agg_map.groovy new file mode 100644 index 0000000000..facd8fe1f9 --- /dev/null +++ b/regression-test/suites/javaudf_p0/test_javaudf_agg_map.groovy @@ -0,0 +1,78 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +import org.codehaus.groovy.runtime.IOGroovyMethods + +import java.nio.charset.StandardCharsets +import java.nio.file.Files +import java.nio.file.Paths + +suite("test_javaudf_agg_map") { + def jarPath = """${context.file.parent}/jars/java-udf-case-jar-with-dependencies.jar""" + log.info("Jar path: ${jarPath}".toString()) + try { + try_sql("DROP FUNCTION IF EXISTS mapii(Map);") + try_sql("DROP FUNCTION IF EXISTS mapid(Map);") + try_sql("DROP TABLE IF EXISTS db") + sql """ + CREATE TABLE IF NOT EXISTS db( + `id` INT NULL COMMENT "", + `i` INT NULL COMMENT "", + `d` Double NULL COMMENT "", + `mii` Map NULL COMMENT "", + `mid` Map NULL COMMENT "" + ) ENGINE=OLAP + DUPLICATE KEY(`id`) + DISTRIBUTED BY HASH(`id`) BUCKETS 1 + PROPERTIES ( + "replication_allocation" = "tag.location.default: 1", + "storage_format" = "V2"); + """ + sql """ INSERT INTO db VALUES(1, 10,1.1,{1:1,10:1,100:1},{1:1.1,11:11.1}); """ + sql """ INSERT INTO db VALUES(2, 20,2.2,{2:2,20:2,200:2},{2:2.2,22:22.2}); """ + + sql """ + + CREATE AGGREGATE FUNCTION mapii(Map) RETURNS BigInt PROPERTIES ( + "file"="file://${jarPath}", + "symbol"="org.apache.doris.udf.MySumMapInt", + "type"="JAVA_UDF" + ); + + """ + + sql """ + + CREATE AGGREGATE FUNCTION mapid(Map) RETURNS Double PROPERTIES ( + "file"="file://${jarPath}", + "symbol"="org.apache.doris.udf.MySumMapIntDou", + "type"="JAVA_UDF" + ); + + """ + + + qt_select_1 """ select mapid(mid) from db; """ + + qt_select_2 """ select mapii(mii) from db; """ + + } finally { + try_sql("DROP FUNCTION IF EXISTS mapii(Map);") + try_sql("DROP FUNCTION IF EXISTS mapid(Map);") + try_sql("DROP TABLE IF EXISTS db") + } +} diff --git a/regression-test/suites/javaudf_p0/test_javaudf_ret_map.groovy b/regression-test/suites/javaudf_p0/test_javaudf_ret_map.groovy index df8baa37f9..8421a6699a 100644 --- a/regression-test/suites/javaudf_p0/test_javaudf_ret_map.groovy +++ b/regression-test/suites/javaudf_p0/test_javaudf_ret_map.groovy @@ -25,6 +25,12 @@ suite("test_javaudf_ret_map") { def jarPath = """${context.file.parent}/jars/java-udf-case-jar-with-dependencies.jar""" log.info("Jar path: ${jarPath}".toString()) try { + try_sql("DROP FUNCTION IF EXISTS retii(map);") + try_sql("DROP FUNCTION IF EXISTS retss(map);") + try_sql("DROP FUNCTION IF EXISTS retid(map);") + try_sql("DROP FUNCTION IF EXISTS retidss(int ,double);") + try_sql("DROP TABLE IF EXISTS db") + try_sql("DROP TABLE IF EXISTS dbss") sql """ CREATE TABLE IF NOT EXISTS db( `id` INT NULL COMMENT "",