[refactor](jni) unified jni framework for java udf (#25302)
Use the unified jni framework to refactor java udf. The unified jni framework takes VectorTable as the container to transform data between c++ and java, and hide the details of data format conversion. In addition, the unified framework supports complex and nested types. The performance of basic types remains consistent, with a 30% improvement in string types and an order of magnitude improvement in complex types.
This commit is contained in:
@ -17,17 +17,18 @@
|
||||
|
||||
package org.apache.doris.udf;
|
||||
|
||||
import org.apache.doris.catalog.PrimitiveType;
|
||||
import org.apache.doris.catalog.Type;
|
||||
import org.apache.doris.common.Pair;
|
||||
import org.apache.doris.common.exception.UdfRuntimeException;
|
||||
import org.apache.doris.common.jni.utils.JavaUdfDataType;
|
||||
import org.apache.doris.common.jni.utils.UdfUtils;
|
||||
import org.apache.doris.common.jni.vec.ColumnValueConverter;
|
||||
import org.apache.doris.common.jni.vec.VectorTable;
|
||||
import org.apache.doris.thrift.TJavaUdfExecutorCtorParams;
|
||||
import org.apache.doris.thrift.TPrimitiveType;
|
||||
|
||||
import com.esotericsoftware.reflectasm.MethodAccess;
|
||||
import com.google.common.base.Joiner;
|
||||
import com.google.common.base.Preconditions;
|
||||
import com.google.common.collect.Lists;
|
||||
import org.apache.log4j.Logger;
|
||||
|
||||
@ -35,7 +36,11 @@ import java.lang.reflect.Array;
|
||||
import java.lang.reflect.Constructor;
|
||||
import java.lang.reflect.Method;
|
||||
import java.net.MalformedURLException;
|
||||
import java.time.LocalDate;
|
||||
import java.time.LocalDateTime;
|
||||
import java.util.ArrayList;
|
||||
import java.util.HashMap;
|
||||
import java.util.Map;
|
||||
|
||||
public class UdfExecutor extends BaseExecutor {
|
||||
// private static final java.util.logging.Logger LOG =
|
||||
@ -46,6 +51,8 @@ public class UdfExecutor extends BaseExecutor {
|
||||
|
||||
private int evaluateIndex;
|
||||
|
||||
private VectorTable outputTable = null;
|
||||
|
||||
/**
|
||||
* Create a UdfExecutor, using parameters from a serialized thrift object. Used by
|
||||
* the backend.
|
||||
@ -59,98 +66,195 @@ public class UdfExecutor extends BaseExecutor {
|
||||
*/
|
||||
@Override
|
||||
public void close() {
|
||||
// inputTable is released by c++, only release outputTable
|
||||
if (outputTable != null) {
|
||||
outputTable.close();
|
||||
}
|
||||
// We are now un-usable (because the class loader has been
|
||||
// closed), so null out method_ and classLoader_.
|
||||
method = null;
|
||||
super.close();
|
||||
}
|
||||
|
||||
public Object[] convertBasicArguments(int argIdx, boolean isNullable, int numRows, long nullMapAddr,
|
||||
long columnAddr, long strOffsetAddr) {
|
||||
return convertBasicArg(true, argIdx, isNullable, 0, numRows, nullMapAddr, columnAddr, strOffsetAddr);
|
||||
private ColumnValueConverter getInputConverter(TPrimitiveType primitiveType, Class clz) {
|
||||
switch (primitiveType) {
|
||||
case DATE:
|
||||
case DATEV2: {
|
||||
if (java.util.Date.class.equals(clz)) {
|
||||
return (Object[] columnData) -> {
|
||||
Object[] result = new java.util.Date[columnData.length];
|
||||
for (int i = 0; i < columnData.length; ++i) {
|
||||
if (columnData[i] != null) {
|
||||
LocalDate v = (LocalDate) columnData[i];
|
||||
result[i] = new java.util.Date(v.getYear() - 1900, v.getMonthValue() - 1,
|
||||
v.getDayOfMonth());
|
||||
}
|
||||
}
|
||||
return result;
|
||||
};
|
||||
} else if (org.joda.time.LocalDate.class.equals(clz)) {
|
||||
return (Object[] columnData) -> {
|
||||
Object[] result = new org.joda.time.LocalDate[columnData.length];
|
||||
for (int i = 0; i < columnData.length; ++i) {
|
||||
if (columnData[i] != null) {
|
||||
LocalDate v = (LocalDate) columnData[i];
|
||||
result[i] = new org.joda.time.LocalDate(v.getYear(), v.getMonthValue(),
|
||||
v.getDayOfMonth());
|
||||
}
|
||||
}
|
||||
return result;
|
||||
};
|
||||
} else if (!LocalDate.class.equals(clz)) {
|
||||
throw new RuntimeException("Unsupported date type: " + clz.getCanonicalName());
|
||||
}
|
||||
break;
|
||||
}
|
||||
case DATETIME:
|
||||
case DATETIMEV2: {
|
||||
if (org.joda.time.DateTime.class.equals(clz)) {
|
||||
return (Object[] columnData) -> {
|
||||
Object[] result = new org.joda.time.DateTime[columnData.length];
|
||||
for (int i = 0; i < columnData.length; ++i) {
|
||||
if (columnData[i] != null) {
|
||||
LocalDateTime v = (LocalDateTime) columnData[i];
|
||||
result[i] = new org.joda.time.DateTime(v.getYear(), v.getMonthValue(),
|
||||
v.getDayOfMonth(), v.getHour(),
|
||||
v.getMinute(), v.getSecond(), v.getNano() / 1000000);
|
||||
}
|
||||
}
|
||||
return result;
|
||||
};
|
||||
} else if (org.joda.time.LocalDateTime.class.equals(clz)) {
|
||||
return (Object[] columnData) -> {
|
||||
Object[] result = new org.joda.time.LocalDateTime[columnData.length];
|
||||
for (int i = 0; i < columnData.length; ++i) {
|
||||
if (columnData[i] != null) {
|
||||
LocalDateTime v = (LocalDateTime) columnData[i];
|
||||
result[i] = new org.joda.time.LocalDateTime(v.getYear(), v.getMonthValue(),
|
||||
v.getDayOfMonth(), v.getHour(),
|
||||
v.getMinute(), v.getSecond(), v.getNano() / 1000000);
|
||||
}
|
||||
}
|
||||
return result;
|
||||
};
|
||||
} else if (!LocalDateTime.class.equals(clz)) {
|
||||
throw new RuntimeException("Unsupported date type: " + clz.getCanonicalName());
|
||||
}
|
||||
break;
|
||||
}
|
||||
default:
|
||||
break;
|
||||
}
|
||||
return null;
|
||||
}
|
||||
|
||||
public Object[] convertArrayArguments(int argIdx, boolean isNullable, int numRows, long nullMapAddr,
|
||||
long offsetsAddr, long nestedNullMapAddr, long dataAddr, long strOffsetAddr) {
|
||||
return convertArrayArg(argIdx, isNullable, 0, numRows, nullMapAddr, offsetsAddr, nestedNullMapAddr, dataAddr,
|
||||
strOffsetAddr);
|
||||
private Map<Integer, ColumnValueConverter> getInputConverters(int numColumns) {
|
||||
Map<Integer, ColumnValueConverter> converters = new HashMap<>();
|
||||
for (int j = 0; j < numColumns; ++j) {
|
||||
ColumnValueConverter converter = getInputConverter(argTypes[j].getPrimitiveType(), argClass[j]);
|
||||
if (converter != null) {
|
||||
converters.put(j, converter);
|
||||
}
|
||||
}
|
||||
return converters;
|
||||
}
|
||||
|
||||
public Object[] convertMapArguments(int argIdx, boolean isNullable, int numRows, long nullMapAddr,
|
||||
long offsetsAddr, long keyNestedNullMapAddr, long keyDataAddr, long keyStrOffsetAddr,
|
||||
long valueNestedNullMapAddr, long valueDataAddr, long valueStrOffsetAddr) {
|
||||
PrimitiveType keyType = argTypes[argIdx].getKeyType().getPrimitiveType();
|
||||
PrimitiveType valueType = argTypes[argIdx].getValueType().getPrimitiveType();
|
||||
Object[] keyCol = convertMapArg(keyType, argIdx, isNullable, 0, numRows, nullMapAddr, offsetsAddr,
|
||||
keyNestedNullMapAddr, keyDataAddr,
|
||||
keyStrOffsetAddr, argTypes[argIdx].getKeyScale());
|
||||
Object[] valueCol = convertMapArg(valueType, argIdx, isNullable, 0, numRows, nullMapAddr, offsetsAddr,
|
||||
valueNestedNullMapAddr, valueDataAddr,
|
||||
valueStrOffsetAddr, argTypes[argIdx].getValueScale());
|
||||
return buildHashMap(keyType, valueType, keyCol, valueCol);
|
||||
private ColumnValueConverter getOutputConverter() {
|
||||
Class clz = method.getReturnType();
|
||||
switch (retType.getPrimitiveType()) {
|
||||
case DATE:
|
||||
case DATEV2: {
|
||||
if (java.util.Date.class.equals(clz)) {
|
||||
return (Object[] columnData) -> {
|
||||
Object[] result = new LocalDate[columnData.length];
|
||||
for (int i = 0; i < columnData.length; ++i) {
|
||||
if (columnData[i] != null) {
|
||||
java.util.Date v = (java.util.Date) columnData[i];
|
||||
result[i] = LocalDate.of(v.getYear() + 1900, v.getMonth() + 1, v.getDate());
|
||||
}
|
||||
}
|
||||
return result;
|
||||
};
|
||||
} else if (org.joda.time.LocalDate.class.equals(clz)) {
|
||||
return (Object[] columnData) -> {
|
||||
Object[] result = new LocalDate[columnData.length];
|
||||
for (int i = 0; i < columnData.length; ++i) {
|
||||
if (columnData[i] != null) {
|
||||
org.joda.time.LocalDate v = (org.joda.time.LocalDate) columnData[i];
|
||||
result[i] = LocalDate.of(v.getYear(), v.getMonthOfYear(), v.getDayOfMonth());
|
||||
}
|
||||
}
|
||||
return result;
|
||||
};
|
||||
} else if (!LocalDate.class.equals(clz)) {
|
||||
throw new RuntimeException("Unsupported date type: " + clz.getCanonicalName());
|
||||
}
|
||||
break;
|
||||
}
|
||||
case DATETIME:
|
||||
case DATETIMEV2: {
|
||||
if (org.joda.time.DateTime.class.equals(clz)) {
|
||||
return (Object[] columnData) -> {
|
||||
Object[] result = new LocalDateTime[columnData.length];
|
||||
for (int i = 0; i < columnData.length; ++i) {
|
||||
if (columnData[i] != null) {
|
||||
org.joda.time.DateTime v = (org.joda.time.DateTime) columnData[i];
|
||||
result[i] = LocalDateTime.of(v.getYear(), v.getMonthOfYear(), v.getDayOfMonth(),
|
||||
v.getHourOfDay(),
|
||||
v.getMinuteOfHour(), v.getSecondOfMinute(), v.getMillisOfSecond() * 1000000);
|
||||
}
|
||||
}
|
||||
return result;
|
||||
};
|
||||
} else if (org.joda.time.LocalDateTime.class.equals(clz)) {
|
||||
return (Object[] columnData) -> {
|
||||
Object[] result = new LocalDateTime[columnData.length];
|
||||
for (int i = 0; i < columnData.length; ++i) {
|
||||
if (columnData[i] != null) {
|
||||
org.joda.time.LocalDateTime v = (org.joda.time.LocalDateTime) columnData[i];
|
||||
result[i] = LocalDateTime.of(v.getYear(), v.getMonthOfYear(), v.getDayOfMonth(),
|
||||
v.getHourOfDay(),
|
||||
v.getMinuteOfHour(), v.getSecondOfMinute(), v.getMillisOfSecond() * 1000000);
|
||||
}
|
||||
}
|
||||
return result;
|
||||
};
|
||||
} else if (!LocalDateTime.class.equals(clz)) {
|
||||
throw new RuntimeException("Unsupported date type: " + clz.getCanonicalName());
|
||||
}
|
||||
break;
|
||||
}
|
||||
default:
|
||||
break;
|
||||
}
|
||||
return null;
|
||||
}
|
||||
|
||||
/**
|
||||
* Evaluates the UDF with 'args' as the input to the UDF.
|
||||
*/
|
||||
public Object[] evaluate(int numRows, Object[] column) throws UdfRuntimeException {
|
||||
public long evaluate(Map<String, String> inputParams, Map<String, String> outputParams) throws UdfRuntimeException {
|
||||
try {
|
||||
VectorTable inputTable = VectorTable.createReadableTable(inputParams);
|
||||
int numRows = inputTable.getNumRows();
|
||||
int numColumns = inputTable.getNumColumns();
|
||||
Object[] result = (Object[]) Array.newInstance(method.getReturnType(), numRows);
|
||||
Object[][] inputs = (Object[][]) column;
|
||||
Object[] parameters = new Object[inputs.length];
|
||||
Object[][] inputs = inputTable.getMaterializedData(getInputConverters(numColumns));
|
||||
Object[] parameters = new Object[numColumns];
|
||||
for (int i = 0; i < numRows; ++i) {
|
||||
for (int j = 0; j < column.length; ++j) {
|
||||
for (int j = 0; j < numColumns; ++j) {
|
||||
parameters[j] = inputs[j][i];
|
||||
}
|
||||
result[i] = methodAccess.invoke(udf, evaluateIndex, parameters);
|
||||
}
|
||||
return result;
|
||||
} catch (Exception e) {
|
||||
LOG.info("evaluate exception: " + debugString());
|
||||
LOG.info("evaluate(int numRows, Object[] column) Exception: " + e.toString());
|
||||
throw new UdfRuntimeException("UDF failed to evaluate", e);
|
||||
}
|
||||
}
|
||||
|
||||
public void copyBatchBasicResult(boolean isNullable, int numRows, Object[] result, long nullMapAddr,
|
||||
long resColumnAddr, long strOffsetAddr) {
|
||||
copyBatchBasicResultImpl(isNullable, numRows, result, nullMapAddr, resColumnAddr, strOffsetAddr, getMethod());
|
||||
}
|
||||
|
||||
public void copyBatchArrayResult(boolean isNullable, int numRows, Object[] result, long nullMapAddr,
|
||||
long offsetsAddr, long nestedNullMapAddr, long dataAddr, long strOffsetAddr) {
|
||||
Preconditions.checkState(result.length == numRows,
|
||||
"copyBatchArrayResult result size should equal;");
|
||||
copyBatchArrayResultImpl(isNullable, numRows, result, nullMapAddr, offsetsAddr, nestedNullMapAddr, dataAddr,
|
||||
strOffsetAddr, retType.getItemType().getPrimitiveType(), retType.getScale());
|
||||
}
|
||||
|
||||
public void copyBatchMapResult(boolean isNullable, int numRows, Object[] result, long nullMapAddr,
|
||||
long offsetsAddr, long keyNsestedNullMapAddr, long keyDataAddr, long keyStrOffsetAddr,
|
||||
long valueNsestedNullMapAddr, long valueDataAddr, long valueStrOffsetAddr) {
|
||||
Preconditions.checkState(result.length == numRows,
|
||||
"copyBatchMapResult result size should equal;");
|
||||
PrimitiveType keyType = retType.getKeyType().getPrimitiveType();
|
||||
PrimitiveType valueType = retType.getValueType().getPrimitiveType();
|
||||
Object[] keyCol = new Object[result.length];
|
||||
Object[] valueCol = new Object[result.length];
|
||||
buildArrayListFromHashMap(result, keyType, valueType, keyCol, valueCol);
|
||||
|
||||
copyBatchArrayResultImpl(isNullable, numRows, valueCol, nullMapAddr, offsetsAddr, valueNsestedNullMapAddr,
|
||||
valueDataAddr,
|
||||
valueStrOffsetAddr, valueType, retType.getKeyScale());
|
||||
copyBatchArrayResultImpl(isNullable, numRows, keyCol, nullMapAddr, offsetsAddr, keyNsestedNullMapAddr,
|
||||
keyDataAddr,
|
||||
keyStrOffsetAddr, keyType, retType.getValueScale());
|
||||
}
|
||||
|
||||
/**
|
||||
* Evaluates the UDF with 'args' as the input to the UDF.
|
||||
*/
|
||||
private Object evaluate(Object... args) throws UdfRuntimeException {
|
||||
try {
|
||||
return method.invoke(udf, args);
|
||||
|
||||
if (outputTable != null) {
|
||||
outputTable.close();
|
||||
}
|
||||
|
||||
boolean isNullable = Boolean.parseBoolean(outputParams.getOrDefault("is_nullable", "true"));
|
||||
outputTable = VectorTable.createWritableTable(outputParams, numRows);
|
||||
outputTable.appendData(0, result, getOutputConverter(), isNullable);
|
||||
return outputTable.getMetaAddress();
|
||||
} catch (Exception e) {
|
||||
LOG.warn("evaluate exception: " + debugString(), e);
|
||||
throw new UdfRuntimeException("UDF failed to evaluate", e);
|
||||
}
|
||||
}
|
||||
@ -254,5 +358,4 @@ public class UdfExecutor extends BaseExecutor {
|
||||
throw new UdfRuntimeException("Unable to call create UDF instance.", e);
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user