diff --git a/be/src/vec/functions/function_java_udf.cpp b/be/src/vec/functions/function_java_udf.cpp index 5a47a5c0c2..cddf95661b 100644 --- a/be/src/vec/functions/function_java_udf.cpp +++ b/be/src/vec/functions/function_java_udf.cpp @@ -77,12 +77,15 @@ Status JavaFunctionCall::prepare(FunctionContext* context, FunctionContext::Func TJavaUdfExecutorCtorParams ctor_params; ctor_params.__set_fn(fn_); ctor_params.__set_location(local_location); - ctor_params.__set_input_byte_offsets(jni_ctx->input_byte_offsets_ptr); - ctor_params.__set_input_buffer_ptrs(jni_ctx->input_values_buffer_ptr); - ctor_params.__set_input_nulls_ptrs(jni_ctx->input_nulls_buffer_ptr); - ctor_params.__set_output_buffer_ptr(jni_ctx->output_value_buffer); - ctor_params.__set_output_null_ptr(jni_ctx->output_null_value); - ctor_params.__set_batch_size_ptr(jni_ctx->batch_size_ptr); + ctor_params.__set_input_offsets_ptrs((int64_t) jni_ctx->input_offsets_ptrs.get()); + ctor_params.__set_input_buffer_ptrs((int64_t) jni_ctx->input_values_buffer_ptr.get()); + ctor_params.__set_input_nulls_ptrs((int64_t) jni_ctx->input_nulls_buffer_ptr.get()); + ctor_params.__set_output_buffer_ptr((int64_t) jni_ctx->output_value_buffer.get()); + ctor_params.__set_output_null_ptr((int64_t) jni_ctx->output_null_value.get()); + ctor_params.__set_output_offsets_ptr((int64_t) jni_ctx->output_offsets_ptr.get()); + ctor_params.__set_output_intermediate_state_ptr( + (int64_t) jni_ctx->output_intermediate_state_ptr.get()); + ctor_params.__set_batch_size_ptr((int64_t) jni_ctx->batch_size_ptr.get()); jbyteArray ctor_params_bytes; @@ -100,11 +103,6 @@ Status JavaFunctionCall::prepare(FunctionContext* context, FunctionContext::Func Status JavaFunctionCall::execute(FunctionContext* context, Block& block, const ColumnNumbers& arguments, size_t result, size_t num_rows, bool dry_run) { - auto return_type = block.get_data_type(result); - if (!return_type->have_maximum_size_of_value()) { - return Status::InvalidArgument(strings::Substitute( - "Java UDF doesn't support return type $0 now !", return_type->get_name())); - } JNIEnv* env; RETURN_IF_ERROR(JniUtil::GetJNIEnv(&env)); JniContext* jni_ctx = reinterpret_cast( @@ -119,50 +117,94 @@ Status JavaFunctionCall::execute(FunctionContext* context, Block& block, const C arg_idx, column.type->get_name(), _argument_types[arg_idx]->get_name())); } - if (!column.type->have_maximum_size_of_value()) { - return Status::InvalidArgument(strings::Substitute( - "Java UDF doesn't support input type $0 now !", return_type->get_name())); - } auto data_col = col; if (auto* nullable = check_and_get_column(*col)) { data_col = nullable->get_nested_column_ptr(); auto null_col = check_and_get_column>(nullable->get_null_map_column_ptr()); - ((int64_t*) jni_ctx->input_nulls_buffer_ptr)[arg_idx] = + jni_ctx->input_nulls_buffer_ptr.get()[arg_idx] = reinterpret_cast(null_col->get_data().data()); + } else { + jni_ctx->input_nulls_buffer_ptr.get()[arg_idx] = -1; + } + if (const ColumnString* str_col = check_and_get_column(data_col.get())) { + jni_ctx->input_values_buffer_ptr.get()[arg_idx] = + reinterpret_cast(str_col->get_chars().data()); + jni_ctx->input_offsets_ptrs.get()[arg_idx] = + reinterpret_cast(str_col->get_offsets().data()); + } else if (data_col->is_numeric()) { + jni_ctx->input_values_buffer_ptr.get()[arg_idx] = + reinterpret_cast(data_col->get_raw_data().data); + } else { + return Status::InvalidArgument(strings::Substitute( + "Java UDF doesn't support type $0 now !", _argument_types[arg_idx]->get_name())); } - ((int64_t*) jni_ctx->input_values_buffer_ptr)[arg_idx] = - reinterpret_cast(data_col->get_raw_data().data); arg_idx++; } - + *(jni_ctx->batch_size_ptr) = num_rows; + auto return_type = block.get_data_type(result); if (return_type->is_nullable()) { auto null_type = std::reinterpret_pointer_cast(return_type); auto data_col = null_type->get_nested_type()->create_column(); auto null_col = ColumnUInt8::create(data_col->size(), 0); null_col->reserve(num_rows); null_col->resize(num_rows); - data_col->reserve(num_rows); - data_col->resize(num_rows); - *((int64_t*) jni_ctx->output_null_value) = + *(jni_ctx->output_null_value) = reinterpret_cast(null_col->get_data().data()); - *((int64_t*) jni_ctx->output_value_buffer) = reinterpret_cast(data_col->get_raw_data().data); +#ifndef EVALUATE_JAVA_UDF +#define EVALUATE_JAVA_UDF \ + if (const ColumnString* str_col = check_and_get_column(data_col.get())) { \ + ColumnString::Chars& chars = const_cast(str_col->get_chars()); \ + ColumnString::Offsets& offsets = \ + const_cast(str_col->get_offsets()); \ + int increase_buffer_size = 0; \ + int32_t buffer_size = \ + JavaFunctionCall::IncreaseReservedBufferSize(increase_buffer_size); \ + chars.reserve(buffer_size); \ + chars.resize(buffer_size); \ + offsets.reserve(num_rows); \ + offsets.resize(num_rows); \ + *(jni_ctx->output_value_buffer) = \ + reinterpret_cast(chars.data()); \ + *(jni_ctx->output_offsets_ptr) = \ + reinterpret_cast(offsets.data()); \ + jni_ctx->output_intermediate_state_ptr->row_idx = 0; \ + jni_ctx->output_intermediate_state_ptr->buffer_size = buffer_size; \ + env->CallNonvirtualVoidMethodA( \ + jni_ctx->executor, executor_cl_, executor_evaluate_id_, nullptr); \ + while (jni_ctx->output_intermediate_state_ptr->row_idx < num_rows) { \ + increase_buffer_size++; \ + int32_t buffer_size = \ + JavaFunctionCall::IncreaseReservedBufferSize(increase_buffer_size); \ + chars.resize(buffer_size); \ + *(jni_ctx->output_value_buffer) = \ + reinterpret_cast(chars.data()); \ + jni_ctx->output_intermediate_state_ptr->buffer_size = buffer_size; \ + env->CallNonvirtualVoidMethodA( \ + jni_ctx->executor, executor_cl_, executor_evaluate_id_, nullptr); \ + } \ + } else if (data_col->is_numeric()) { \ + data_col->reserve(num_rows); \ + data_col->resize(num_rows); \ + *(jni_ctx->output_value_buffer) = \ + reinterpret_cast(data_col->get_raw_data().data); \ + env->CallNonvirtualVoidMethodA( \ + jni_ctx->executor, executor_cl_, executor_evaluate_id_, nullptr); \ + } else { \ + return Status::InvalidArgument(strings::Substitute( \ + "Java UDF doesn't support return type $0 now !", return_type->get_name())); \ + } +#endif + EVALUATE_JAVA_UDF; block.replace_by_position(result, ColumnNullable::create(std::move(data_col), std::move(null_col))); } else { + *(jni_ctx->output_null_value) = -1; auto data_col = return_type->create_column(); - data_col->reserve(num_rows); - data_col->resize(num_rows); - - *((int64_t*) jni_ctx->output_value_buffer) = reinterpret_cast(data_col->get_raw_data().data); + EVALUATE_JAVA_UDF; block.replace_by_position(result, std::move(data_col)); } - *((int32_t*) jni_ctx->batch_size_ptr) = num_rows; - // Using this version of Call has the lowest overhead. This eliminates the - // vtable lookup and setting up return stacks. - env->CallNonvirtualVoidMethodA( - jni_ctx->executor, executor_cl_, executor_evaluate_id_, nullptr); return JniUtil::GetJniExceptionMsg(env); } diff --git a/be/src/vec/functions/function_java_udf.h b/be/src/vec/functions/function_java_udf.h index 2bc8ce88d8..8c0fcbe8e7 100644 --- a/be/src/vec/functions/function_java_udf.h +++ b/be/src/vec/functions/function_java_udf.h @@ -76,27 +76,36 @@ private: jmethodID executor_evaluate_id_; jmethodID executor_close_id_; + struct IntermediateState { + size_t buffer_size; + size_t row_idx; + }; + struct JniContext { JavaFunctionCall* parent = nullptr; jobject executor = nullptr; - int64_t input_values_buffer_ptr; - int64_t input_nulls_buffer_ptr; - int64_t input_byte_offsets_ptr; - int64_t output_value_buffer; - int64_t output_null_value; - int64_t batch_size_ptr; + std::unique_ptr input_values_buffer_ptr; + std::unique_ptr input_nulls_buffer_ptr; + std::unique_ptr input_offsets_ptrs; + std::unique_ptr output_value_buffer; + std::unique_ptr output_null_value; + std::unique_ptr output_offsets_ptr; + std::unique_ptr batch_size_ptr; + // intermediate_state includes two parts: reserved / used buffer size and rows + std::unique_ptr output_intermediate_state_ptr; JniContext(int64_t num_args, JavaFunctionCall* parent): parent(parent) { - input_values_buffer_ptr = (int64_t) new int64_t[num_args]; - input_nulls_buffer_ptr = (int64_t) new int64_t[num_args]; - input_byte_offsets_ptr = (int64_t) new int64_t[num_args]; - - output_value_buffer = (int64_t) malloc(sizeof(int64_t)); - output_null_value = (int64_t) malloc(sizeof(int64_t)); - batch_size_ptr = (int64_t) malloc(sizeof(int32_t)); + input_values_buffer_ptr.reset(new int64_t[num_args]); + input_nulls_buffer_ptr.reset(new int64_t[num_args]); + input_offsets_ptrs.reset(new int64_t[num_args]); + output_value_buffer.reset((int64_t*) malloc(sizeof(int64_t))); + output_null_value.reset((int64_t*) malloc(sizeof(int64_t))); + batch_size_ptr.reset((int32_t*) malloc(sizeof(int32_t))); + output_offsets_ptr.reset((int64_t*) malloc(sizeof(int64_t))); + output_intermediate_state_ptr.reset((IntermediateState*) malloc(sizeof(IntermediateState))); } ~JniContext() { @@ -109,12 +118,6 @@ private: Status s = JniUtil::GetJniExceptionMsg(env); if (!s.ok()) LOG(WARNING) << s.get_error_msg(); env->DeleteGlobalRef(executor); - delete[] ((int64*) input_values_buffer_ptr); - delete[] ((int64*) input_nulls_buffer_ptr); - delete[] ((int64*) input_byte_offsets_ptr); - free((int64*) output_value_buffer); - free((int64*) output_null_value); - free((int32*) batch_size_ptr); } /// These functions are cross-compiled to IR and used by codegen. @@ -122,6 +125,12 @@ private: JniContext* jni_ctx, int index, uint8_t value); static uint8_t* GetInputValuesBufferAtOffset(JniContext* jni_ctx, int offset); }; + + static const int32_t INITIAL_RESERVED_BUFFER_SIZE = 1024; + // TODO: we need a heuristic strategy to increase buffer size for variable-size output. + static inline int32_t IncreaseReservedBufferSize(int n) { + return INITIAL_RESERVED_BUFFER_SIZE << n; + } }; } // namespace vectorized diff --git a/bin/start_be.sh b/bin/start_be.sh index c5b8da5c61..8bb8a6ff30 100755 --- a/bin/start_be.sh +++ b/bin/start_be.sh @@ -92,6 +92,7 @@ jdk_version() { } jvm_arch="amd64" +MACHINE_TYPE=$(uname -m) if [[ "${MACHINE_TYPE}" == "aarch64" ]]; then jvm_arch="aarch64" fi diff --git a/fe/fe-core/src/main/java/org/apache/doris/analysis/CreateFunctionStmt.java b/fe/fe-core/src/main/java/org/apache/doris/analysis/CreateFunctionStmt.java index 0ecd3b1ffb..d2516954f7 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/analysis/CreateFunctionStmt.java +++ b/fe/fe-core/src/main/java/org/apache/doris/analysis/CreateFunctionStmt.java @@ -375,6 +375,7 @@ public class CreateFunctionStmt extends DdlStmt { .put(PrimitiveType.BIGINT, Sets.newHashSet(Long.class, long.class)) .put(PrimitiveType.CHAR, Sets.newHashSet(String.class)) .put(PrimitiveType.VARCHAR, Sets.newHashSet(String.class)) + .put(PrimitiveType.STRING, Sets.newHashSet(String.class)) .build(); private void checkUdfType(Class clazz, Method method, Type expType, Class pType, String pname) diff --git a/fe/java-udf/src/main/java/org/apache/doris/udf/UdfExecutor.java b/fe/java-udf/src/main/java/org/apache/doris/udf/UdfExecutor.java index 89b6ea79a2..cca787400a 100644 --- a/fe/java-udf/src/main/java/org/apache/doris/udf/UdfExecutor.java +++ b/fe/java-udf/src/main/java/org/apache/doris/udf/UdfExecutor.java @@ -29,36 +29,18 @@ import org.apache.thrift.TDeserializer; import org.apache.thrift.TException; import org.apache.thrift.protocol.TBinaryProtocol; -import sun.misc.Unsafe; - import java.io.File; import java.io.IOException; import java.lang.reflect.Constructor; -import java.lang.reflect.Field; import java.lang.reflect.Method; import java.net.MalformedURLException; import java.net.URL; import java.net.URLClassLoader; -import java.security.AccessController; -import java.security.PrivilegedAction; +import java.nio.charset.StandardCharsets; import java.util.ArrayList; public class UdfExecutor { private static final Logger LOG = Logger.getLogger(UdfExecutor.class); - public static final Unsafe UNSAFE; - - static { - UNSAFE = (Unsafe) AccessController.doPrivileged( - (PrivilegedAction) () -> { - try { - Field f = Unsafe.class.getDeclaredField("theUnsafe"); - f.setAccessible(true); - return f.get(null); - } catch (NoSuchFieldException | IllegalAccessException e) { - throw new Error(); - } - }); - } // By convention, the function in the class must be called evaluate() public static final String UDF_FUNCTION_NAME = "evaluate"; @@ -82,10 +64,13 @@ public class UdfExecutor { // These buffers are allocated in the BE. private final long inputBufferPtrs_; private final long inputNullsPtrs_; + private final long inputOffsetsPtrs_; // Output buffer to return non-string values. These buffers are allocated in the BE. private final long outputBufferPtr_; private final long outputNullPtr_; + private final long outputOffsetsPtr_; + private final long outputIntermediateStatePtr_; // Pre-constructed input objects for the UDF. This minimizes object creation overhead // as these objects are reused across calls to evaluate(). @@ -93,6 +78,9 @@ public class UdfExecutor { // inputArgs_[i] is either inputObjects_[i] or null private Object[] inputArgs_; + private long outputOffset_; + private long row_idx_; + private final long batch_size_ptr_; // Data types that are supported as return or argument types in Java UDFs. @@ -104,7 +92,10 @@ public class UdfExecutor { INT("INT", TPrimitiveType.INT, 4), BIGINT("BIGINT", TPrimitiveType.BIGINT, 8), FLOAT("FLOAT", TPrimitiveType.FLOAT, 4), - DOUBLE("DOUBLE", TPrimitiveType.DOUBLE, 8); + DOUBLE("DOUBLE", TPrimitiveType.DOUBLE, 8), + CHAR("CHAR", TPrimitiveType.CHAR, 0), + VARCHAR("VARCHAR", TPrimitiveType.VARCHAR, 0), + STRING("STRING", TPrimitiveType.STRING, 0); private final String description_; private final TPrimitiveType thriftType_; @@ -144,6 +135,10 @@ public class UdfExecutor { return JavaUdfDataType.FLOAT; } else if (c == double.class || c == Double.class) { return JavaUdfDataType.DOUBLE; + } else if (c == char.class || c == Character.class) { + return JavaUdfDataType.CHAR; + } else if (c == String.class) { + return JavaUdfDataType.STRING; } return JavaUdfDataType.INVALID_TYPE; } @@ -183,8 +178,15 @@ public class UdfExecutor { batch_size_ptr_ = request.batch_size_ptr; inputBufferPtrs_ = request.input_buffer_ptrs; inputNullsPtrs_ = request.input_nulls_ptrs; + inputOffsetsPtrs_ = request.input_offsets_ptrs; + outputBufferPtr_ = request.output_buffer_ptr; outputNullPtr_ = request.output_null_ptr; + outputOffsetsPtr_ = request.output_offsets_ptr; + outputIntermediateStatePtr_ = request.output_intermediate_state_ptr; + + outputOffset_ = 0L; + row_idx_ = 0L; init(jarFile, className, retType, parameterTypes); } @@ -218,22 +220,52 @@ public class UdfExecutor { * been serialized to 'input' */ public void evaluate() throws UdfRuntimeException { + int batch_size = UdfUtils.UNSAFE.getInt(null, batch_size_ptr_); try { - int batch_size = UNSAFE.getInt(null, batch_size_ptr_); - for (int row = 0; row < batch_size; row++) { - allocateInputObjects(row); + if (retType_.equals(JavaUdfDataType.STRING) || retType_.equals(JavaUdfDataType.VARCHAR) + || retType_.equals(JavaUdfDataType.CHAR)) { + // If this udf return variable-size type (e.g.) String, we have to allocate output + // buffer multiple times until buffer size is enough to store output column. So we + // always begin with the last evaluated row instead of beginning of this batch. + row_idx_ = UdfUtils.UNSAFE.getLong(null, outputIntermediateStatePtr_ + 8); + if (row_idx_ == 0) { + outputOffset_ = 0L; + } + } else { + row_idx_ = 0; + } + for (; row_idx_ < batch_size; row_idx_++) { + allocateInputObjects(row_idx_); for (int i = 0; i < argTypes_.length; ++i) { - if (UNSAFE.getByte(null, UNSAFE.getLong(null, UdfUtils.getAddressAtOffset(inputNullsPtrs_, i)) + row * 1L) == 0) { + // Currently, -1 indicates this column is not nullable. So input argument is + // null iff inputNullsPtrs_ != -1 and nullCol[row_idx] != 0. + if (UdfUtils.UNSAFE.getLong(null, + UdfUtils.getAddressAtOffset(inputNullsPtrs_, i)) == -1 || + UdfUtils.UNSAFE.getByte(null, UdfUtils.UNSAFE.getLong(null, + UdfUtils.getAddressAtOffset(inputNullsPtrs_, i)) + row_idx_) == 0) { inputArgs_[i] = inputObjects_[i]; } else { inputArgs_[i] = null; } } - storeUdfResult(evaluate(inputArgs_), row); + // `storeUdfResult` is called to store udf result to output column. If true + // is returned, current value is stored successfully. Otherwise, current result is + // not processed successfully (e.g. current output buffer is not large enough) so + // we break this loop directly. + if (!storeUdfResult(evaluate(inputArgs_), row_idx_)) { + UdfUtils.UNSAFE.putLong(null, outputIntermediateStatePtr_ + 8, row_idx_); + return; + } } } catch (Exception e) { + if (retType_.equals(JavaUdfDataType.STRING)) { + UdfUtils.UNSAFE.putLong(null, outputIntermediateStatePtr_ + 8, batch_size); + } throw new UdfRuntimeException("UDF::evaluate() ran into a problem.", e); } + if (retType_.equals(JavaUdfDataType.STRING)) { + UdfUtils.UNSAFE.putLong(null, outputIntermediateStatePtr_ + 8, row_idx_); + } } /** @@ -252,42 +284,73 @@ public class UdfExecutor { } // Sets the result object 'obj' into the outputBufferPtr_ and outputNullPtr_ - private void storeUdfResult(Object obj, int row) throws UdfRuntimeException { + private boolean storeUdfResult(Object obj, long row) throws UdfRuntimeException { if (obj == null) { - UNSAFE.putByte(null, UNSAFE.getLong(null, outputNullPtr_) + row * 1L, (byte) 1); - return; + assert (UdfUtils.UNSAFE.getLong(null, outputNullPtr_) != -1); + UdfUtils.UNSAFE.putByte(null, UdfUtils.UNSAFE.getLong(null, outputNullPtr_) + row, (byte) 1); + if (retType_.equals(JavaUdfDataType.STRING)) { + long bufferSize = UdfUtils.UNSAFE.getLong(null, outputIntermediateStatePtr_); + if (outputOffset_ + 1 > bufferSize) { + return false; + } + outputOffset_ += 1; + UdfUtils.UNSAFE.putChar(null, UdfUtils.UNSAFE.getLong(null, outputBufferPtr_) + + outputOffset_ - 1, UdfUtils.END_OF_STRING); + UdfUtils.UNSAFE.putInt(null, UdfUtils.UNSAFE.getLong(null, outputOffsetsPtr_) + + 4L * row, Integer.parseUnsignedInt(String.valueOf(outputOffset_))); + } + return true; + } + if (UdfUtils.UNSAFE.getLong(null, outputNullPtr_) != -1) { + UdfUtils.UNSAFE.putByte(UdfUtils.UNSAFE.getLong(null, outputNullPtr_) + row, (byte) 0); } - UNSAFE.putByte(UNSAFE.getLong(null, outputNullPtr_) + row * 1L, (byte) 0); switch (retType_) { case BOOLEAN: { boolean val = (boolean) obj; - UNSAFE.putByte(UNSAFE.getLong(null, outputBufferPtr_) + row * retType_.getLen(), val ? (byte) 1 : 0); - return; + UdfUtils.UNSAFE.putByte(UdfUtils.UNSAFE.getLong(null, outputBufferPtr_) + row * retType_.getLen(), val ? (byte) 1 : 0); + return true; } case TINYINT: { - UNSAFE.putByte(UNSAFE.getLong(null, outputBufferPtr_) + row * retType_.getLen(), (byte) obj); - return; + UdfUtils.UNSAFE.putByte(UdfUtils.UNSAFE.getLong(null, outputBufferPtr_) + row * retType_.getLen(), (byte) obj); + return true; } case SMALLINT: { - UNSAFE.putShort(UNSAFE.getLong(null, outputBufferPtr_) + row * retType_.getLen(), (short) obj); - return; + UdfUtils.UNSAFE.putShort(UdfUtils.UNSAFE.getLong(null, outputBufferPtr_) + row * retType_.getLen(), (short) obj); + return true; } case INT: { - UNSAFE.putInt(UNSAFE.getLong(null, outputBufferPtr_) + row * retType_.getLen(), (int) obj); - return; + UdfUtils.UNSAFE.putInt(UdfUtils.UNSAFE.getLong(null, outputBufferPtr_) + row * retType_.getLen(), (int) obj); + return true; } case BIGINT: { - UNSAFE.putLong(UNSAFE.getLong(null, outputBufferPtr_) + row * retType_.getLen(), (long) obj); - return; + UdfUtils.UNSAFE.putLong(UdfUtils.UNSAFE.getLong(null, outputBufferPtr_) + row * retType_.getLen(), (long) obj); + return true; } case FLOAT: { - UNSAFE.putFloat(UNSAFE.getLong(null, outputBufferPtr_) + row * retType_.getLen(), (float) obj); - return; + UdfUtils.UNSAFE.putFloat(UdfUtils.UNSAFE.getLong(null, outputBufferPtr_) + row * retType_.getLen(), (float) obj); + return true; } case DOUBLE: { - UNSAFE.putDouble(UNSAFE.getLong(null, outputBufferPtr_) + row * retType_.getLen(), (double) obj); - return; + UdfUtils.UNSAFE.putDouble(UdfUtils.UNSAFE.getLong(null, outputBufferPtr_) + row * retType_.getLen(), (double) obj); + return true; } + case CHAR: + case VARCHAR: + case STRING: + long bufferSize = UdfUtils.UNSAFE.getLong(null, outputIntermediateStatePtr_); + byte[] bytes = ((String) obj).getBytes(StandardCharsets.UTF_8); + if (outputOffset_ + bytes.length + 1 > bufferSize) { + return false; + } + outputOffset_ += (bytes.length + 1); + UdfUtils.UNSAFE.putChar(UdfUtils.UNSAFE.getLong(null, outputBufferPtr_) + + outputOffset_ - 1, UdfUtils.END_OF_STRING); + UdfUtils.UNSAFE.putInt(null, UdfUtils.UNSAFE.getLong(null, outputOffsetsPtr_) + 4L * row, + Integer.parseUnsignedInt(String.valueOf(outputOffset_))); + UdfUtils.copyMemory(bytes, UdfUtils.BYTE_ARRAY_OFFSET, null, + UdfUtils.UNSAFE.getLong(null, outputBufferPtr_) + + outputOffset_ - bytes.length - 1, bytes.length); + return true; default: throw new UdfRuntimeException("Unsupported return type: " + retType_); } @@ -295,32 +358,47 @@ public class UdfExecutor { // Preallocate the input objects that will be passed to the underlying UDF. // These objects are allocated once and reused across calls to evaluate() - private void allocateInputObjects(int row) throws UdfRuntimeException { + private void allocateInputObjects(long row) throws UdfRuntimeException { inputObjects_ = new Object[argTypes_.length]; inputArgs_ = new Object[argTypes_.length]; for (int i = 0; i < argTypes_.length; ++i) { switch (argTypes_[i]) { case BOOLEAN: - inputObjects_[i] = UNSAFE.getBoolean(null, UNSAFE.getLong(null, UdfUtils.getAddressAtOffset(inputBufferPtrs_, i)) + 1L * row); + inputObjects_[i] = UdfUtils.UNSAFE.getBoolean(null, UdfUtils.UNSAFE.getLong(null, UdfUtils.getAddressAtOffset(inputBufferPtrs_, i)) + row); break; case TINYINT: - inputObjects_[i] = UNSAFE.getByte(null, UNSAFE.getLong(null, UdfUtils.getAddressAtOffset(inputBufferPtrs_, i)) + 1L * row); + inputObjects_[i] = UdfUtils.UNSAFE.getByte(null, UdfUtils.UNSAFE.getLong(null, UdfUtils.getAddressAtOffset(inputBufferPtrs_, i)) + row); break; case SMALLINT: - inputObjects_[i] = UNSAFE.getShort(null, UNSAFE.getLong(null, UdfUtils.getAddressAtOffset(inputBufferPtrs_, i)) + 2L * row); + inputObjects_[i] = UdfUtils.UNSAFE.getShort(null, UdfUtils.UNSAFE.getLong(null, UdfUtils.getAddressAtOffset(inputBufferPtrs_, i)) + 2L * row); break; case INT: - inputObjects_[i] = UNSAFE.getInt(null, UNSAFE.getLong(null, UdfUtils.getAddressAtOffset(inputBufferPtrs_, i)) + 4L * row); + inputObjects_[i] = UdfUtils.UNSAFE.getInt(null, UdfUtils.UNSAFE.getLong(null, UdfUtils.getAddressAtOffset(inputBufferPtrs_, i)) + 4L * row); break; case BIGINT: - inputObjects_[i] = UNSAFE.getLong(null, UNSAFE.getLong(null, UdfUtils.getAddressAtOffset(inputBufferPtrs_, i)) + 8L * row); + inputObjects_[i] = UdfUtils.UNSAFE.getLong(null, UdfUtils.UNSAFE.getLong(null, UdfUtils.getAddressAtOffset(inputBufferPtrs_, i)) + 8L * row); break; case FLOAT: - inputObjects_[i] = UNSAFE.getFloat(null, UNSAFE.getLong(null, UdfUtils.getAddressAtOffset(inputBufferPtrs_, i)) + 4L * row); + inputObjects_[i] = UdfUtils.UNSAFE.getFloat(null, UdfUtils.UNSAFE.getLong(null, UdfUtils.getAddressAtOffset(inputBufferPtrs_, i)) + 4L * row); break; case DOUBLE: - inputObjects_[i] = UNSAFE.getDouble(null, UNSAFE.getLong(null, UdfUtils.getAddressAtOffset(inputBufferPtrs_, i)) + 8L * row); + inputObjects_[i] = UdfUtils.UNSAFE.getDouble(null, UdfUtils.UNSAFE.getLong(null, UdfUtils.getAddressAtOffset(inputBufferPtrs_, i)) + 8L * row); + break; + case CHAR: + case VARCHAR: + case STRING: + long offset = Integer.toUnsignedLong(UdfUtils.UNSAFE.getInt(null, + UdfUtils.UNSAFE.getLong(null, + UdfUtils.getAddressAtOffset(inputOffsetsPtrs_, i)) + 4L * row)); + long numBytes = row == 0 ? offset - 1 : offset - Integer.toUnsignedLong(UdfUtils.UNSAFE.getInt(null, + UdfUtils.UNSAFE.getLong(null, + UdfUtils.getAddressAtOffset(inputOffsetsPtrs_, i)) + 4L * (row - 1))) - 1; + long base = row == 0 ? UdfUtils.UNSAFE.getLong(null, UdfUtils.getAddressAtOffset(inputBufferPtrs_, i)) : + UdfUtils.UNSAFE.getLong(null, UdfUtils.getAddressAtOffset(inputBufferPtrs_, i)) + offset - numBytes - 1; + byte[] bytes = new byte[(int) numBytes]; + UdfUtils.copyMemory(null, base, bytes, UdfUtils.BYTE_ARRAY_OFFSET, numBytes); + inputObjects_[i] = new String(bytes, StandardCharsets.UTF_8); break; default: throw new UdfRuntimeException("Unsupported argument type: " + argTypes_[i]); diff --git a/fe/java-udf/src/main/java/org/apache/doris/udf/UdfUtils.java b/fe/java-udf/src/main/java/org/apache/doris/udf/UdfUtils.java index 9f33977df3..f412d8593f 100644 --- a/fe/java-udf/src/main/java/org/apache/doris/udf/UdfUtils.java +++ b/fe/java-udf/src/main/java/org/apache/doris/udf/UdfUtils.java @@ -18,17 +18,42 @@ package org.apache.doris.udf; import com.google.common.base.Preconditions; + import org.apache.doris.catalog.PrimitiveType; import org.apache.doris.catalog.ScalarType; import org.apache.doris.catalog.Type; import org.apache.doris.common.Pair; - import org.apache.doris.thrift.TPrimitiveType; import org.apache.doris.thrift.TScalarType; import org.apache.doris.thrift.TTypeDesc; import org.apache.doris.thrift.TTypeNode; +import sun.misc.Unsafe; + +import java.lang.reflect.Field; +import java.security.AccessController; +import java.security.PrivilegedAction; + public class UdfUtils { + public static final Unsafe UNSAFE; + private static final long UNSAFE_COPY_THRESHOLD = 1024L * 1024L; + public static final long BYTE_ARRAY_OFFSET; + public static final char END_OF_STRING = '\0'; + + static { + UNSAFE = (Unsafe) AccessController.doPrivileged( + (PrivilegedAction) () -> { + try { + Field f = Unsafe.class.getDeclaredField("theUnsafe"); + f.setAccessible(true); + return f.get(null); + } catch (NoSuchFieldException | IllegalAccessException e) { + throw new Error(); + } + }); + BYTE_ARRAY_OFFSET = UNSAFE.arrayBaseOffset(byte[].class); + } + protected static Pair fromThrift(TTypeDesc typeDesc, int nodeIdx) throws InternalException { TTypeNode node = typeDesc.getTypes().get(nodeIdx); Type type = null; @@ -62,4 +87,30 @@ public class UdfUtils { protected static long getAddressAtOffset(long base, int offset) { return base + 8L * offset; } + + public static void copyMemory( + Object src, long srcOffset, Object dst, long dstOffset, long length) { + // Check if dstOffset is before or after srcOffset to determine if we should copy + // forward or backwards. This is necessary in case src and dst overlap. + if (dstOffset < srcOffset) { + while (length > 0) { + long size = Math.min(length, UNSAFE_COPY_THRESHOLD); + UNSAFE.copyMemory(src, srcOffset, dst, dstOffset, size); + length -= size; + srcOffset += size; + dstOffset += size; + } + } else { + srcOffset += length; + dstOffset += length; + while (length > 0) { + long size = Math.min(length, UNSAFE_COPY_THRESHOLD); + srcOffset -= size; + dstOffset -= size; + UNSAFE.copyMemory(src, srcOffset, dst, dstOffset, size); + length -= size; + } + + } + } } diff --git a/fe/java-udf/src/test/java/org/apache/doris/udf/StringConcatUdf.java b/fe/java-udf/src/test/java/org/apache/doris/udf/StringConcatUdf.java new file mode 100644 index 0000000000..2fa6c2754d --- /dev/null +++ b/fe/java-udf/src/test/java/org/apache/doris/udf/StringConcatUdf.java @@ -0,0 +1,24 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +package org.apache.doris.udf; + +public class StringConcatUdf { + public String evaluate(String a, String b) { + return a == null || b == null? null: a + b; + } +} diff --git a/fe/java-udf/src/test/java/org/apache/doris/udf/UdfExecutorTest.java b/fe/java-udf/src/test/java/org/apache/doris/udf/UdfExecutorTest.java index 6b1c487604..6113d94e84 100644 --- a/fe/java-udf/src/test/java/org/apache/doris/udf/UdfExecutorTest.java +++ b/fe/java-udf/src/test/java/org/apache/doris/udf/UdfExecutorTest.java @@ -22,6 +22,7 @@ import org.apache.thrift.TSerializer; import org.apache.thrift.protocol.TBinaryProtocol; import org.junit.Test; +import java.nio.charset.StandardCharsets; import java.util.ArrayList; import java.util.Arrays; import java.util.Collections; @@ -42,24 +43,24 @@ public class UdfExecutorTest { fn.name = new TFunctionName("ConstantOne"); - long batchSizePtr = UdfExecutor.UNSAFE.allocateMemory(32); + long batchSizePtr = UdfUtils.UNSAFE.allocateMemory(4); int batchSize = 10; - UdfExecutor.UNSAFE.putInt(batchSizePtr, batchSize); + UdfUtils.UNSAFE.putInt(batchSizePtr, batchSize); TJavaUdfExecutorCtorParams params = new TJavaUdfExecutorCtorParams(); - params.batch_size_ptr = batchSizePtr; - params.fn = fn; + params.setBatchSizePtr(batchSizePtr); + params.setFn(fn); - long outputBuffer = UdfExecutor.UNSAFE.allocateMemory(32 * batchSize); - long outputNull = UdfExecutor.UNSAFE.allocateMemory(8 * batchSize); - long outputBufferPtr = UdfExecutor.UNSAFE.allocateMemory(64); - UdfExecutor.UNSAFE.putLong(outputBufferPtr, outputBuffer); - long outputNullPtr = UdfExecutor.UNSAFE.allocateMemory(64); - UdfExecutor.UNSAFE.putLong(outputNullPtr, outputNull); - params.output_buffer_ptr = outputBufferPtr; - params.output_null_ptr = outputNullPtr; - params.input_buffer_ptrs = 0; - params.input_nulls_ptrs = 0; - params.input_byte_offsets = 0; + long outputBuffer = UdfUtils.UNSAFE.allocateMemory(4 * batchSize); + long outputNull = UdfUtils.UNSAFE.allocateMemory(batchSize); + long outputBufferPtr = UdfUtils.UNSAFE.allocateMemory(8); + UdfUtils.UNSAFE.putLong(outputBufferPtr, outputBuffer); + long outputNullPtr = UdfUtils.UNSAFE.allocateMemory(8); + UdfUtils.UNSAFE.putLong(outputNullPtr, outputNull); + params.setOutputBufferPtr(outputBufferPtr); + params.setOutputNullPtr(outputNullPtr); + params.setInputBufferPtrs(0); + params.setInputNullsPtrs(0); + params.setInputOffsetsPtrs(0); TBinaryProtocol.Factory factory = new TBinaryProtocol.Factory(); @@ -70,8 +71,8 @@ public class UdfExecutorTest { executor.evaluate(); for (int i = 0; i < 10; i ++) { - assert (UdfExecutor.UNSAFE.getByte(outputNull + 8 * i) == 0); - assert (UdfExecutor.UNSAFE.getInt(outputBuffer + 32 * i) == 1); + assert (UdfUtils.UNSAFE.getByte(outputNull + i) == 0); + assert (UdfUtils.UNSAFE.getInt(outputBuffer + 4 * i) == 1); } } @@ -91,52 +92,52 @@ public class UdfExecutorTest { fn.name = new TFunctionName("SimpleAdd"); - long batchSizePtr = UdfExecutor.UNSAFE.allocateMemory(32); + long batchSizePtr = UdfUtils.UNSAFE.allocateMemory(4); int batchSize = 10; - UdfExecutor.UNSAFE.putInt(batchSizePtr, batchSize); + UdfUtils.UNSAFE.putInt(batchSizePtr, batchSize); TJavaUdfExecutorCtorParams params = new TJavaUdfExecutorCtorParams(); - params.batch_size_ptr = batchSizePtr; - params.fn = fn; + params.setBatchSizePtr(batchSizePtr); + params.setFn(fn); - long outputBufferPtr = UdfExecutor.UNSAFE.allocateMemory(64); - long outputNullPtr = UdfExecutor.UNSAFE.allocateMemory(64); - long outputBuffer = UdfExecutor.UNSAFE.allocateMemory(32 * batchSize); - long outputNull = UdfExecutor.UNSAFE.allocateMemory(8 * batchSize); - UdfExecutor.UNSAFE.putLong(outputBufferPtr, outputBuffer); - UdfExecutor.UNSAFE.putLong(outputNullPtr, outputNull); + long outputBufferPtr = UdfUtils.UNSAFE.allocateMemory(8); + long outputNullPtr = UdfUtils.UNSAFE.allocateMemory(8); + long outputBuffer = UdfUtils.UNSAFE.allocateMemory(4 * batchSize); + long outputNull = UdfUtils.UNSAFE.allocateMemory(batchSize); + UdfUtils.UNSAFE.putLong(outputBufferPtr, outputBuffer); + UdfUtils.UNSAFE.putLong(outputNullPtr, outputNull); - params.output_buffer_ptr = outputBufferPtr; - params.output_null_ptr = outputNullPtr; + params.setOutputBufferPtr(outputBufferPtr); + params.setOutputNullPtr(outputNullPtr); int numCols = 2; - long inputBufferPtr = UdfExecutor.UNSAFE.allocateMemory(64 * numCols); - long inputNullPtr = UdfExecutor.UNSAFE.allocateMemory(64 * numCols); + long inputBufferPtr = UdfUtils.UNSAFE.allocateMemory(8 * numCols); + long inputNullPtr = UdfUtils.UNSAFE.allocateMemory(8 * numCols); - long inputBuffer1 = UdfExecutor.UNSAFE.allocateMemory(32 * batchSize); - long inputNull1 = UdfExecutor.UNSAFE.allocateMemory(8 * batchSize); - long inputBuffer2 = UdfExecutor.UNSAFE.allocateMemory(32 * batchSize); - long inputNull2 = UdfExecutor.UNSAFE.allocateMemory(8 * batchSize); + long inputBuffer1 = UdfUtils.UNSAFE.allocateMemory(4 * batchSize); + long inputNull1 = UdfUtils.UNSAFE.allocateMemory(batchSize); + long inputBuffer2 = UdfUtils.UNSAFE.allocateMemory(4 * batchSize); + long inputNull2 = UdfUtils.UNSAFE.allocateMemory(batchSize); - UdfExecutor.UNSAFE.putLong(inputBufferPtr, inputBuffer1); - UdfExecutor.UNSAFE.putLong(inputBufferPtr + 64, inputBuffer2); - UdfExecutor.UNSAFE.putLong(inputNullPtr, inputNull1); - UdfExecutor.UNSAFE.putLong(inputNullPtr + 64, inputNull2); + UdfUtils.UNSAFE.putLong(inputBufferPtr, inputBuffer1); + UdfUtils.UNSAFE.putLong(inputBufferPtr + 8, inputBuffer2); + UdfUtils.UNSAFE.putLong(inputNullPtr, inputNull1); + UdfUtils.UNSAFE.putLong(inputNullPtr + 8, inputNull2); for (int i = 0; i < batchSize; i ++) { - UdfExecutor.UNSAFE.putInt(null, inputBuffer1 + i * 32, i); - UdfExecutor.UNSAFE.putInt(null, inputBuffer2 + i * 32, i); + UdfUtils.UNSAFE.putInt(null, inputBuffer1 + i * 4, i); + UdfUtils.UNSAFE.putInt(null, inputBuffer2 + i * 4, i); if (i % 2 == 0) { - UdfExecutor.UNSAFE.putByte(null, inputNull1 + i * 8, (byte) 1); + UdfUtils.UNSAFE.putByte(null, inputNull1 + i, (byte) 1); } else { - UdfExecutor.UNSAFE.putByte(null, inputNull1 + i * 8, (byte) 0); + UdfUtils.UNSAFE.putByte(null, inputNull1 + i, (byte) 0); } - UdfExecutor.UNSAFE.putByte(null, inputNull2 + i * 8, (byte) 0); + UdfUtils.UNSAFE.putByte(null, inputNull2 + i, (byte) 0); } - params.input_buffer_ptrs = inputBufferPtr; - params.input_nulls_ptrs = inputNullPtr; - params.input_byte_offsets = 0; + params.setInputBufferPtrs(inputBufferPtr); + params.setInputNullsPtrs(inputNullPtr); + params.setInputOffsetsPtrs(0); TBinaryProtocol.Factory factory = new TBinaryProtocol.Factory(); @@ -148,11 +149,148 @@ public class UdfExecutorTest { executor.evaluate(); for (int i = 0; i < batchSize; i ++) { if (i % 2 == 0) { - assert (UdfExecutor.UNSAFE.getByte(outputNull + 8 * i) == 1); + assert (UdfUtils.UNSAFE.getByte(outputNull + i) == 1); } else { - assert (UdfExecutor.UNSAFE.getByte(outputNull + 8 * i) == 0); - assert (UdfExecutor.UNSAFE.getInt(outputBuffer + 32 * i) == i * 2); + assert (UdfUtils.UNSAFE.getByte(outputNull + i) == 0); + assert (UdfUtils.UNSAFE.getInt(outputBuffer + 4 * i) == i * 2); } } } + + @Test + public void testStringConcatUdf() throws Exception { + TScalarFunction scalarFunction = new TScalarFunction(); + scalarFunction.symbol = "org.apache.doris.udf.StringConcatUdf"; + + TFunction fn = new TFunction(); + fn.binary_type = TFunctionBinaryType.JAVA_UDF; + TTypeNode typeNode = new TTypeNode(TTypeNodeType.SCALAR); + typeNode.scalar_type = new TScalarType(TPrimitiveType.STRING); + TTypeDesc typeDesc = new TTypeDesc(Collections.singletonList(typeNode)); + fn.ret_type = typeDesc; + fn.arg_types = Arrays.asList(typeDesc, typeDesc); + fn.scalar_fn = scalarFunction; + fn.name = new TFunctionName("StringConcat"); + + + long batchSizePtr = UdfUtils.UNSAFE.allocateMemory(32); + int batchSize = 10; + UdfUtils.UNSAFE.putInt(batchSizePtr, batchSize); + + TJavaUdfExecutorCtorParams params = new TJavaUdfExecutorCtorParams(); + params.setBatchSizePtr(batchSizePtr); + params.setFn(fn); + + long outputBufferPtr = UdfUtils.UNSAFE.allocateMemory(8); + long outputNullPtr = UdfUtils.UNSAFE.allocateMemory(8); + long outputOffsetsPtr = UdfUtils.UNSAFE.allocateMemory(8); + long outputIntermediateStatePtr = UdfUtils.UNSAFE.allocateMemory(8 * 2); + + String[] input1 = new String[batchSize]; + String[] input2 = new String[batchSize]; + long[] inputOffsets1 = new long[batchSize]; + long[] inputOffsets2 = new long[batchSize]; + long inputBufferSize1 = 0; + long inputBufferSize2 = 0; + for (int i = 0; i < batchSize; i ++) { + input1[i] = "Input1_" + i; + input2[i] = "Input2_" + i; + inputOffsets1[i] = i == 0? input1[i].getBytes(StandardCharsets.UTF_8).length + 1: + inputOffsets1[i - 1] + input1[i].getBytes(StandardCharsets.UTF_8).length + 1; + inputOffsets2[i] = i == 0? input2[i].getBytes(StandardCharsets.UTF_8).length + 1: + inputOffsets2[i - 1] + input2[i].getBytes(StandardCharsets.UTF_8).length + 1; + inputBufferSize1 += input1[i].getBytes(StandardCharsets.UTF_8).length; + inputBufferSize2 += input2[i].getBytes(StandardCharsets.UTF_8).length; + } + // In our test case, output buffer is (8 + 1) bytes * batchSize + long outputBuffer = UdfUtils.UNSAFE.allocateMemory(inputBufferSize1 + inputBufferSize2 + batchSize); + long outputNull = UdfUtils.UNSAFE.allocateMemory(batchSize); + long outputOffset = UdfUtils.UNSAFE.allocateMemory(4 * batchSize); + + UdfUtils.UNSAFE.putLong(outputBufferPtr, outputBuffer); + UdfUtils.UNSAFE.putLong(outputNullPtr, outputNull); + UdfUtils.UNSAFE.putLong(outputOffsetsPtr, outputOffset); + // reserved buffer size + UdfUtils.UNSAFE.putLong(outputIntermediateStatePtr, inputBufferSize1 + inputBufferSize2 + batchSize); + // current row id + UdfUtils.UNSAFE.putLong(outputIntermediateStatePtr + 8, 0); + + params.setOutputBufferPtr(outputBufferPtr); + params.setOutputNullPtr(outputNullPtr); + params.setOutputOffsetsPtr(outputOffsetsPtr); + params.setOutputIntermediateStatePtr(outputIntermediateStatePtr); + + int numCols = 2; + long inputBufferPtr = UdfUtils.UNSAFE.allocateMemory(8 * numCols); + long inputNullPtr = UdfUtils.UNSAFE.allocateMemory(8 * numCols); + long inputOffsetsPtr = UdfUtils.UNSAFE.allocateMemory(8 * numCols); + + long inputBuffer1 = UdfUtils.UNSAFE.allocateMemory(inputBufferSize1 + batchSize); + long inputOffset1 = UdfUtils.UNSAFE.allocateMemory(4 * batchSize); + long inputBuffer2 = UdfUtils.UNSAFE.allocateMemory(inputBufferSize2 + batchSize); + long inputOffset2 = UdfUtils.UNSAFE.allocateMemory(4 * batchSize); + + UdfUtils.UNSAFE.putLong(inputBufferPtr, inputBuffer1); + UdfUtils.UNSAFE.putLong(inputBufferPtr + 8, inputBuffer2); + UdfUtils.UNSAFE.putLong(inputNullPtr, -1); + UdfUtils.UNSAFE.putLong(inputNullPtr + 8, -1); + UdfUtils.UNSAFE.putLong(inputOffsetsPtr, inputOffset1); + UdfUtils.UNSAFE.putLong(inputOffsetsPtr + 8, inputOffset2); + + for (int i = 0; i < batchSize; i ++) { + if (i == 0) { + UdfUtils.copyMemory(input1[i].getBytes(StandardCharsets.UTF_8), + UdfUtils.BYTE_ARRAY_OFFSET, null, inputBuffer1, + input1[i].getBytes(StandardCharsets.UTF_8).length); + UdfUtils.copyMemory(input2[i].getBytes(StandardCharsets.UTF_8), + UdfUtils.BYTE_ARRAY_OFFSET, null, inputBuffer2, + input2[i].getBytes(StandardCharsets.UTF_8).length); + } else { + UdfUtils.copyMemory(input1[i].getBytes(StandardCharsets.UTF_8), + UdfUtils.BYTE_ARRAY_OFFSET, null, inputBuffer1 + inputOffsets1[i - 1], + input1[i].getBytes(StandardCharsets.UTF_8).length); + UdfUtils.copyMemory(input2[i].getBytes(StandardCharsets.UTF_8), + UdfUtils.BYTE_ARRAY_OFFSET, null, inputBuffer2 + inputOffsets2[i - 1], + input2[i].getBytes(StandardCharsets.UTF_8).length); + } + UdfUtils.UNSAFE.putInt(null, inputOffset1 + 4L * i, + Integer.parseUnsignedInt(String.valueOf(inputOffsets1[i]))); + UdfUtils.UNSAFE.putInt(null, inputOffset2 + 4L * i, + Integer.parseUnsignedInt(String.valueOf(inputOffsets2[i]))); + UdfUtils.UNSAFE.putChar(null, inputBuffer1 + inputOffsets1[i] - 1, + UdfUtils.END_OF_STRING); + UdfUtils.UNSAFE.putChar(null, inputBuffer2 + inputOffsets2[i] - 1, + UdfUtils.END_OF_STRING); + + } + params.setInputBufferPtrs(inputBufferPtr); + params.setInputNullsPtrs(inputNullPtr); + params.setInputOffsetsPtrs(inputOffsetsPtr); + + TBinaryProtocol.Factory factory = + new TBinaryProtocol.Factory(); + TSerializer serializer = new TSerializer(factory); + + UdfExecutor executor; + executor = new UdfExecutor(serializer.serialize(params)); + + executor.evaluate(); + for (int i = 0; i < batchSize; i ++) { + byte[] bytes = new byte[input1[i].getBytes(StandardCharsets.UTF_8).length + + input2[i].getBytes(StandardCharsets.UTF_8).length]; + assert (UdfUtils.UNSAFE.getByte(outputNull + i) == 0); + if (i == 0) { + UdfUtils.copyMemory(null, outputBuffer, bytes, UdfUtils.BYTE_ARRAY_OFFSET, + bytes.length); + } else { + long lastOffset = UdfUtils.UNSAFE.getInt(null, outputOffset + 4 * (i - 1)); + UdfUtils.copyMemory(null, outputBuffer + lastOffset, bytes, UdfUtils.BYTE_ARRAY_OFFSET, + bytes.length); + } + long curOffset = UdfUtils.UNSAFE.getInt(null, outputOffset + 4 * i); + assert (new String(bytes, StandardCharsets.UTF_8).equals(input1[i] + input2[i])); + assert (UdfUtils.UNSAFE.getByte(null, outputBuffer + curOffset - 1) == UdfUtils.END_OF_STRING); + assert (UdfUtils.UNSAFE.getByte(null, outputNull + i) == 0); + } + } } diff --git a/gensrc/thrift/Types.thrift b/gensrc/thrift/Types.thrift index d34f5499eb..145cc45555 100644 --- a/gensrc/thrift/Types.thrift +++ b/gensrc/thrift/Types.thrift @@ -350,7 +350,7 @@ struct TJavaUdfExecutorCtorParams { // call the Java executor with a buffer for all the inputs. // input_byte_offsets[0] is the byte offset in the buffer for the first // argument; input_byte_offsets[1] is the second, etc. - 3: optional i64 input_byte_offsets + 3: optional i64 input_offsets_ptrs // Native input buffer ptr (cast as i64) for the inputs. The input arguments // are written to this buffer directly and read from java with no copies @@ -365,8 +365,10 @@ struct TJavaUdfExecutorCtorParams { // NULL. 6: optional i64 output_null_ptr 7: optional i64 output_buffer_ptr + 8: optional i64 output_offsets_ptr + 9: optional i64 output_intermediate_state_ptr - 8: optional i64 batch_size_ptr + 10: optional i64 batch_size_ptr } // Contains all interesting statistics from a single 'memory pool' in the JVM. diff --git a/run-be-ut.sh b/run-be-ut.sh index e86b289952..2b9d873527 100755 --- a/run-be-ut.sh +++ b/run-be-ut.sh @@ -176,6 +176,50 @@ done export DORIS_TEST_BINARY_DIR=${DORIS_TEST_BINARY_DIR}/test/ +# prepare jvm if needed +jdk_version() { + local result + local java_cmd=$JAVA_HOME/bin/java + local IFS=$'\n' + # remove \r for Cygwin + local lines=$("$java_cmd" -Xms32M -Xmx32M -version 2>&1 | tr '\r' '\n') + if [[ -z $java_cmd ]] + then + result=no_java + else + for line in $lines; do + if [[ (-z $result) && ($line = *"version \""*) ]] + then + local ver=$(echo $line | sed -e 's/.*version "\(.*\)"\(.*\)/\1/; 1q') + # on macOS, sed doesn't support '?' + if [[ $ver = "1."* ]] + then + result=$(echo $ver | sed -e 's/1\.\([0-9]*\)\(.*\)/\1/; 1q') + else + result=$(echo $ver | sed -e 's/\([0-9]*\)\(.*\)/\1/; 1q') + fi + fi + done + fi + echo "$result" +} + +jvm_arch="amd64" +MACHINE_TYPE=$(uname -m) +if [[ "${MACHINE_TYPE}" == "aarch64" ]]; then + jvm_arch="aarch64" +fi +java_version=$(jdk_version) +if [[ $java_version -gt 8 ]]; then + export LD_LIBRARY_PATH=$JAVA_HOME/lib/server:$JAVA_HOME/lib:$LD_LIBRARY_PATH +# JAVA_HOME is jdk +elif [[ -d "$JAVA_HOME/jre" ]]; then + export LD_LIBRARY_PATH=$JAVA_HOME/jre/lib/$jvm_arch/server:$JAVA_HOME/jre/lib/$jvm_arch:$LD_LIBRARY_PATH +# JAVA_HOME is jre +else + export LD_LIBRARY_PATH=$JAVA_HOME/lib/$jvm_arch/server:$JAVA_HOME/lib/$jvm_arch:$LD_LIBRARY_PATH +fi + # prepare gtest output dir GTEST_OUTPUT_DIR=${CMAKE_BUILD_DIR}/gtest_output rm -rf ${GTEST_OUTPUT_DIR} && mkdir ${GTEST_OUTPUT_DIR}