[feature-wip](udf) support java udf in FE (#8437)
First step to support Java UDF in Doris. After this PR, we can create Java UDF in doris. For example, we create Java UDF function by code below. ``` CREATE FUNCTION test_udf(int) RETURNS int PROPERTIES ( "file"="file:///root/hive-udf-1.0-SNAPSHOT.jar", "symbol"="udf.Main", "type"="JAVA_UDF" ) ``` 1. `file` indicate where user file is. 2. `symbol` for java udf means udf class in this jar. 3. `type` indicate this function is a java udf.
This commit is contained in:
@ -21,6 +21,7 @@ import org.apache.doris.catalog.AggregateFunction;
|
||||
import org.apache.doris.catalog.AliasFunction;
|
||||
import org.apache.doris.catalog.Catalog;
|
||||
import org.apache.doris.catalog.Function;
|
||||
import org.apache.doris.catalog.PrimitiveType;
|
||||
import org.apache.doris.catalog.ScalarFunction;
|
||||
import org.apache.doris.catalog.ScalarType;
|
||||
import org.apache.doris.catalog.Type;
|
||||
@ -41,7 +42,9 @@ import org.apache.doris.thrift.TFunctionBinaryType;
|
||||
|
||||
import com.google.common.base.Strings;
|
||||
import com.google.common.collect.ImmutableList;
|
||||
import com.google.common.collect.ImmutableMap;
|
||||
import com.google.common.collect.ImmutableSortedMap;
|
||||
import com.google.common.collect.Sets;
|
||||
|
||||
import org.apache.commons.codec.binary.Hex;
|
||||
import org.apache.commons.lang3.StringUtils;
|
||||
@ -50,10 +53,17 @@ import org.apache.logging.log4j.Logger;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.io.InputStream;
|
||||
import java.lang.reflect.Method;
|
||||
import java.lang.reflect.Modifier;
|
||||
import java.lang.reflect.Parameter;
|
||||
import java.net.MalformedURLException;
|
||||
import java.net.URL;
|
||||
import java.net.URLClassLoader;
|
||||
import java.security.MessageDigest;
|
||||
import java.security.NoSuchAlgorithmException;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.Set;
|
||||
|
||||
import io.grpc.ManagedChannel;
|
||||
import io.grpc.netty.shaded.io.grpc.netty.NettyChannelBuilder;
|
||||
@ -61,7 +71,9 @@ import io.grpc.netty.shaded.io.grpc.netty.NettyChannelBuilder;
|
||||
// create a user define function
|
||||
public class CreateFunctionStmt extends DdlStmt {
|
||||
private final static Logger LOG = LogManager.getLogger(CreateFunctionStmt.class);
|
||||
@Deprecated
|
||||
public static final String OBJECT_FILE_KEY = "object_file";
|
||||
public static final String FILE_KEY = "file";
|
||||
public static final String SYMBOL_KEY = "symbol";
|
||||
public static final String PREPARE_SYMBOL_KEY = "prepare_fn";
|
||||
public static final String CLOSE_SYMBOL_KEY = "close_fn";
|
||||
@ -74,6 +86,7 @@ public class CreateFunctionStmt extends DdlStmt {
|
||||
public static final String GET_VALUE_KEY = "get_value_fn";
|
||||
public static final String REMOVE_KEY = "remove_fn";
|
||||
public static final String BINARY_TYPE = "type";
|
||||
public static final String EVAL_METHOD_KEY = "evaluate";
|
||||
|
||||
private final FunctionName functionName;
|
||||
private final boolean isAggregate;
|
||||
@ -87,7 +100,7 @@ public class CreateFunctionStmt extends DdlStmt {
|
||||
TFunctionBinaryType binaryType = TFunctionBinaryType.NATIVE;
|
||||
|
||||
// needed item set after analyzed
|
||||
private String objectFile;
|
||||
private String userFile;
|
||||
private Function function;
|
||||
private String checksum = "";
|
||||
|
||||
@ -183,9 +196,9 @@ public class CreateFunctionStmt extends DdlStmt {
|
||||
throw new AnalysisException("unknown function type");
|
||||
}
|
||||
|
||||
objectFile = properties.get(OBJECT_FILE_KEY);
|
||||
if (Strings.isNullOrEmpty(objectFile)) {
|
||||
throw new AnalysisException("No 'object_file' in properties");
|
||||
userFile = properties.getOrDefault(FILE_KEY, properties.get(OBJECT_FILE_KEY));
|
||||
if (Strings.isNullOrEmpty(userFile)) {
|
||||
throw new AnalysisException("No 'file' or 'object_file' in properties");
|
||||
}
|
||||
if (binaryType != TFunctionBinaryType.RPC) {
|
||||
try {
|
||||
@ -206,7 +219,7 @@ public class CreateFunctionStmt extends DdlStmt {
|
||||
return;
|
||||
}
|
||||
|
||||
try (InputStream inputStream = Util.getInputStreamFromUrl(objectFile, null, HTTP_TIMEOUT_MS, HTTP_TIMEOUT_MS)) {
|
||||
try (InputStream inputStream = Util.getInputStreamFromUrl(userFile, null, HTTP_TIMEOUT_MS, HTTP_TIMEOUT_MS)) {
|
||||
MessageDigest digest = MessageDigest.getInstance("MD5");
|
||||
byte[] buf = new byte[4096];
|
||||
int bytesRead = 0;
|
||||
@ -229,7 +242,7 @@ public class CreateFunctionStmt extends DdlStmt {
|
||||
AggregateFunction.AggregateFunctionBuilder builder = AggregateFunction.AggregateFunctionBuilder.createUdfBuilder();
|
||||
|
||||
builder.name(functionName).argsType(argsDef.getArgTypes()).retType(returnType.getType()).
|
||||
hasVarArgs(argsDef.isVariadic()).intermediateType(intermediateType.getType()).location(URI.create(objectFile));
|
||||
hasVarArgs(argsDef.isVariadic()).intermediateType(intermediateType.getType()).location(URI.create(userFile));
|
||||
String initFnSymbol = properties.get(INIT_KEY);
|
||||
if (initFnSymbol == null) {
|
||||
throw new AnalysisException("No 'init_fn' in properties");
|
||||
@ -259,11 +272,11 @@ public class CreateFunctionStmt extends DdlStmt {
|
||||
String closeFnSymbol = properties.get(CLOSE_SYMBOL_KEY);
|
||||
// TODO(yangzhg) support check function in FE when function service behind load balancer
|
||||
// the format for load balance can ref https://github.com/apache/incubator-brpc/blob/master/docs/en/client.md#connect-to-a-cluster
|
||||
if (binaryType == TFunctionBinaryType.RPC && !objectFile.contains("://")) {
|
||||
if (binaryType == TFunctionBinaryType.RPC && !userFile.contains("://")) {
|
||||
if (StringUtils.isNotBlank(prepareFnSymbol) || StringUtils.isNotBlank(closeFnSymbol)) {
|
||||
throw new AnalysisException(" prepare and close in RPC UDF are not supported.");
|
||||
}
|
||||
String[] url = objectFile.split(":");
|
||||
String[] url = userFile.split(":");
|
||||
if (url.length != 2) {
|
||||
throw new AnalysisException("function server address invalid.");
|
||||
}
|
||||
@ -288,8 +301,10 @@ public class CreateFunctionStmt extends DdlStmt {
|
||||
if (response.getStatus().getStatusCode() != 0) {
|
||||
throw new AnalysisException("check function [" + symbol + "] failed: " + response.getStatus());
|
||||
}
|
||||
} else if (binaryType == TFunctionBinaryType.JAVA_UDF) {
|
||||
analyzeJavaUdf(symbol);
|
||||
}
|
||||
URI location = URI.create(objectFile);
|
||||
URI location = URI.create(userFile);
|
||||
function = ScalarFunction.createUdf(binaryType,
|
||||
functionName, argsDef.getArgTypes(),
|
||||
returnType.getType(), argsDef.isVariadic(),
|
||||
@ -297,6 +312,92 @@ public class CreateFunctionStmt extends DdlStmt {
|
||||
function.setChecksum(checksum);
|
||||
}
|
||||
|
||||
private void analyzeJavaUdf(String clazz) throws AnalysisException {
|
||||
try {
|
||||
URL[] urls = {new URL("jar:" + userFile + "!/")};
|
||||
URLClassLoader cl = URLClassLoader.newInstance(urls);
|
||||
Class udfClass = cl.loadClass(clazz);
|
||||
|
||||
Method eval = null;
|
||||
for (Method m : udfClass.getMethods()) {
|
||||
if (!m.getDeclaringClass().equals(udfClass)) {
|
||||
continue;
|
||||
}
|
||||
String name = m.getName();
|
||||
if (EVAL_METHOD_KEY.equals(name) && eval == null) {
|
||||
eval = m;
|
||||
} else if (EVAL_METHOD_KEY.equals(name)) {
|
||||
throw new AnalysisException(String.format(
|
||||
"UDF class '%s' has multiple methods with name '%s' ", udfClass.getCanonicalName(),
|
||||
EVAL_METHOD_KEY));
|
||||
}
|
||||
}
|
||||
if (eval == null) {
|
||||
throw new AnalysisException(String.format(
|
||||
"No method '%s' in class '%s'!", EVAL_METHOD_KEY, udfClass.getCanonicalName()));
|
||||
}
|
||||
if (Modifier.isStatic(eval.getModifiers())) {
|
||||
throw new AnalysisException(
|
||||
String.format("Method '%s' in class '%s' should be non-static", eval.getName(),
|
||||
udfClass.getCanonicalName()));
|
||||
}
|
||||
if (!Modifier.isPublic(eval.getModifiers())) {
|
||||
throw new AnalysisException(
|
||||
String.format("Method '%s' in class '%s' should be public", eval.getName(),
|
||||
udfClass.getCanonicalName()));
|
||||
}
|
||||
if (eval.getParameters().length != argsDef.getArgTypes().length) {
|
||||
throw new AnalysisException(
|
||||
String.format("The number of parameters for method '%s' in class '%s' should be %d",
|
||||
eval.getName(), udfClass.getCanonicalName(), argsDef.getArgTypes().length));
|
||||
}
|
||||
|
||||
checkUdfType(udfClass, eval, returnType.getType(), eval.getReturnType(), "return");
|
||||
for (int i = 0; i < eval.getParameters().length; i++) {
|
||||
Parameter p = eval.getParameters()[i];
|
||||
checkUdfType(udfClass, eval, argsDef.getArgTypes()[i], p.getType(), p.getName());
|
||||
}
|
||||
} catch (MalformedURLException e) {
|
||||
throw new AnalysisException("Failed to load file: " + userFile);
|
||||
} catch (ClassNotFoundException e) {
|
||||
throw new AnalysisException("Class [" + clazz + "] not found in file :" + userFile);
|
||||
}
|
||||
}
|
||||
|
||||
private static final ImmutableMap<PrimitiveType, Set<Class>> PrimitiveTypeToJavaClassType =
|
||||
new ImmutableMap.Builder<PrimitiveType, Set<Class>>()
|
||||
.put(PrimitiveType.BOOLEAN, Sets.newHashSet(Boolean.class, boolean.class))
|
||||
.put(PrimitiveType.TINYINT, Sets.newHashSet(Byte.class, byte.class))
|
||||
.put(PrimitiveType.SMALLINT, Sets.newHashSet(Short.class, short.class))
|
||||
.put(PrimitiveType.INT, Sets.newHashSet(Integer.class, int.class))
|
||||
.put(PrimitiveType.FLOAT, Sets.newHashSet(Float.class, float.class))
|
||||
.put(PrimitiveType.DOUBLE, Sets.newHashSet(Double.class, double.class))
|
||||
.put(PrimitiveType.BIGINT, Sets.newHashSet(Long.class, long.class))
|
||||
.put(PrimitiveType.CHAR, Sets.newHashSet(String.class))
|
||||
.put(PrimitiveType.VARCHAR, Sets.newHashSet(String.class))
|
||||
.build();
|
||||
|
||||
private void checkUdfType(Class clazz, Method method, Type expType, Class pType, String pname)
|
||||
throws AnalysisException {
|
||||
if (!(expType instanceof ScalarType)) {
|
||||
throw new AnalysisException(
|
||||
String.format("Method '%s' in class '%s' does not support non-scalar type '%s'",
|
||||
method.getName(), clazz.getCanonicalName(), expType));
|
||||
}
|
||||
ScalarType scalarType = (ScalarType) expType;
|
||||
Set<Class> javaTypes = PrimitiveTypeToJavaClassType.get(scalarType.getPrimitiveType());
|
||||
if (javaTypes == null) {
|
||||
throw new AnalysisException(
|
||||
String.format("Method '%s' in class '%s' does not support type '%s'",
|
||||
method.getName(), clazz.getCanonicalName(), scalarType));
|
||||
}
|
||||
if (!javaTypes.contains(pType)) {
|
||||
throw new AnalysisException(
|
||||
String.format("UDF class '%s' method '%s' %s[%s] type is not supported!",
|
||||
clazz.getCanonicalName(), method.getName(), pname, pType.getCanonicalName()));
|
||||
}
|
||||
}
|
||||
|
||||
private Types.PGenericType convertToPParameterType(Type arg) throws AnalysisException {
|
||||
Types.PGenericType.Builder typeBuilder = Types.PGenericType.newBuilder();
|
||||
switch (arg.getPrimitiveType()) {
|
||||
|
||||
@ -269,6 +269,8 @@ enum TFunctionBinaryType {
|
||||
|
||||
// call udfs by rpc service
|
||||
RPC,
|
||||
|
||||
JAVA_UDF,
|
||||
}
|
||||
|
||||
// Represents a fully qualified function name.
|
||||
|
||||
Reference in New Issue
Block a user