[imporve](udaf) refactor java-udaf executor by using for loop (#21713)

refactor java-udaf executor by using for loop
This commit is contained in:
zhangstar333
2023-07-14 11:37:19 +08:00
committed by GitHub
parent ea73dd5851
commit c07e2ada43
6 changed files with 554 additions and 406 deletions

View File

@ -128,63 +128,80 @@ public:
return Status::OK();
}
Status add(const int64_t places_address[], bool is_single_place, const IColumn** columns,
size_t row_num_start, size_t row_num_end, const DataTypes& argument_types) {
Status add(int64_t places_address, bool is_single_place, const IColumn** columns,
int row_num_start, int row_num_end, const DataTypes& argument_types,
int place_offset) {
JNIEnv* env = nullptr;
RETURN_NOT_OK_STATUS_WITH_WARN(JniUtil::GetJNIEnv(&env), "Java-Udaf add function");
jclass obj_class = env->FindClass("[Ljava/lang/Object;");
jobjectArray arg_objects = env->NewObjectArray(argument_size, obj_class, nullptr);
int64_t nullmap_address = 0;
for (int arg_idx = 0; arg_idx < argument_size; ++arg_idx) {
bool arg_column_nullable = false;
auto data_col = columns[arg_idx];
if (auto* nullable = check_and_get_column<const ColumnNullable>(*columns[arg_idx])) {
arg_column_nullable = true;
auto null_col = nullable->get_null_map_column_ptr();
data_col = nullable->get_nested_column_ptr();
auto null_col = check_and_get_column<ColumnVector<UInt8>>(
nullable->get_null_map_column_ptr());
input_nulls_buffer_ptr.get()[arg_idx] =
reinterpret_cast<int64_t>(null_col->get_data().data());
} else {
input_nulls_buffer_ptr.get()[arg_idx] = -1;
nullmap_address = reinterpret_cast<int64_t>(
check_and_get_column<ColumnVector<UInt8>>(null_col)->get_data().data());
}
if (data_col->is_column_string()) {
const ColumnString* str_col = check_and_get_column<ColumnString>(data_col);
input_values_buffer_ptr.get()[arg_idx] =
reinterpret_cast<int64_t>(str_col->get_chars().data());
input_offsets_ptrs.get()[arg_idx] =
reinterpret_cast<int64_t>(str_col->get_offsets().data());
} else if (data_col->is_numeric() || data_col->is_column_decimal()) {
input_values_buffer_ptr.get()[arg_idx] =
reinterpret_cast<int64_t>(data_col->get_raw_data().data);
// convert argument column data into java type
jobjectArray arr_obj = nullptr;
if (data_col->is_numeric() || data_col->is_column_decimal()) {
arr_obj = (jobjectArray)env->CallObjectMethod(
executor_obj, executor_convert_basic_argument_id, arg_idx,
arg_column_nullable, row_num_start, row_num_end, nullmap_address,
reinterpret_cast<int64_t>(data_col->get_raw_data().data), 0);
} else if (data_col->is_column_string()) {
const ColumnString* str_col = assert_cast<const ColumnString*>(data_col);
arr_obj = (jobjectArray)env->CallObjectMethod(
executor_obj, executor_convert_basic_argument_id, arg_idx,
arg_column_nullable, row_num_start, row_num_end, nullmap_address,
reinterpret_cast<int64_t>(str_col->get_chars().data()),
reinterpret_cast<int64_t>(str_col->get_offsets().data()));
} else if (data_col->is_column_array()) {
const ColumnArray* array_col = assert_cast<const ColumnArray*>(data_col);
input_offsets_ptrs.get()[arg_idx] = reinterpret_cast<int64_t>(
array_col->get_offsets_column().get_raw_data().data);
const ColumnNullable& array_nested_nullable =
assert_cast<const ColumnNullable&>(array_col->get_data());
auto data_column_null_map = array_nested_nullable.get_null_map_column_ptr();
auto data_column = array_nested_nullable.get_nested_column_ptr();
input_array_nulls_buffer_ptr.get()[arg_idx] = reinterpret_cast<int64_t>(
auto offset_address = reinterpret_cast<int64_t>(
array_col->get_offsets_column().get_raw_data().data);
auto nested_nullmap_address = reinterpret_cast<int64_t>(
check_and_get_column<ColumnVector<UInt8>>(data_column_null_map)
->get_data()
.data());
//need pass FE, nullamp and offset, chars
int64_t nested_data_address = 0, nested_offset_address = 0;
// array type need pass address: [nullmap_address], offset_address, nested_nullmap_address, nested_data_address/nested_char_address,nested_offset_address
if (data_column->is_column_string()) {
const ColumnString* col = assert_cast<const ColumnString*>(data_column.get());
input_values_buffer_ptr.get()[arg_idx] =
reinterpret_cast<int64_t>(col->get_chars().data());
input_array_string_offsets_ptrs.get()[arg_idx] =
reinterpret_cast<int64_t>(col->get_offsets().data());
nested_data_address = reinterpret_cast<int64_t>(col->get_chars().data());
nested_offset_address = reinterpret_cast<int64_t>(col->get_offsets().data());
} else {
input_values_buffer_ptr.get()[arg_idx] =
nested_data_address =
reinterpret_cast<int64_t>(data_column->get_raw_data().data);
}
arr_obj = (jobjectArray)env->CallObjectMethod(
executor_obj, executor_convert_array_argument_id, arg_idx,
arg_column_nullable, row_num_start, row_num_end, nullmap_address,
offset_address, nested_nullmap_address, nested_data_address,
nested_offset_address);
} else {
return Status::InvalidArgument(
strings::Substitute("Java UDAF doesn't support type is $0 now !",
argument_types[arg_idx]->get_name()));
}
env->SetObjectArrayElement(arg_objects, arg_idx, arr_obj);
env->DeleteLocalRef(arr_obj);
}
*input_place_ptrs = reinterpret_cast<int64_t>(places_address);
env->CallNonvirtualVoidMethod(executor_obj, executor_cl, executor_add_id, is_single_place,
row_num_start, row_num_end);
RETURN_IF_ERROR(JniUtil::GetJniExceptionMsg(env));
// invoke add batch
env->CallObjectMethod(executor_obj, executor_add_batch_id, is_single_place, row_num_start,
row_num_end, places_address, place_offset, arg_objects);
env->DeleteLocalRef(arg_objects);
env->DeleteLocalRef(obj_class);
return JniUtil::GetJniExceptionMsg(env);
}
@ -392,6 +409,12 @@ private:
register_id("getValue", UDAF_EXECUTOR_RESULT_SIGNATURE, executor_result_id));
RETURN_IF_ERROR(
register_id("destroy", UDAF_EXECUTOR_DESTROY_SIGNATURE, executor_destroy_id));
RETURN_IF_ERROR(register_id("convertBasicArguments", "(IZIIJJJ)[Ljava/lang/Object;",
executor_convert_basic_argument_id));
RETURN_IF_ERROR(register_id("convertArrayArguments", "(IZIIJJJJJ)[Ljava/lang/Object;",
executor_convert_array_argument_id));
RETURN_IF_ERROR(
register_id("addBatch", "(ZIIJI[Ljava/lang/Object;)V", executor_add_batch_id));
return Status::OK();
}
@ -403,12 +426,15 @@ private:
jmethodID executor_ctor_id;
jmethodID executor_add_id;
jmethodID executor_add_batch_id;
jmethodID executor_merge_id;
jmethodID executor_serialize_id;
jmethodID executor_result_id;
jmethodID executor_reset_id;
jmethodID executor_close_id;
jmethodID executor_destroy_id;
jmethodID executor_convert_basic_argument_id;
jmethodID executor_convert_array_argument_id;
std::unique_ptr<int64_t[]> input_values_buffer_ptr;
std::unique_ptr<int64_t[]> input_nulls_buffer_ptr;
@ -481,11 +507,10 @@ public:
void add(AggregateDataPtr __restrict place, const IColumn** columns, size_t row_num,
Arena*) const override {
int64_t places_address[1];
places_address[0] = reinterpret_cast<int64_t>(place);
Status st =
this->data(_exec_place)
.add(places_address, true, columns, row_num, row_num + 1, argument_types);
int64_t places_address = reinterpret_cast<int64_t>(place);
Status st = this->data(_exec_place)
.add(places_address, true, columns, row_num, row_num + 1,
argument_types, 0);
if (UNLIKELY(st != Status::OK())) {
throw doris::Exception(ErrorCode::INTERNAL_ERROR, st.to_string());
}
@ -493,25 +518,20 @@ public:
void add_batch(size_t batch_size, AggregateDataPtr* places, size_t place_offset,
const IColumn** columns, Arena* /*arena*/, bool /*agg_many*/) const override {
int64_t places_address[batch_size];
for (size_t i = 0; i < batch_size; ++i) {
places_address[i] = reinterpret_cast<int64_t>(places[i] + place_offset);
}
int64_t places_address = reinterpret_cast<int64_t>(places);
Status st = this->data(_exec_place)
.add(places_address, false, columns, 0, batch_size, argument_types);
.add(places_address, false, columns, 0, batch_size, argument_types,
place_offset);
if (UNLIKELY(st != Status::OK())) {
throw doris::Exception(ErrorCode::INTERNAL_ERROR, st.to_string());
}
}
// TODO: Here we calling method by jni, And if we get a thrown from FE,
// But can't let user known the error, only return directly and output error to log file.
void add_batch_single_place(size_t batch_size, AggregateDataPtr place, const IColumn** columns,
Arena* /*arena*/) const override {
int64_t places_address[1];
places_address[0] = reinterpret_cast<int64_t>(place);
int64_t places_address = reinterpret_cast<int64_t>(place);
Status st = this->data(_exec_place)
.add(places_address, true, columns, 0, batch_size, argument_types);
.add(places_address, true, columns, 0, batch_size, argument_types, 0);
if (UNLIKELY(st != Status::OK())) {
throw doris::Exception(ErrorCode::INTERNAL_ERROR, st.to_string());
}
@ -522,11 +542,10 @@ public:
Arena* arena) const override {
frame_start = std::max<int64_t>(frame_start, partition_start);
frame_end = std::min<int64_t>(frame_end, partition_end);
int64_t places_address[1];
places_address[0] = reinterpret_cast<int64_t>(place);
Status st =
this->data(_exec_place)
.add(places_address, true, columns, frame_start, frame_end, argument_types);
int64_t places_address = reinterpret_cast<int64_t>(place);
Status st = this->data(_exec_place)
.add(places_address, true, columns, frame_start, frame_end,
argument_types, 0);
if (UNLIKELY(st != Status::OK())) {
throw doris::Exception(ErrorCode::INTERNAL_ERROR, st.to_string());
}

View File

@ -41,6 +41,12 @@ under the License.
<artifactId>java-common</artifactId>
<version>${project.version}</version>
</dependency>
<!-- https://mvnrepository.com/artifact/com.esotericsoftware/reflectasm -->
<dependency>
<groupId>com.esotericsoftware</groupId>
<artifactId>reflectasm</artifactId>
<version>1.11.9</version>
</dependency>
</dependencies>
<build>

View File

@ -25,12 +25,14 @@ import org.apache.doris.common.jni.utils.UdfUtils;
import org.apache.doris.common.jni.utils.UdfUtils.JavaUdfDataType;
import org.apache.doris.thrift.TJavaUdfExecutorCtorParams;
import com.google.common.base.Preconditions;
import org.apache.log4j.Logger;
import org.apache.thrift.TDeserializer;
import org.apache.thrift.TException;
import org.apache.thrift.protocol.TBinaryProtocol;
import java.io.IOException;
import java.lang.reflect.Array;
import java.math.BigDecimal;
import java.math.BigInteger;
import java.math.RoundingMode;
@ -1021,4 +1023,183 @@ public abstract class BaseExecutor {
protected void updateOutputOffset(long offset) {
}
public Object[] convertBasicArg(boolean isUdf, int argIdx, boolean isNullable, int rowStart, int rowEnd,
long nullMapAddr, long columnAddr, long strOffsetAddr) {
switch (argTypes[argIdx]) {
case BOOLEAN:
return UdfConvert.convertBooleanArg(isNullable, rowStart, rowEnd, nullMapAddr, columnAddr);
case TINYINT:
return UdfConvert.convertTinyIntArg(isNullable, rowStart, rowEnd, nullMapAddr, columnAddr);
case SMALLINT:
return UdfConvert.convertSmallIntArg(isNullable, rowStart, rowEnd, nullMapAddr, columnAddr);
case INT:
return UdfConvert.convertIntArg(isNullable, rowStart, rowEnd, nullMapAddr, columnAddr);
case BIGINT:
return UdfConvert.convertBigIntArg(isNullable, rowStart, rowEnd, nullMapAddr, columnAddr);
case LARGEINT:
return UdfConvert.convertLargeIntArg(isNullable, rowStart, rowEnd, nullMapAddr, columnAddr);
case FLOAT:
return UdfConvert.convertFloatArg(isNullable, rowStart, rowEnd, nullMapAddr, columnAddr);
case DOUBLE:
return UdfConvert.convertDoubleArg(isNullable, rowStart, rowEnd, nullMapAddr, columnAddr);
case CHAR:
case VARCHAR:
case STRING:
return UdfConvert
.convertStringArg(isNullable, rowStart, rowEnd, nullMapAddr, columnAddr, strOffsetAddr);
case DATE: // udaf maybe argClass[i + argClassOffset] need add +1
return UdfConvert
.convertDateArg(isUdf ? argClass[argIdx] : argClass[argIdx + 1], isNullable, rowStart, rowEnd,
nullMapAddr, columnAddr);
case DATETIME:
return UdfConvert
.convertDateTimeArg(isUdf ? argClass[argIdx] : argClass[argIdx + 1], isNullable, rowStart,
rowEnd, nullMapAddr, columnAddr);
case DATEV2:
return UdfConvert
.convertDateV2Arg(isUdf ? argClass[argIdx] : argClass[argIdx + 1], isNullable, rowStart, rowEnd,
nullMapAddr, columnAddr);
case DATETIMEV2:
return UdfConvert
.convertDateTimeV2Arg(isUdf ? argClass[argIdx] : argClass[argIdx + 1], isNullable, rowStart,
rowEnd, nullMapAddr, columnAddr);
case DECIMALV2:
case DECIMAL128:
return UdfConvert
.convertDecimalArg(argTypes[argIdx].getScale(), 16L, isNullable, rowStart, rowEnd, nullMapAddr,
columnAddr);
case DECIMAL32:
return UdfConvert
.convertDecimalArg(argTypes[argIdx].getScale(), 4L, isNullable, rowStart, rowEnd, nullMapAddr,
columnAddr);
case DECIMAL64:
return UdfConvert
.convertDecimalArg(argTypes[argIdx].getScale(), 8L, isNullable, rowStart, rowEnd, nullMapAddr,
columnAddr);
default: {
LOG.info("Not support type: " + argTypes[argIdx].toString());
Preconditions.checkState(false, "Not support type: " + argTypes[argIdx].toString());
break;
}
}
return null;
}
public Object[] convertArrayArg(int argIdx, boolean isNullable, int rowStart, int rowEnd, long nullMapAddr,
long offsetsAddr, long nestedNullMapAddr, long dataAddr, long strOffsetAddr) {
Object[] argument = (Object[]) Array.newInstance(ArrayList.class, rowEnd - rowStart);
for (int row = rowStart; row < rowEnd; ++row) {
long offsetStart = UdfUtils.UNSAFE.getLong(null, offsetsAddr + 8L * (row - 1));
long offsetEnd = UdfUtils.UNSAFE.getLong(null, offsetsAddr + 8L * (row));
int currentRowNum = (int) (offsetEnd - offsetStart);
switch (argTypes[argIdx].getItemType().getPrimitiveType()) {
case BOOLEAN: {
argument[row - rowStart] = UdfConvert
.convertArrayBooleanArg(row, currentRowNum, offsetStart, isNullable, nullMapAddr,
nestedNullMapAddr, dataAddr);
break;
}
case TINYINT: {
argument[row - rowStart] = UdfConvert
.convertArrayTinyIntArg(row, currentRowNum, offsetStart, isNullable, nullMapAddr,
nestedNullMapAddr, dataAddr);
break;
}
case SMALLINT: {
argument[row - rowStart] = UdfConvert
.convertArraySmallIntArg(row, currentRowNum, offsetStart, isNullable, nullMapAddr,
nestedNullMapAddr, dataAddr);
break;
}
case INT: {
argument[row - rowStart] = UdfConvert
.convertArrayIntArg(row, currentRowNum, offsetStart, isNullable, nullMapAddr,
nestedNullMapAddr, dataAddr);
break;
}
case BIGINT: {
argument[row - rowStart] = UdfConvert
.convertArrayBigIntArg(row, currentRowNum, offsetStart, isNullable, nullMapAddr,
nestedNullMapAddr, dataAddr);
break;
}
case LARGEINT: {
argument[row - rowStart] = UdfConvert
.convertArrayLargeIntArg(row, currentRowNum, offsetStart, isNullable, nullMapAddr,
nestedNullMapAddr, dataAddr);
break;
}
case FLOAT: {
argument[row - rowStart] = UdfConvert
.convertArrayFloatArg(row, currentRowNum, offsetStart, isNullable, nullMapAddr,
nestedNullMapAddr, dataAddr);
break;
}
case DOUBLE: {
argument[row - rowStart] = UdfConvert
.convertArrayDoubleArg(row, currentRowNum, offsetStart, isNullable, nullMapAddr,
nestedNullMapAddr, dataAddr);
break;
}
case CHAR:
case VARCHAR:
case STRING: {
argument[row - rowStart] = UdfConvert
.convertArrayStringArg(row, currentRowNum, offsetStart, isNullable, nullMapAddr,
nestedNullMapAddr, dataAddr, strOffsetAddr);
break;
}
case DATE: {
argument[row - rowStart] = UdfConvert
.convertArrayDateArg(row, currentRowNum, offsetStart, isNullable, nullMapAddr,
nestedNullMapAddr, dataAddr);
break;
}
case DATETIME: {
argument[row - rowStart] = UdfConvert
.convertArrayDateTimeArg(row, currentRowNum, offsetStart, isNullable, nullMapAddr,
nestedNullMapAddr, dataAddr);
break;
}
case DATEV2: {
argument[row - rowStart] = UdfConvert
.convertArrayDateV2Arg(row, currentRowNum, offsetStart, isNullable, nullMapAddr,
nestedNullMapAddr, dataAddr);
break;
}
case DATETIMEV2: {
argument[row - rowStart] = UdfConvert
.convertArrayDateTimeV2Arg(row, currentRowNum, offsetStart, isNullable,
nullMapAddr, nestedNullMapAddr, dataAddr);
break;
}
case DECIMALV2:
case DECIMAL128: {
argument[row - rowStart] = UdfConvert
.convertArrayDecimalArg(argTypes[argIdx].getScale(), 16L, row, currentRowNum,
offsetStart, isNullable, nullMapAddr, nestedNullMapAddr, dataAddr);
break;
}
case DECIMAL32: {
argument[row - rowStart] = UdfConvert
.convertArrayDecimalArg(argTypes[argIdx].getScale(), 4L, row, currentRowNum,
offsetStart, isNullable, nullMapAddr, nestedNullMapAddr, dataAddr);
break;
}
case DECIMAL64: {
argument[row - rowStart] = UdfConvert
.convertArrayDecimalArg(argTypes[argIdx].getScale(), 8L, row, currentRowNum,
offsetStart, isNullable, nullMapAddr, nestedNullMapAddr, dataAddr);
break;
}
default: {
LOG.info("Not support: " + argTypes[argIdx]);
Preconditions.checkState(false, "Not support type " + argTypes[argIdx].toString());
break;
}
}
}
return argument;
}
}

View File

@ -24,6 +24,7 @@ import org.apache.doris.common.jni.utils.UdfUtils;
import org.apache.doris.common.jni.utils.UdfUtils.JavaUdfDataType;
import org.apache.doris.thrift.TJavaUdfExecutorCtorParams;
import com.esotericsoftware.reflectasm.MethodAccess;
import com.google.common.base.Joiner;
import com.google.common.collect.Lists;
import org.apache.log4j.Logger;
@ -36,6 +37,7 @@ import java.lang.reflect.Constructor;
import java.lang.reflect.Method;
import java.net.MalformedURLException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashMap;
/**
@ -49,6 +51,8 @@ public class UdafExecutor extends BaseExecutor {
private HashMap<String, Method> allMethods;
private HashMap<Long, Object> stateObjMap;
private Class retClass;
private int addIndex;
private MethodAccess methodAccess;
/**
* Constructor to create an object.
@ -66,6 +70,84 @@ public class UdafExecutor extends BaseExecutor {
super.close();
}
public Object[] convertBasicArguments(int argIdx, boolean isNullable, int rowStart, int rowEnd, long nullMapAddr,
long columnAddr, long strOffsetAddr) {
return convertBasicArg(false, argIdx, isNullable, rowStart, rowEnd, nullMapAddr, columnAddr, strOffsetAddr);
}
public Object[] convertArrayArguments(int argIdx, boolean isNullable, int rowStart, int rowEnd, long nullMapAddr,
long offsetsAddr, long nestedNullMapAddr, long dataAddr, long strOffsetAddr) {
return convertArrayArg(argIdx, isNullable, rowStart, rowEnd, nullMapAddr, offsetsAddr, nestedNullMapAddr,
dataAddr, strOffsetAddr);
}
public void addBatch(boolean isSinglePlace, int rowStart, int rowEnd, long placeAddr, int offset, Object[] column)
throws UdfRuntimeException {
if (isSinglePlace) {
addBatchSingle(rowStart, rowEnd, placeAddr, column);
} else {
addBatchPlaces(rowStart, rowEnd, placeAddr, offset, column);
}
}
public void addBatchSingle(int rowStart, int rowEnd, long placeAddr, Object[] column) throws UdfRuntimeException {
try {
Long curPlace = placeAddr;
Object[] inputArgs = new Object[argTypes.length + 1];
Object state = stateObjMap.get(curPlace);
if (state != null) {
inputArgs[0] = state;
} else {
Object newState = createAggState();
stateObjMap.put(curPlace, newState);
inputArgs[0] = newState;
}
Object[][] inputs = (Object[][]) column;
for (int i = 0; i < (rowEnd - rowStart); ++i) {
for (int j = 0; j < column.length; ++j) {
inputArgs[j + 1] = inputs[j][i];
}
methodAccess.invoke(udf, addIndex, inputArgs);
}
} catch (Exception e) {
LOG.warn("invoke add function meet some error: " + e.getCause().toString());
throw new UdfRuntimeException("UDAF failed to addBatchSingle: ", e);
}
}
public void addBatchPlaces(int rowStart, int rowEnd, long placeAddr, int offset, Object[] column)
throws UdfRuntimeException {
try {
Object[][] inputs = (Object[][]) column;
ArrayList<Object> placeState = new ArrayList<>(rowEnd - rowStart);
for (int row = rowStart; row < rowEnd; ++row) {
Long curPlace = UdfUtils.UNSAFE.getLong(null, placeAddr + (8L * row)) + offset;
Object state = stateObjMap.get(curPlace);
if (state != null) {
placeState.add(state);
} else {
Object newState = createAggState();
stateObjMap.put(curPlace, newState);
placeState.add(newState);
}
}
//spilt into two for loop
Object[] inputArgs = new Object[argTypes.length + 1];
for (int row = 0; row < (rowEnd - rowStart); ++row) {
inputArgs[0] = placeState.get(row);
for (int j = 0; j < column.length; ++j) {
inputArgs[j + 1] = inputs[j][row];
}
methodAccess.invoke(udf, addIndex, inputArgs);
}
} catch (Exception e) {
LOG.warn("invoke add function meet some error: " + Arrays.toString(e.getStackTrace()));
throw new UdfRuntimeException("UDAF failed to addBatchPlaces: ", e);
}
}
/**
* invoke add function, add row in loop [rowStart, rowEnd).
*/
@ -224,10 +306,10 @@ public class UdafExecutor extends BaseExecutor {
protected long getCurrentOutputOffset(long row, boolean isArrayType) {
if (isArrayType) {
return Integer.toUnsignedLong(
UdfUtils.UNSAFE.getInt(null, UdfUtils.UNSAFE.getLong(null, outputOffsetsPtr) + 8L * (row - 1)));
UdfUtils.UNSAFE.getInt(null, UdfUtils.UNSAFE.getLong(null, outputOffsetsPtr) + 8L * (row - 1)));
} else {
return Integer.toUnsignedLong(
UdfUtils.UNSAFE.getInt(null, UdfUtils.UNSAFE.getLong(null, outputOffsetsPtr) + 4L * (row - 1)));
UdfUtils.UNSAFE.getInt(null, UdfUtils.UNSAFE.getLong(null, outputOffsetsPtr) + 4L * (row - 1)));
}
}
@ -251,6 +333,7 @@ public class UdafExecutor extends BaseExecutor {
loader = ClassLoader.getSystemClassLoader();
}
Class<?> c = Class.forName(className, true, loader);
methodAccess = MethodAccess.get(c);
Constructor<?> ctor = c.getConstructor();
udf = ctor.newInstance();
Method[] methods = c.getDeclaredMethods();
@ -281,7 +364,7 @@ public class UdafExecutor extends BaseExecutor {
}
case UDAF_ADD_FUNCTION: {
allMethods.put(methods[idx].getName(), methods[idx]);
addIndex = methodAccess.getIndex(UDAF_ADD_FUNCTION);
argClass = methods[idx].getParameterTypes();
if (argClass.length != parameterTypes.length + 1) {
LOG.debug("add function parameterTypes length not equal " + argClass.length + " "

View File

@ -37,263 +37,269 @@ import java.util.Arrays;
public class UdfConvert {
private static final Logger LOG = Logger.getLogger(UdfConvert.class);
public static Object[] convertBooleanArg(boolean isNullable, int numRows, long nullMapAddr, long columnAddr) {
Boolean[] argument = new Boolean[numRows];
if (isNullable) {
for (int i = 0; i < numRows; ++i) {
if (UdfUtils.UNSAFE.getByte(nullMapAddr + i) == 0) {
argument[i] = UdfUtils.UNSAFE.getBoolean(null, columnAddr + i);
} // else is the current row is null
}
} else {
for (int i = 0; i < numRows; ++i) {
argument[i] = UdfUtils.UNSAFE.getBoolean(null, columnAddr + i);
}
}
return argument;
}
public static Object[] convertTinyIntArg(boolean isNullable, int numRows, long nullMapAddr, long columnAddr) {
Byte[] argument = new Byte[numRows];
if (isNullable) {
for (int i = 0; i < numRows; ++i) {
if (UdfUtils.UNSAFE.getByte(nullMapAddr + i) == 0) {
argument[i] = UdfUtils.UNSAFE.getByte(null, columnAddr + i);
} // else is the current row is null
}
} else {
for (int i = 0; i < numRows; ++i) {
argument[i] = UdfUtils.UNSAFE.getByte(null, columnAddr + i);
}
}
return argument;
}
public static Object[] convertSmallIntArg(boolean isNullable, int numRows, long nullMapAddr, long columnAddr) {
Short[] argument = new Short[numRows];
if (isNullable) {
for (int i = 0; i < numRows; ++i) {
if (UdfUtils.UNSAFE.getByte(nullMapAddr + i) == 0) {
argument[i] = UdfUtils.UNSAFE.getShort(null, columnAddr + (i * 2L));
} // else is the current row is null
}
} else {
for (int i = 0; i < numRows; ++i) {
argument[i] = UdfUtils.UNSAFE.getShort(null, columnAddr + (i * 2L));
}
}
return argument;
}
public static Object[] convertIntArg(boolean isNullable, int numRows, long nullMapAddr, long columnAddr) {
Integer[] argument = new Integer[numRows];
if (isNullable) {
for (int i = 0; i < numRows; ++i) {
if (UdfUtils.UNSAFE.getByte(nullMapAddr + i) == 0) {
argument[i] = UdfUtils.UNSAFE.getInt(null, columnAddr + (i * 4L));
} // else is the current row is null
}
} else {
for (int i = 0; i < numRows; ++i) {
argument[i] = UdfUtils.UNSAFE.getInt(null, columnAddr + (i * 4L));
}
}
return argument;
}
public static Object[] convertBigIntArg(boolean isNullable, int numRows, long nullMapAddr, long columnAddr) {
Long[] argument = new Long[numRows];
if (isNullable) {
for (int i = 0; i < numRows; ++i) {
if (UdfUtils.UNSAFE.getByte(nullMapAddr + i) == 0) {
argument[i] = UdfUtils.UNSAFE.getLong(null, columnAddr + (i * 8L));
} // else is the current row is null
}
} else {
for (int i = 0; i < numRows; ++i) {
argument[i] = UdfUtils.UNSAFE.getLong(null, columnAddr + (i * 8L));
}
}
return argument;
}
public static Object[] convertFloatArg(boolean isNullable, int numRows, long nullMapAddr, long columnAddr) {
Float[] argument = new Float[numRows];
if (isNullable) {
for (int i = 0; i < numRows; ++i) {
if (UdfUtils.UNSAFE.getByte(nullMapAddr + i) == 0) {
argument[i] = UdfUtils.UNSAFE.getFloat(null, columnAddr + (i * 4L));
} // else is the current row is null
}
} else {
for (int i = 0; i < numRows; ++i) {
argument[i] = UdfUtils.UNSAFE.getFloat(null, columnAddr + (i * 4L));
}
}
return argument;
}
public static Object[] convertDoubleArg(boolean isNullable, int numRows, long nullMapAddr, long columnAddr) {
Double[] argument = new Double[numRows];
if (isNullable) {
for (int i = 0; i < numRows; ++i) {
if (UdfUtils.UNSAFE.getByte(nullMapAddr + i) == 0) {
argument[i] = UdfUtils.UNSAFE.getDouble(null, columnAddr + (i * 8L));
} // else is the current row is null
}
} else {
for (int i = 0; i < numRows; ++i) {
argument[i] = UdfUtils.UNSAFE.getDouble(null, columnAddr + (i * 8L));
}
}
return argument;
}
public static Object[] convertDateArg(Class argTypeClass, boolean isNullable, int numRows, long nullMapAddr,
public static Object[] convertBooleanArg(boolean isNullable, int rowsStart, int rowsEnd, long nullMapAddr,
long columnAddr) {
Object[] argument = (Object[]) Array.newInstance(argTypeClass, numRows);
Boolean[] argument = new Boolean[rowsEnd - rowsStart];
if (isNullable) {
for (int i = 0; i < numRows; ++i) {
for (int i = rowsStart; i < rowsEnd; ++i) {
if (UdfUtils.UNSAFE.getByte(nullMapAddr + i) == 0) {
argument[i - rowsStart] = UdfUtils.UNSAFE.getBoolean(null, columnAddr + i);
} // else is the current row is null
}
} else {
for (int i = rowsStart; i < rowsEnd; ++i) {
argument[i - rowsStart] = UdfUtils.UNSAFE.getBoolean(null, columnAddr + i);
}
}
return argument;
}
public static Object[] convertTinyIntArg(boolean isNullable, int rowsStart, int rowsEnd, long nullMapAddr,
long columnAddr) {
Byte[] argument = new Byte[rowsEnd - rowsStart];
if (isNullable) {
for (int i = rowsStart; i < rowsEnd; ++i) {
if (UdfUtils.UNSAFE.getByte(nullMapAddr + i) == 0) {
argument[i - rowsStart] = UdfUtils.UNSAFE.getByte(null, columnAddr + i);
} // else is the current row is null
}
} else {
for (int i = rowsStart; i < rowsEnd; ++i) {
argument[i - rowsStart] = UdfUtils.UNSAFE.getByte(null, columnAddr + i);
}
}
return argument;
}
public static Object[] convertSmallIntArg(boolean isNullable, int rowsStart, int rowsEnd, long nullMapAddr,
long columnAddr) {
Short[] argument = new Short[rowsEnd - rowsStart];
if (isNullable) {
for (int i = rowsStart; i < rowsEnd; ++i) {
if (UdfUtils.UNSAFE.getByte(nullMapAddr + i) == 0) {
argument[i - rowsStart] = UdfUtils.UNSAFE.getShort(null, columnAddr + (i * 2L));
} // else is the current row is null
}
} else {
for (int i = rowsStart; i < rowsEnd; ++i) {
argument[i - rowsStart] = UdfUtils.UNSAFE.getShort(null, columnAddr + (i * 2L));
}
}
return argument;
}
public static Object[] convertIntArg(boolean isNullable, int rowsStart, int rowsEnd, long nullMapAddr,
long columnAddr) {
Integer[] argument = new Integer[rowsEnd - rowsStart];
if (isNullable) {
for (int i = rowsStart; i < rowsEnd; ++i) {
if (UdfUtils.UNSAFE.getByte(nullMapAddr + i) == 0) {
argument[i - rowsStart] = UdfUtils.UNSAFE.getInt(null, columnAddr + (i * 4L));
} // else is the current row is null
}
} else {
for (int i = rowsStart; i < rowsEnd; ++i) {
argument[i - rowsStart] = UdfUtils.UNSAFE.getInt(null, columnAddr + (i * 4L));
}
}
return argument;
}
public static Object[] convertBigIntArg(boolean isNullable, int rowsStart, int rowsEnd, long nullMapAddr,
long columnAddr) {
Long[] argument = new Long[rowsEnd - rowsStart];
if (isNullable) {
for (int i = rowsStart; i < rowsEnd; ++i) {
if (UdfUtils.UNSAFE.getByte(nullMapAddr + i) == 0) {
argument[i - rowsStart] = UdfUtils.UNSAFE.getLong(null, columnAddr + (i * 8L));
} // else is the current row is null
}
} else {
for (int i = rowsStart; i < rowsEnd; ++i) {
argument[i - rowsStart] = UdfUtils.UNSAFE.getLong(null, columnAddr + (i * 8L));
}
}
return argument;
}
public static Object[] convertFloatArg(boolean isNullable, int rowsStart, int rowsEnd, long nullMapAddr,
long columnAddr) {
Float[] argument = new Float[rowsEnd - rowsStart];
if (isNullable) {
for (int i = rowsStart; i < rowsEnd; ++i) {
if (UdfUtils.UNSAFE.getByte(nullMapAddr + i) == 0) {
argument[i - rowsStart] = UdfUtils.UNSAFE.getFloat(null, columnAddr + (i * 4L));
} // else is the current row is null
}
} else {
for (int i = rowsStart; i < rowsEnd; ++i) {
argument[i - rowsStart] = UdfUtils.UNSAFE.getFloat(null, columnAddr + (i * 4L));
}
}
return argument;
}
public static Object[] convertDoubleArg(boolean isNullable, int rowsStart, int rowsEnd, long nullMapAddr,
long columnAddr) {
Double[] argument = new Double[rowsEnd - rowsStart];
if (isNullable) {
for (int i = rowsStart; i < rowsEnd; ++i) {
if (UdfUtils.UNSAFE.getByte(nullMapAddr + i) == 0) {
argument[i - rowsStart] = UdfUtils.UNSAFE.getDouble(null, columnAddr + (i * 8L));
} // else is the current row is null
}
} else {
for (int i = rowsStart; i < rowsEnd; ++i) {
argument[i - rowsStart] = UdfUtils.UNSAFE.getDouble(null, columnAddr + (i * 8L));
}
}
return argument;
}
public static Object[] convertDateArg(Class argTypeClass, boolean isNullable, int rowsStart, int rowsEnd,
long nullMapAddr, long columnAddr) {
Object[] argument = (Object[]) Array.newInstance(argTypeClass, rowsEnd - rowsStart);
if (isNullable) {
for (int i = rowsStart; i < rowsEnd; ++i) {
if (UdfUtils.UNSAFE.getByte(nullMapAddr + i) == 0) {
long value = UdfUtils.UNSAFE.getLong(null, columnAddr + (i * 8L));
argument[i] = UdfUtils.convertDateToJavaDate(value, argTypeClass);
argument[i - rowsStart] = UdfUtils.convertDateToJavaDate(value, argTypeClass);
} // else is the current row is null
}
} else {
for (int i = 0; i < numRows; ++i) {
for (int i = rowsStart; i < rowsEnd; ++i) {
long value = UdfUtils.UNSAFE.getLong(null, columnAddr + (i * 8L));
argument[i] = UdfUtils.convertDateToJavaDate(value, argTypeClass);
argument[i - rowsStart] = UdfUtils.convertDateToJavaDate(value, argTypeClass);
}
}
return argument;
}
public static Object[] convertDateTimeArg(Class argTypeClass, boolean isNullable, int numRows, long nullMapAddr,
long columnAddr) {
Object[] argument = (Object[]) Array.newInstance(argTypeClass, numRows);
public static Object[] convertDateTimeArg(Class argTypeClass, boolean isNullable, int rowsStart, int rowsEnd,
long nullMapAddr, long columnAddr) {
Object[] argument = (Object[]) Array.newInstance(argTypeClass, rowsEnd - rowsStart);
if (isNullable) {
for (int i = 0; i < numRows; ++i) {
for (int i = rowsStart; i < rowsEnd; ++i) {
if (UdfUtils.UNSAFE.getByte(nullMapAddr + i) == 0) {
long value = UdfUtils.UNSAFE.getLong(null, columnAddr + (i * 8L));
argument[i] = UdfUtils
argument[i - rowsStart] = UdfUtils
.convertDateTimeToJavaDateTime(value, argTypeClass);
} // else is the current row is null
}
} else {
for (int i = 0; i < numRows; ++i) {
for (int i = rowsStart; i < rowsEnd; ++i) {
long value = UdfUtils.UNSAFE.getLong(null, columnAddr + (i * 8L));
argument[i] = UdfUtils.convertDateTimeToJavaDateTime(value, argTypeClass);
argument[i - rowsStart] = UdfUtils.convertDateTimeToJavaDateTime(value, argTypeClass);
}
}
return argument;
}
public static Object[] convertDateV2Arg(Class argTypeClass, boolean isNullable, int numRows, long nullMapAddr,
long columnAddr) {
Object[] argument = (Object[]) Array.newInstance(argTypeClass, numRows);
public static Object[] convertDateV2Arg(Class argTypeClass, boolean isNullable, int rowsStart, int rowsEnd,
long nullMapAddr, long columnAddr) {
Object[] argument = (Object[]) Array.newInstance(argTypeClass, rowsEnd - rowsStart);
if (isNullable) {
for (int i = 0; i < numRows; ++i) {
for (int i = rowsStart; i < rowsEnd; ++i) {
if (UdfUtils.UNSAFE.getByte(nullMapAddr + i) == 0) {
int value = UdfUtils.UNSAFE.getInt(null, columnAddr + (i * 4L));
argument[i] = UdfUtils.convertDateV2ToJavaDate(value, argTypeClass);
argument[i - rowsStart] = UdfUtils.convertDateV2ToJavaDate(value, argTypeClass);
} // else is the current row is null
}
} else {
for (int i = 0; i < numRows; ++i) {
for (int i = rowsStart; i < rowsEnd; ++i) {
int value = UdfUtils.UNSAFE.getInt(null, columnAddr + (i * 4L));
argument[i] = UdfUtils.convertDateV2ToJavaDate(value, argTypeClass);
argument[i - rowsStart] = UdfUtils.convertDateV2ToJavaDate(value, argTypeClass);
}
}
return argument;
}
public static Object[] convertDateTimeV2Arg(Class argTypeClass, boolean isNullable, int numRows, long nullMapAddr,
long columnAddr) {
Object[] argument = (Object[]) Array.newInstance(argTypeClass, numRows);
public static Object[] convertDateTimeV2Arg(Class argTypeClass, boolean isNullable, int rowsStart, int rowsEnd,
long nullMapAddr, long columnAddr) {
Object[] argument = (Object[]) Array.newInstance(argTypeClass, rowsEnd - rowsStart);
if (isNullable) {
for (int i = 0; i < numRows; ++i) {
for (int i = rowsStart; i < rowsEnd; ++i) {
if (UdfUtils.UNSAFE.getByte(null, nullMapAddr + i) == 0) {
long value = UdfUtils.UNSAFE.getLong(columnAddr + (i * 8L));
argument[i] = UdfUtils
argument[i - rowsStart] = UdfUtils
.convertDateTimeV2ToJavaDateTime(value, argTypeClass);
} // else is the current row is null
}
} else {
for (int i = 0; i < numRows; ++i) {
for (int i = rowsStart; i < rowsEnd; ++i) {
long value = UdfUtils.UNSAFE.getLong(null, columnAddr + (i * 8L));
argument[i] = UdfUtils
argument[i - rowsStart] = UdfUtils
.convertDateTimeV2ToJavaDateTime(value, argTypeClass);
}
}
return argument;
}
public static Object[] convertLargeIntArg(boolean isNullable, int numRows, long nullMapAddr, long columnAddr) {
BigInteger[] argument = new BigInteger[numRows];
public static Object[] convertLargeIntArg(boolean isNullable, int rowsStart, int rowsEnd, long nullMapAddr,
long columnAddr) {
BigInteger[] argument = new BigInteger[rowsEnd - rowsStart];
byte[] bytes = new byte[16];
if (isNullable) {
for (int i = 0; i < numRows; ++i) {
for (int i = rowsStart; i < rowsEnd; ++i) {
if (UdfUtils.UNSAFE.getByte(nullMapAddr + i) == 0) {
UdfUtils.copyMemory(null, columnAddr + (i * 16L), bytes, UdfUtils.BYTE_ARRAY_OFFSET, 16);
argument[i] = new BigInteger(UdfUtils.convertByteOrder(bytes));
argument[i - rowsStart] = new BigInteger(UdfUtils.convertByteOrder(bytes));
} // else is the current row is null
}
} else {
for (int i = 0; i < numRows; ++i) {
for (int i = rowsStart; i < rowsEnd; ++i) {
UdfUtils.copyMemory(null, columnAddr + (i * 16L), bytes, UdfUtils.BYTE_ARRAY_OFFSET, 16);
argument[i] = new BigInteger(UdfUtils.convertByteOrder(bytes));
argument[i - rowsStart] = new BigInteger(UdfUtils.convertByteOrder(bytes));
}
}
return argument;
}
public static Object[] convertDecimalArg(int scale, long typeLen, boolean isNullable, int numRows, long nullMapAddr,
long columnAddr) {
BigDecimal[] argument = new BigDecimal[numRows];
public static Object[] convertDecimalArg(int scale, long typeLen, boolean isNullable, int rowsStart, int rowsEnd,
long nullMapAddr, long columnAddr) {
BigDecimal[] argument = new BigDecimal[rowsEnd - rowsStart];
byte[] bytes = new byte[(int) typeLen];
if (isNullable) {
for (int i = 0; i < numRows; ++i) {
for (int i = rowsStart; i < rowsEnd; ++i) {
if (UdfUtils.UNSAFE.getByte(nullMapAddr + i) == 0) {
UdfUtils.copyMemory(null, columnAddr + (i * typeLen), bytes, UdfUtils.BYTE_ARRAY_OFFSET, typeLen);
BigInteger bigInteger = new BigInteger(UdfUtils.convertByteOrder(bytes));
argument[i] = new BigDecimal(bigInteger, scale); //show to pass scale info
argument[i - rowsStart] = new BigDecimal(bigInteger, scale); //show to pass scale info
} // else is the current row is null
}
} else {
for (int i = 0; i < numRows; ++i) {
for (int i = rowsStart; i < rowsEnd; ++i) {
UdfUtils.copyMemory(null, columnAddr + (i * typeLen), bytes, UdfUtils.BYTE_ARRAY_OFFSET, typeLen);
BigInteger bigInteger = new BigInteger(UdfUtils.convertByteOrder(bytes));
argument[i] = new BigDecimal(bigInteger, scale);
argument[i - rowsStart] = new BigDecimal(bigInteger, scale);
}
}
return argument;
}
public static Object[] convertStringArg(boolean isNullable, int numRows, long nullMapAddr,
public static Object[] convertStringArg(boolean isNullable, int rowsStart, int rowsEnd, long nullMapAddr,
long charsAddr, long offsetsAddr) {
String[] argument = new String[numRows];
String[] argument = new String[rowsEnd - rowsStart];
Preconditions.checkState(UdfUtils.UNSAFE.getInt(null, offsetsAddr + 4L * (0 - 1)) == 0,
"offsetsAddr[-1] should be 0;");
final int totalLen = UdfUtils.UNSAFE.getInt(null, offsetsAddr + (rowsEnd - 1) * 4L);
byte[] bytes = new byte[totalLen];
UdfUtils.copyMemory(null, charsAddr, bytes, UdfUtils.BYTE_ARRAY_OFFSET, totalLen);
if (isNullable) {
for (int row = 0; row < numRows; ++row) {
for (int row = rowsStart; row < rowsEnd; ++row) {
if (UdfUtils.UNSAFE.getByte(nullMapAddr + row) == 0) {
int offset = UdfUtils.UNSAFE.getInt(null, offsetsAddr + row * 4L);
int numBytes = offset - UdfUtils.UNSAFE.getInt(null, offsetsAddr + 4L * (row - 1));
long base = charsAddr + offset - numBytes;
byte[] bytes = new byte[numBytes];
UdfUtils.copyMemory(null, base, bytes, UdfUtils.BYTE_ARRAY_OFFSET, numBytes);
argument[row] = new String(bytes, StandardCharsets.UTF_8);
int prevOffset = UdfUtils.UNSAFE.getInt(null, offsetsAddr + 4L * (row - 1));
int currOffset = UdfUtils.UNSAFE.getInt(null, offsetsAddr + row * 4L);
argument[row - rowsStart] = new String(bytes, prevOffset, currOffset - prevOffset,
StandardCharsets.UTF_8);
} // else is the current row is null
}
} else {
for (int row = 0; row < numRows; ++row) {
int offset = UdfUtils.UNSAFE.getInt(null, offsetsAddr + row * 4L);
int numBytes = offset - UdfUtils.UNSAFE.getInt(null, offsetsAddr + 4L * (row - 1));
long base = charsAddr + offset - numBytes;
byte[] bytes = new byte[numBytes];
UdfUtils.copyMemory(null, base, bytes, UdfUtils.BYTE_ARRAY_OFFSET, numBytes);
argument[row] = new String(bytes, StandardCharsets.UTF_8);
for (int row = rowsStart; row < rowsEnd; ++row) {
int prevOffset = UdfUtils.UNSAFE.getInt(null, offsetsAddr + 4L * (row - 1));
int currOffset = UdfUtils.UNSAFE.getInt(null, offsetsAddr + 4L * row);
argument[row - rowsStart] = new String(bytes, prevOffset, currOffset - prevOffset,
StandardCharsets.UTF_8);
}
}
return argument;
@ -1314,7 +1320,7 @@ public class UdfConvert {
}
//////////////////////////////////////////convertArray///////////////////////////////////////////////////////////
public static void convertArrayBooleanArg(Object[] argument, int row, int currentRowNum, long offsetStart,
public static ArrayList<Boolean> convertArrayBooleanArg(int row, int currentRowNum, long offsetStart,
boolean isNullable, long nullMapAddr, long nestedNullMapAddr, long dataAddr) {
ArrayList<Boolean> data = null;
if (isNullable) {
@ -1340,10 +1346,10 @@ public class UdfConvert {
}
} // for loop
} // end for all current row
argument[row] = data;
return data;
}
public static void convertArrayTinyIntArg(Object[] argument, int row, int currentRowNum, long offsetStart,
public static ArrayList<Byte> convertArrayTinyIntArg(int row, int currentRowNum, long offsetStart,
boolean isNullable, long nullMapAddr, long nestedNullMapAddr, long dataAddr) {
ArrayList<Byte> data = null;
if (isNullable) {
@ -1369,10 +1375,10 @@ public class UdfConvert {
}
} // for loop
} // end for all current row
argument[row] = data;
return data;
}
public static void convertArraySmallIntArg(Object[] argument, int row, int currentRowNum, long offsetStart,
public static ArrayList<Short> convertArraySmallIntArg(int row, int currentRowNum, long offsetStart,
boolean isNullable, long nullMapAddr, long nestedNullMapAddr, long dataAddr) {
ArrayList<Short> data = null;
if (isNullable) {
@ -1398,10 +1404,10 @@ public class UdfConvert {
}
} // for loop
} // end for all current row
argument[row] = data;
return data;
}
public static void convertArrayIntArg(Object[] argument, int row, int currentRowNum, long offsetStart,
public static ArrayList<Integer> convertArrayIntArg(int row, int currentRowNum, long offsetStart,
boolean isNullable, long nullMapAddr, long nestedNullMapAddr, long dataAddr) {
ArrayList<Integer> data = null;
if (isNullable) {
@ -1427,10 +1433,10 @@ public class UdfConvert {
}
} // for loop
} // end for all current row
argument[row] = data;
return data;
}
public static void convertArrayBigIntArg(Object[] argument, int row, int currentRowNum, long offsetStart,
public static ArrayList<Long> convertArrayBigIntArg(int row, int currentRowNum, long offsetStart,
boolean isNullable, long nullMapAddr, long nestedNullMapAddr, long dataAddr) {
ArrayList<Long> data = null;
if (isNullable) {
@ -1456,10 +1462,10 @@ public class UdfConvert {
}
} // for loop
} // end for all current row
argument[row] = data;
return data;
}
public static void convertArrayFloatArg(Object[] argument, int row, int currentRowNum, long offsetStart,
public static ArrayList<Float> convertArrayFloatArg(int row, int currentRowNum, long offsetStart,
boolean isNullable, long nullMapAddr, long nestedNullMapAddr, long dataAddr) {
ArrayList<Float> data = null;
if (isNullable) {
@ -1485,10 +1491,10 @@ public class UdfConvert {
}
} // for loop
} // end for all current row
argument[row] = data;
return data;
}
public static void convertArrayDoubleArg(Object[] argument, int row, int currentRowNum, long offsetStart,
public static ArrayList<Double> convertArrayDoubleArg(int row, int currentRowNum, long offsetStart,
boolean isNullable, long nullMapAddr, long nestedNullMapAddr, long dataAddr) {
ArrayList<Double> data = null;
if (isNullable) {
@ -1514,10 +1520,10 @@ public class UdfConvert {
}
} // for loop
} // end for all current row
argument[row] = data;
return data;
}
public static void convertArrayDateArg(Object[] argument, int row, int currentRowNum, long offsetStart,
public static ArrayList<LocalDate> convertArrayDateArg(int row, int currentRowNum, long offsetStart,
boolean isNullable, long nullMapAddr, long nestedNullMapAddr, long dataAddr) {
ArrayList<LocalDate> data = null;
if (isNullable) {
@ -1549,10 +1555,10 @@ public class UdfConvert {
}
} // for loop
} // end for all current row
argument[row] = data;
return data;
}
public static void convertArrayDateTimeArg(Object[] argument, int row, int currentRowNum, long offsetStart,
public static ArrayList<LocalDateTime> convertArrayDateTimeArg(int row, int currentRowNum, long offsetStart,
boolean isNullable, long nullMapAddr, long nestedNullMapAddr, long dataAddr) {
ArrayList<LocalDateTime> data = null;
if (isNullable) {
@ -1582,10 +1588,10 @@ public class UdfConvert {
}
} // for loop
} // end for all current row
argument[row] = data;
return data;
}
public static void convertArrayDateV2Arg(Object[] argument, int row, int currentRowNum, long offsetStart,
public static ArrayList<LocalDate> convertArrayDateV2Arg(int row, int currentRowNum, long offsetStart,
boolean isNullable, long nullMapAddr, long nestedNullMapAddr, long dataAddr) {
ArrayList<LocalDate> data = null;
if (isNullable) {
@ -1613,10 +1619,10 @@ public class UdfConvert {
}
} // for loop
} // end for all current row
argument[row] = data;
return data;
}
public static void convertArrayDateTimeV2Arg(Object[] argument, int row, int currentRowNum, long offsetStart,
public static ArrayList<LocalDateTime> convertArrayDateTimeV2Arg(int row, int currentRowNum, long offsetStart,
boolean isNullable, long nullMapAddr, long nestedNullMapAddr, long dataAddr) {
ArrayList<LocalDateTime> data = null;
if (isNullable) {
@ -1646,10 +1652,10 @@ public class UdfConvert {
}
} // for loop
} // end for all current row
argument[row] = data;
return data;
}
public static void convertArrayLargeIntArg(Object[] argument, int row, int currentRowNum, long offsetStart,
public static ArrayList<BigInteger> convertArrayLargeIntArg(int row, int currentRowNum, long offsetStart,
boolean isNullable, long nullMapAddr, long nestedNullMapAddr, long dataAddr) {
ArrayList<BigInteger> data = null;
byte[] bytes = new byte[16];
@ -1678,10 +1684,10 @@ public class UdfConvert {
}
} // for loop
} // end for all current row
argument[row] = data;
return data;
}
public static void convertArrayDecimalArg(int scale, long typeLen, Object[] argument, int row, int currentRowNum,
public static ArrayList<BigDecimal> convertArrayDecimalArg(int scale, long typeLen, int row, int currentRowNum,
long offsetStart,
boolean isNullable, long nullMapAddr, long nestedNullMapAddr, long dataAddr) {
ArrayList<BigDecimal> data = null;
@ -1713,10 +1719,10 @@ public class UdfConvert {
}
} // for loop
} // end for all current row
argument[row] = data;
return data;
}
public static void convertArrayStringArg(Object[] argument, int row, int currentRowNum, long offsetStart,
public static ArrayList<String> convertArrayStringArg(int row, int currentRowNum, long offsetStart,
boolean isNullable, long nullMapAddr, long nestedNullMapAddr, long dataAddr, long strOffsetAddr) {
ArrayList<String> data = null;
if (isNullable) {
@ -1755,6 +1761,6 @@ public class UdfConvert {
}
}
}
argument[row] = data;
return data;
}
}

View File

@ -24,6 +24,7 @@ import org.apache.doris.common.jni.utils.UdfUtils;
import org.apache.doris.common.jni.utils.UdfUtils.JavaUdfDataType;
import org.apache.doris.thrift.TJavaUdfExecutorCtorParams;
import com.esotericsoftware.reflectasm.MethodAccess;
import com.google.common.base.Joiner;
import com.google.common.base.Preconditions;
import com.google.common.collect.Lists;
@ -50,6 +51,8 @@ public class UdfExecutor extends BaseExecutor {
private long rowIdx;
private long batchSizePtr;
private int evaluateIndex;
private MethodAccess methodAccess;
/**
* Create a UdfExecutor, using parameters from a serialized thrift object. Used by
@ -113,166 +116,14 @@ public class UdfExecutor extends BaseExecutor {
public Object[] convertBasicArguments(int argIdx, boolean isNullable, int numRows, long nullMapAddr,
long columnAddr, long strOffsetAddr) {
switch (argTypes[argIdx]) {
case BOOLEAN:
return UdfConvert.convertBooleanArg(isNullable, numRows, nullMapAddr, columnAddr);
case TINYINT:
return UdfConvert.convertTinyIntArg(isNullable, numRows, nullMapAddr, columnAddr);
case SMALLINT:
return UdfConvert.convertSmallIntArg(isNullable, numRows, nullMapAddr, columnAddr);
case INT:
return UdfConvert.convertIntArg(isNullable, numRows, nullMapAddr, columnAddr);
case BIGINT:
return UdfConvert.convertBigIntArg(isNullable, numRows, nullMapAddr, columnAddr);
case LARGEINT:
return UdfConvert.convertLargeIntArg(isNullable, numRows, nullMapAddr, columnAddr);
case FLOAT:
return UdfConvert.convertFloatArg(isNullable, numRows, nullMapAddr, columnAddr);
case DOUBLE:
return UdfConvert.convertDoubleArg(isNullable, numRows, nullMapAddr, columnAddr);
case CHAR:
case VARCHAR:
case STRING:
return UdfConvert.convertStringArg(isNullable, numRows, nullMapAddr, columnAddr, strOffsetAddr);
case DATE: // udaf maybe argClass[i + argClassOffset] need add +1
return UdfConvert.convertDateArg(argClass[argIdx], isNullable, numRows, nullMapAddr, columnAddr);
case DATETIME:
return UdfConvert.convertDateTimeArg(argClass[argIdx], isNullable, numRows, nullMapAddr, columnAddr);
case DATEV2:
return UdfConvert.convertDateV2Arg(argClass[argIdx], isNullable, numRows, nullMapAddr, columnAddr);
case DATETIMEV2:
return UdfConvert.convertDateTimeV2Arg(argClass[argIdx], isNullable, numRows, nullMapAddr, columnAddr);
case DECIMALV2:
case DECIMAL128:
return UdfConvert.convertDecimalArg(argTypes[argIdx].getScale(), 16L, isNullable, numRows, nullMapAddr,
columnAddr);
case DECIMAL32:
return UdfConvert.convertDecimalArg(argTypes[argIdx].getScale(), 4L, isNullable, numRows, nullMapAddr,
columnAddr);
case DECIMAL64:
return UdfConvert.convertDecimalArg(argTypes[argIdx].getScale(), 8L, isNullable, numRows, nullMapAddr,
columnAddr);
default: {
LOG.info("Not support type: " + argTypes[argIdx].toString());
Preconditions.checkState(false, "Not support type: " + argTypes[argIdx].toString());
break;
}
}
return null;
return convertBasicArg(true, argIdx, isNullable, 0, numRows, nullMapAddr, columnAddr, strOffsetAddr);
}
public Object[] convertArrayArguments(int argIdx, boolean isNullable, int numRows, long nullMapAddr,
long offsetsAddr, long nestedNullMapAddr, long dataAddr, long strOffsetAddr) {
Object[] argument = (Object[]) Array.newInstance(ArrayList.class, numRows);
for (int row = 0; row < numRows; ++row) {
long offsetStart = UdfUtils.UNSAFE.getLong(null, offsetsAddr + 8L * (row - 1));
long offsetEnd = UdfUtils.UNSAFE.getLong(null, offsetsAddr + 8L * (row));
int currentRowNum = (int) (offsetEnd - offsetStart);
switch (argTypes[argIdx].getItemType().getPrimitiveType()) {
case BOOLEAN: {
UdfConvert
.convertArrayBooleanArg(argument, row, currentRowNum, offsetStart, isNullable, nullMapAddr,
nestedNullMapAddr, dataAddr);
break;
}
case TINYINT: {
UdfConvert
.convertArrayTinyIntArg(argument, row, currentRowNum, offsetStart, isNullable, nullMapAddr,
nestedNullMapAddr, dataAddr);
break;
}
case SMALLINT: {
UdfConvert
.convertArraySmallIntArg(argument, row, currentRowNum, offsetStart, isNullable, nullMapAddr,
nestedNullMapAddr, dataAddr);
break;
}
case INT: {
UdfConvert.convertArrayIntArg(argument, row, currentRowNum, offsetStart, isNullable, nullMapAddr,
nestedNullMapAddr, dataAddr);
break;
}
case BIGINT: {
UdfConvert.convertArrayBigIntArg(argument, row, currentRowNum, offsetStart, isNullable, nullMapAddr,
nestedNullMapAddr, dataAddr);
break;
}
case LARGEINT: {
UdfConvert
.convertArrayLargeIntArg(argument, row, currentRowNum, offsetStart, isNullable, nullMapAddr,
nestedNullMapAddr, dataAddr);
break;
}
case FLOAT: {
UdfConvert.convertArrayFloatArg(argument, row, currentRowNum, offsetStart, isNullable, nullMapAddr,
nestedNullMapAddr, dataAddr);
break;
}
case DOUBLE: {
UdfConvert.convertArrayDoubleArg(argument, row, currentRowNum, offsetStart, isNullable, nullMapAddr,
nestedNullMapAddr, dataAddr);
break;
}
case CHAR:
case VARCHAR:
case STRING: {
UdfConvert.convertArrayStringArg(argument, row, currentRowNum, offsetStart, isNullable, nullMapAddr,
nestedNullMapAddr, dataAddr, strOffsetAddr);
break;
}
case DATE: {
UdfConvert.convertArrayDateArg(argument, row, currentRowNum, offsetStart, isNullable, nullMapAddr,
nestedNullMapAddr, dataAddr);
break;
}
case DATETIME: {
UdfConvert
.convertArrayDateTimeArg(argument, row, currentRowNum, offsetStart, isNullable, nullMapAddr,
nestedNullMapAddr, dataAddr);
break;
}
case DATEV2: {
UdfConvert.convertArrayDateV2Arg(argument, row, currentRowNum, offsetStart, isNullable, nullMapAddr,
nestedNullMapAddr, dataAddr);
break;
}
case DATETIMEV2: {
UdfConvert.convertArrayDateTimeV2Arg(argument, row, currentRowNum, offsetStart, isNullable,
nullMapAddr,
nestedNullMapAddr, dataAddr);
break;
}
case DECIMALV2:
case DECIMAL128: {
UdfConvert.convertArrayDecimalArg(argTypes[argIdx].getScale(), 16L, argument, row, currentRowNum,
offsetStart, isNullable,
nullMapAddr,
nestedNullMapAddr, dataAddr);
break;
}
case DECIMAL32: {
UdfConvert.convertArrayDecimalArg(argTypes[argIdx].getScale(), 4L, argument, row, currentRowNum,
offsetStart, isNullable,
nullMapAddr,
nestedNullMapAddr, dataAddr);
break;
}
case DECIMAL64: {
UdfConvert.convertArrayDecimalArg(argTypes[argIdx].getScale(), 8L, argument, row, currentRowNum,
offsetStart, isNullable,
nullMapAddr,
nestedNullMapAddr, dataAddr);
break;
}
default: {
LOG.info("Not support: " + argTypes[argIdx]);
Preconditions.checkState(false, "Not support type " + argTypes[argIdx].toString());
break;
}
}
}
return argument;
return convertArrayArg(argIdx, isNullable, 0, numRows, nullMapAddr, offsetsAddr, nestedNullMapAddr, dataAddr,
strOffsetAddr);
}
/**
@ -287,7 +138,7 @@ public class UdfExecutor extends BaseExecutor {
for (int j = 0; j < column.length; ++j) {
parameters[j] = inputs[j][i];
}
result[i] = method.invoke(udf, parameters);
result[i] = methodAccess.invoke(udf, evaluateIndex, parameters);
}
return result;
} catch (Exception e) {
@ -581,6 +432,7 @@ public class UdfExecutor extends BaseExecutor {
loader = ClassLoader.getSystemClassLoader();
}
Class<?> c = Class.forName(className, true, loader);
methodAccess = MethodAccess.get(c);
Constructor<?> ctor = c.getConstructor();
udf = ctor.newInstance();
Method[] methods = c.getMethods();
@ -597,6 +449,7 @@ public class UdfExecutor extends BaseExecutor {
continue;
}
method = m;
evaluateIndex = methodAccess.getIndex(UDF_FUNCTION_NAME);
Pair<Boolean, JavaUdfDataType> returnType;
if (argClass.length == 0 && parameterTypes.length == 0) {
// Special case where the UDF doesn't take any input args