[refactor](jni) unified jni framework for java udf (#25302)

Use the unified jni framework to refactor java udf.
The unified jni framework takes VectorTable as the container to transform data between c++ and java, and hide the details of data format conversion.
In addition, the unified framework supports complex and nested types.
The performance of basic types remains consistent, with a 30% improvement in string types and an order of magnitude improvement in complex types.
This commit is contained in:
Ashin Gau
2023-10-18 09:27:54 +08:00
committed by GitHub
parent 26e332c608
commit 47689fd452
23 changed files with 2153 additions and 742 deletions

View File

@ -17,17 +17,18 @@
package org.apache.doris.udf;
import org.apache.doris.catalog.PrimitiveType;
import org.apache.doris.catalog.Type;
import org.apache.doris.common.Pair;
import org.apache.doris.common.exception.UdfRuntimeException;
import org.apache.doris.common.jni.utils.JavaUdfDataType;
import org.apache.doris.common.jni.utils.UdfUtils;
import org.apache.doris.common.jni.vec.ColumnValueConverter;
import org.apache.doris.common.jni.vec.VectorTable;
import org.apache.doris.thrift.TJavaUdfExecutorCtorParams;
import org.apache.doris.thrift.TPrimitiveType;
import com.esotericsoftware.reflectasm.MethodAccess;
import com.google.common.base.Joiner;
import com.google.common.base.Preconditions;
import com.google.common.collect.Lists;
import org.apache.log4j.Logger;
@ -35,7 +36,11 @@ import java.lang.reflect.Array;
import java.lang.reflect.Constructor;
import java.lang.reflect.Method;
import java.net.MalformedURLException;
import java.time.LocalDate;
import java.time.LocalDateTime;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.Map;
public class UdfExecutor extends BaseExecutor {
// private static final java.util.logging.Logger LOG =
@ -46,6 +51,8 @@ public class UdfExecutor extends BaseExecutor {
private int evaluateIndex;
private VectorTable outputTable = null;
/**
* Create a UdfExecutor, using parameters from a serialized thrift object. Used by
* the backend.
@ -59,98 +66,195 @@ public class UdfExecutor extends BaseExecutor {
*/
@Override
public void close() {
// inputTable is released by c++, only release outputTable
if (outputTable != null) {
outputTable.close();
}
// We are now un-usable (because the class loader has been
// closed), so null out method_ and classLoader_.
method = null;
super.close();
}
public Object[] convertBasicArguments(int argIdx, boolean isNullable, int numRows, long nullMapAddr,
long columnAddr, long strOffsetAddr) {
return convertBasicArg(true, argIdx, isNullable, 0, numRows, nullMapAddr, columnAddr, strOffsetAddr);
private ColumnValueConverter getInputConverter(TPrimitiveType primitiveType, Class clz) {
switch (primitiveType) {
case DATE:
case DATEV2: {
if (java.util.Date.class.equals(clz)) {
return (Object[] columnData) -> {
Object[] result = new java.util.Date[columnData.length];
for (int i = 0; i < columnData.length; ++i) {
if (columnData[i] != null) {
LocalDate v = (LocalDate) columnData[i];
result[i] = new java.util.Date(v.getYear() - 1900, v.getMonthValue() - 1,
v.getDayOfMonth());
}
}
return result;
};
} else if (org.joda.time.LocalDate.class.equals(clz)) {
return (Object[] columnData) -> {
Object[] result = new org.joda.time.LocalDate[columnData.length];
for (int i = 0; i < columnData.length; ++i) {
if (columnData[i] != null) {
LocalDate v = (LocalDate) columnData[i];
result[i] = new org.joda.time.LocalDate(v.getYear(), v.getMonthValue(),
v.getDayOfMonth());
}
}
return result;
};
} else if (!LocalDate.class.equals(clz)) {
throw new RuntimeException("Unsupported date type: " + clz.getCanonicalName());
}
break;
}
case DATETIME:
case DATETIMEV2: {
if (org.joda.time.DateTime.class.equals(clz)) {
return (Object[] columnData) -> {
Object[] result = new org.joda.time.DateTime[columnData.length];
for (int i = 0; i < columnData.length; ++i) {
if (columnData[i] != null) {
LocalDateTime v = (LocalDateTime) columnData[i];
result[i] = new org.joda.time.DateTime(v.getYear(), v.getMonthValue(),
v.getDayOfMonth(), v.getHour(),
v.getMinute(), v.getSecond(), v.getNano() / 1000000);
}
}
return result;
};
} else if (org.joda.time.LocalDateTime.class.equals(clz)) {
return (Object[] columnData) -> {
Object[] result = new org.joda.time.LocalDateTime[columnData.length];
for (int i = 0; i < columnData.length; ++i) {
if (columnData[i] != null) {
LocalDateTime v = (LocalDateTime) columnData[i];
result[i] = new org.joda.time.LocalDateTime(v.getYear(), v.getMonthValue(),
v.getDayOfMonth(), v.getHour(),
v.getMinute(), v.getSecond(), v.getNano() / 1000000);
}
}
return result;
};
} else if (!LocalDateTime.class.equals(clz)) {
throw new RuntimeException("Unsupported date type: " + clz.getCanonicalName());
}
break;
}
default:
break;
}
return null;
}
public Object[] convertArrayArguments(int argIdx, boolean isNullable, int numRows, long nullMapAddr,
long offsetsAddr, long nestedNullMapAddr, long dataAddr, long strOffsetAddr) {
return convertArrayArg(argIdx, isNullable, 0, numRows, nullMapAddr, offsetsAddr, nestedNullMapAddr, dataAddr,
strOffsetAddr);
private Map<Integer, ColumnValueConverter> getInputConverters(int numColumns) {
Map<Integer, ColumnValueConverter> converters = new HashMap<>();
for (int j = 0; j < numColumns; ++j) {
ColumnValueConverter converter = getInputConverter(argTypes[j].getPrimitiveType(), argClass[j]);
if (converter != null) {
converters.put(j, converter);
}
}
return converters;
}
public Object[] convertMapArguments(int argIdx, boolean isNullable, int numRows, long nullMapAddr,
long offsetsAddr, long keyNestedNullMapAddr, long keyDataAddr, long keyStrOffsetAddr,
long valueNestedNullMapAddr, long valueDataAddr, long valueStrOffsetAddr) {
PrimitiveType keyType = argTypes[argIdx].getKeyType().getPrimitiveType();
PrimitiveType valueType = argTypes[argIdx].getValueType().getPrimitiveType();
Object[] keyCol = convertMapArg(keyType, argIdx, isNullable, 0, numRows, nullMapAddr, offsetsAddr,
keyNestedNullMapAddr, keyDataAddr,
keyStrOffsetAddr, argTypes[argIdx].getKeyScale());
Object[] valueCol = convertMapArg(valueType, argIdx, isNullable, 0, numRows, nullMapAddr, offsetsAddr,
valueNestedNullMapAddr, valueDataAddr,
valueStrOffsetAddr, argTypes[argIdx].getValueScale());
return buildHashMap(keyType, valueType, keyCol, valueCol);
private ColumnValueConverter getOutputConverter() {
Class clz = method.getReturnType();
switch (retType.getPrimitiveType()) {
case DATE:
case DATEV2: {
if (java.util.Date.class.equals(clz)) {
return (Object[] columnData) -> {
Object[] result = new LocalDate[columnData.length];
for (int i = 0; i < columnData.length; ++i) {
if (columnData[i] != null) {
java.util.Date v = (java.util.Date) columnData[i];
result[i] = LocalDate.of(v.getYear() + 1900, v.getMonth() + 1, v.getDate());
}
}
return result;
};
} else if (org.joda.time.LocalDate.class.equals(clz)) {
return (Object[] columnData) -> {
Object[] result = new LocalDate[columnData.length];
for (int i = 0; i < columnData.length; ++i) {
if (columnData[i] != null) {
org.joda.time.LocalDate v = (org.joda.time.LocalDate) columnData[i];
result[i] = LocalDate.of(v.getYear(), v.getMonthOfYear(), v.getDayOfMonth());
}
}
return result;
};
} else if (!LocalDate.class.equals(clz)) {
throw new RuntimeException("Unsupported date type: " + clz.getCanonicalName());
}
break;
}
case DATETIME:
case DATETIMEV2: {
if (org.joda.time.DateTime.class.equals(clz)) {
return (Object[] columnData) -> {
Object[] result = new LocalDateTime[columnData.length];
for (int i = 0; i < columnData.length; ++i) {
if (columnData[i] != null) {
org.joda.time.DateTime v = (org.joda.time.DateTime) columnData[i];
result[i] = LocalDateTime.of(v.getYear(), v.getMonthOfYear(), v.getDayOfMonth(),
v.getHourOfDay(),
v.getMinuteOfHour(), v.getSecondOfMinute(), v.getMillisOfSecond() * 1000000);
}
}
return result;
};
} else if (org.joda.time.LocalDateTime.class.equals(clz)) {
return (Object[] columnData) -> {
Object[] result = new LocalDateTime[columnData.length];
for (int i = 0; i < columnData.length; ++i) {
if (columnData[i] != null) {
org.joda.time.LocalDateTime v = (org.joda.time.LocalDateTime) columnData[i];
result[i] = LocalDateTime.of(v.getYear(), v.getMonthOfYear(), v.getDayOfMonth(),
v.getHourOfDay(),
v.getMinuteOfHour(), v.getSecondOfMinute(), v.getMillisOfSecond() * 1000000);
}
}
return result;
};
} else if (!LocalDateTime.class.equals(clz)) {
throw new RuntimeException("Unsupported date type: " + clz.getCanonicalName());
}
break;
}
default:
break;
}
return null;
}
/**
* Evaluates the UDF with 'args' as the input to the UDF.
*/
public Object[] evaluate(int numRows, Object[] column) throws UdfRuntimeException {
public long evaluate(Map<String, String> inputParams, Map<String, String> outputParams) throws UdfRuntimeException {
try {
VectorTable inputTable = VectorTable.createReadableTable(inputParams);
int numRows = inputTable.getNumRows();
int numColumns = inputTable.getNumColumns();
Object[] result = (Object[]) Array.newInstance(method.getReturnType(), numRows);
Object[][] inputs = (Object[][]) column;
Object[] parameters = new Object[inputs.length];
Object[][] inputs = inputTable.getMaterializedData(getInputConverters(numColumns));
Object[] parameters = new Object[numColumns];
for (int i = 0; i < numRows; ++i) {
for (int j = 0; j < column.length; ++j) {
for (int j = 0; j < numColumns; ++j) {
parameters[j] = inputs[j][i];
}
result[i] = methodAccess.invoke(udf, evaluateIndex, parameters);
}
return result;
} catch (Exception e) {
LOG.info("evaluate exception: " + debugString());
LOG.info("evaluate(int numRows, Object[] column) Exception: " + e.toString());
throw new UdfRuntimeException("UDF failed to evaluate", e);
}
}
public void copyBatchBasicResult(boolean isNullable, int numRows, Object[] result, long nullMapAddr,
long resColumnAddr, long strOffsetAddr) {
copyBatchBasicResultImpl(isNullable, numRows, result, nullMapAddr, resColumnAddr, strOffsetAddr, getMethod());
}
public void copyBatchArrayResult(boolean isNullable, int numRows, Object[] result, long nullMapAddr,
long offsetsAddr, long nestedNullMapAddr, long dataAddr, long strOffsetAddr) {
Preconditions.checkState(result.length == numRows,
"copyBatchArrayResult result size should equal;");
copyBatchArrayResultImpl(isNullable, numRows, result, nullMapAddr, offsetsAddr, nestedNullMapAddr, dataAddr,
strOffsetAddr, retType.getItemType().getPrimitiveType(), retType.getScale());
}
public void copyBatchMapResult(boolean isNullable, int numRows, Object[] result, long nullMapAddr,
long offsetsAddr, long keyNsestedNullMapAddr, long keyDataAddr, long keyStrOffsetAddr,
long valueNsestedNullMapAddr, long valueDataAddr, long valueStrOffsetAddr) {
Preconditions.checkState(result.length == numRows,
"copyBatchMapResult result size should equal;");
PrimitiveType keyType = retType.getKeyType().getPrimitiveType();
PrimitiveType valueType = retType.getValueType().getPrimitiveType();
Object[] keyCol = new Object[result.length];
Object[] valueCol = new Object[result.length];
buildArrayListFromHashMap(result, keyType, valueType, keyCol, valueCol);
copyBatchArrayResultImpl(isNullable, numRows, valueCol, nullMapAddr, offsetsAddr, valueNsestedNullMapAddr,
valueDataAddr,
valueStrOffsetAddr, valueType, retType.getKeyScale());
copyBatchArrayResultImpl(isNullable, numRows, keyCol, nullMapAddr, offsetsAddr, keyNsestedNullMapAddr,
keyDataAddr,
keyStrOffsetAddr, keyType, retType.getValueScale());
}
/**
* Evaluates the UDF with 'args' as the input to the UDF.
*/
private Object evaluate(Object... args) throws UdfRuntimeException {
try {
return method.invoke(udf, args);
if (outputTable != null) {
outputTable.close();
}
boolean isNullable = Boolean.parseBoolean(outputParams.getOrDefault("is_nullable", "true"));
outputTable = VectorTable.createWritableTable(outputParams, numRows);
outputTable.appendData(0, result, getOutputConverter(), isNullable);
return outputTable.getMetaAddress();
} catch (Exception e) {
LOG.warn("evaluate exception: " + debugString(), e);
throw new UdfRuntimeException("UDF failed to evaluate", e);
}
}
@ -254,5 +358,4 @@ public class UdfExecutor extends BaseExecutor {
throw new UdfRuntimeException("Unable to call create UDF instance.", e);
}
}
}