[feature-wip](udf) support java udf in FE (#8437)

First step to support Java UDF in Doris. After this PR, we can create Java UDF in doris.

For example, we create Java UDF function by code below.
```
CREATE FUNCTION test_udf(int) RETURNS int 
PROPERTIES (
"file"="file:///root/hive-udf-1.0-SNAPSHOT.jar",
"symbol"="udf.Main", 
"type"="JAVA_UDF"
)
```
1. `file` indicate where user file is.
2. `symbol` for java udf means udf class in this jar.
3. `type` indicate this function is a java udf.
This commit is contained in:
Gabriel
2022-03-15 11:42:39 +08:00
committed by GitHub
parent 571f0b688d
commit 7d1d45d6dc
2 changed files with 112 additions and 9 deletions

View File

@ -21,6 +21,7 @@ import org.apache.doris.catalog.AggregateFunction;
import org.apache.doris.catalog.AliasFunction;
import org.apache.doris.catalog.Catalog;
import org.apache.doris.catalog.Function;
import org.apache.doris.catalog.PrimitiveType;
import org.apache.doris.catalog.ScalarFunction;
import org.apache.doris.catalog.ScalarType;
import org.apache.doris.catalog.Type;
@ -41,7 +42,9 @@ import org.apache.doris.thrift.TFunctionBinaryType;
import com.google.common.base.Strings;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import com.google.common.collect.ImmutableSortedMap;
import com.google.common.collect.Sets;
import org.apache.commons.codec.binary.Hex;
import org.apache.commons.lang3.StringUtils;
@ -50,10 +53,17 @@ import org.apache.logging.log4j.Logger;
import java.io.IOException;
import java.io.InputStream;
import java.lang.reflect.Method;
import java.lang.reflect.Modifier;
import java.lang.reflect.Parameter;
import java.net.MalformedURLException;
import java.net.URL;
import java.net.URLClassLoader;
import java.security.MessageDigest;
import java.security.NoSuchAlgorithmException;
import java.util.List;
import java.util.Map;
import java.util.Set;
import io.grpc.ManagedChannel;
import io.grpc.netty.shaded.io.grpc.netty.NettyChannelBuilder;
@ -61,7 +71,9 @@ import io.grpc.netty.shaded.io.grpc.netty.NettyChannelBuilder;
// create a user define function
public class CreateFunctionStmt extends DdlStmt {
private final static Logger LOG = LogManager.getLogger(CreateFunctionStmt.class);
@Deprecated
public static final String OBJECT_FILE_KEY = "object_file";
public static final String FILE_KEY = "file";
public static final String SYMBOL_KEY = "symbol";
public static final String PREPARE_SYMBOL_KEY = "prepare_fn";
public static final String CLOSE_SYMBOL_KEY = "close_fn";
@ -74,6 +86,7 @@ public class CreateFunctionStmt extends DdlStmt {
public static final String GET_VALUE_KEY = "get_value_fn";
public static final String REMOVE_KEY = "remove_fn";
public static final String BINARY_TYPE = "type";
public static final String EVAL_METHOD_KEY = "evaluate";
private final FunctionName functionName;
private final boolean isAggregate;
@ -87,7 +100,7 @@ public class CreateFunctionStmt extends DdlStmt {
TFunctionBinaryType binaryType = TFunctionBinaryType.NATIVE;
// needed item set after analyzed
private String objectFile;
private String userFile;
private Function function;
private String checksum = "";
@ -183,9 +196,9 @@ public class CreateFunctionStmt extends DdlStmt {
throw new AnalysisException("unknown function type");
}
objectFile = properties.get(OBJECT_FILE_KEY);
if (Strings.isNullOrEmpty(objectFile)) {
throw new AnalysisException("No 'object_file' in properties");
userFile = properties.getOrDefault(FILE_KEY, properties.get(OBJECT_FILE_KEY));
if (Strings.isNullOrEmpty(userFile)) {
throw new AnalysisException("No 'file' or 'object_file' in properties");
}
if (binaryType != TFunctionBinaryType.RPC) {
try {
@ -206,7 +219,7 @@ public class CreateFunctionStmt extends DdlStmt {
return;
}
try (InputStream inputStream = Util.getInputStreamFromUrl(objectFile, null, HTTP_TIMEOUT_MS, HTTP_TIMEOUT_MS)) {
try (InputStream inputStream = Util.getInputStreamFromUrl(userFile, null, HTTP_TIMEOUT_MS, HTTP_TIMEOUT_MS)) {
MessageDigest digest = MessageDigest.getInstance("MD5");
byte[] buf = new byte[4096];
int bytesRead = 0;
@ -229,7 +242,7 @@ public class CreateFunctionStmt extends DdlStmt {
AggregateFunction.AggregateFunctionBuilder builder = AggregateFunction.AggregateFunctionBuilder.createUdfBuilder();
builder.name(functionName).argsType(argsDef.getArgTypes()).retType(returnType.getType()).
hasVarArgs(argsDef.isVariadic()).intermediateType(intermediateType.getType()).location(URI.create(objectFile));
hasVarArgs(argsDef.isVariadic()).intermediateType(intermediateType.getType()).location(URI.create(userFile));
String initFnSymbol = properties.get(INIT_KEY);
if (initFnSymbol == null) {
throw new AnalysisException("No 'init_fn' in properties");
@ -259,11 +272,11 @@ public class CreateFunctionStmt extends DdlStmt {
String closeFnSymbol = properties.get(CLOSE_SYMBOL_KEY);
// TODO(yangzhg) support check function in FE when function service behind load balancer
// the format for load balance can ref https://github.com/apache/incubator-brpc/blob/master/docs/en/client.md#connect-to-a-cluster
if (binaryType == TFunctionBinaryType.RPC && !objectFile.contains("://")) {
if (binaryType == TFunctionBinaryType.RPC && !userFile.contains("://")) {
if (StringUtils.isNotBlank(prepareFnSymbol) || StringUtils.isNotBlank(closeFnSymbol)) {
throw new AnalysisException(" prepare and close in RPC UDF are not supported.");
}
String[] url = objectFile.split(":");
String[] url = userFile.split(":");
if (url.length != 2) {
throw new AnalysisException("function server address invalid.");
}
@ -288,8 +301,10 @@ public class CreateFunctionStmt extends DdlStmt {
if (response.getStatus().getStatusCode() != 0) {
throw new AnalysisException("check function [" + symbol + "] failed: " + response.getStatus());
}
} else if (binaryType == TFunctionBinaryType.JAVA_UDF) {
analyzeJavaUdf(symbol);
}
URI location = URI.create(objectFile);
URI location = URI.create(userFile);
function = ScalarFunction.createUdf(binaryType,
functionName, argsDef.getArgTypes(),
returnType.getType(), argsDef.isVariadic(),
@ -297,6 +312,92 @@ public class CreateFunctionStmt extends DdlStmt {
function.setChecksum(checksum);
}
private void analyzeJavaUdf(String clazz) throws AnalysisException {
try {
URL[] urls = {new URL("jar:" + userFile + "!/")};
URLClassLoader cl = URLClassLoader.newInstance(urls);
Class udfClass = cl.loadClass(clazz);
Method eval = null;
for (Method m : udfClass.getMethods()) {
if (!m.getDeclaringClass().equals(udfClass)) {
continue;
}
String name = m.getName();
if (EVAL_METHOD_KEY.equals(name) && eval == null) {
eval = m;
} else if (EVAL_METHOD_KEY.equals(name)) {
throw new AnalysisException(String.format(
"UDF class '%s' has multiple methods with name '%s' ", udfClass.getCanonicalName(),
EVAL_METHOD_KEY));
}
}
if (eval == null) {
throw new AnalysisException(String.format(
"No method '%s' in class '%s'!", EVAL_METHOD_KEY, udfClass.getCanonicalName()));
}
if (Modifier.isStatic(eval.getModifiers())) {
throw new AnalysisException(
String.format("Method '%s' in class '%s' should be non-static", eval.getName(),
udfClass.getCanonicalName()));
}
if (!Modifier.isPublic(eval.getModifiers())) {
throw new AnalysisException(
String.format("Method '%s' in class '%s' should be public", eval.getName(),
udfClass.getCanonicalName()));
}
if (eval.getParameters().length != argsDef.getArgTypes().length) {
throw new AnalysisException(
String.format("The number of parameters for method '%s' in class '%s' should be %d",
eval.getName(), udfClass.getCanonicalName(), argsDef.getArgTypes().length));
}
checkUdfType(udfClass, eval, returnType.getType(), eval.getReturnType(), "return");
for (int i = 0; i < eval.getParameters().length; i++) {
Parameter p = eval.getParameters()[i];
checkUdfType(udfClass, eval, argsDef.getArgTypes()[i], p.getType(), p.getName());
}
} catch (MalformedURLException e) {
throw new AnalysisException("Failed to load file: " + userFile);
} catch (ClassNotFoundException e) {
throw new AnalysisException("Class [" + clazz + "] not found in file :" + userFile);
}
}
private static final ImmutableMap<PrimitiveType, Set<Class>> PrimitiveTypeToJavaClassType =
new ImmutableMap.Builder<PrimitiveType, Set<Class>>()
.put(PrimitiveType.BOOLEAN, Sets.newHashSet(Boolean.class, boolean.class))
.put(PrimitiveType.TINYINT, Sets.newHashSet(Byte.class, byte.class))
.put(PrimitiveType.SMALLINT, Sets.newHashSet(Short.class, short.class))
.put(PrimitiveType.INT, Sets.newHashSet(Integer.class, int.class))
.put(PrimitiveType.FLOAT, Sets.newHashSet(Float.class, float.class))
.put(PrimitiveType.DOUBLE, Sets.newHashSet(Double.class, double.class))
.put(PrimitiveType.BIGINT, Sets.newHashSet(Long.class, long.class))
.put(PrimitiveType.CHAR, Sets.newHashSet(String.class))
.put(PrimitiveType.VARCHAR, Sets.newHashSet(String.class))
.build();
private void checkUdfType(Class clazz, Method method, Type expType, Class pType, String pname)
throws AnalysisException {
if (!(expType instanceof ScalarType)) {
throw new AnalysisException(
String.format("Method '%s' in class '%s' does not support non-scalar type '%s'",
method.getName(), clazz.getCanonicalName(), expType));
}
ScalarType scalarType = (ScalarType) expType;
Set<Class> javaTypes = PrimitiveTypeToJavaClassType.get(scalarType.getPrimitiveType());
if (javaTypes == null) {
throw new AnalysisException(
String.format("Method '%s' in class '%s' does not support type '%s'",
method.getName(), clazz.getCanonicalName(), scalarType));
}
if (!javaTypes.contains(pType)) {
throw new AnalysisException(
String.format("UDF class '%s' method '%s' %s[%s] type is not supported!",
clazz.getCanonicalName(), method.getName(), pname, pType.getCanonicalName()));
}
}
private Types.PGenericType convertToPParameterType(Type arg) throws AnalysisException {
Types.PGenericType.Builder typeBuilder = Types.PGenericType.newBuilder();
switch (arg.getPrimitiveType()) {

View File

@ -269,6 +269,8 @@ enum TFunctionBinaryType {
// call udfs by rpc service
RPC,
JAVA_UDF,
}
// Represents a fully qualified function name.