diff --git a/fe/fe-core/src/main/java/org/apache/doris/analysis/CreateFunctionStmt.java b/fe/fe-core/src/main/java/org/apache/doris/analysis/CreateFunctionStmt.java index 76f3638dde..0ecd3b1ffb 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/analysis/CreateFunctionStmt.java +++ b/fe/fe-core/src/main/java/org/apache/doris/analysis/CreateFunctionStmt.java @@ -21,6 +21,7 @@ import org.apache.doris.catalog.AggregateFunction; import org.apache.doris.catalog.AliasFunction; import org.apache.doris.catalog.Catalog; import org.apache.doris.catalog.Function; +import org.apache.doris.catalog.PrimitiveType; import org.apache.doris.catalog.ScalarFunction; import org.apache.doris.catalog.ScalarType; import org.apache.doris.catalog.Type; @@ -41,7 +42,9 @@ import org.apache.doris.thrift.TFunctionBinaryType; import com.google.common.base.Strings; import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableMap; import com.google.common.collect.ImmutableSortedMap; +import com.google.common.collect.Sets; import org.apache.commons.codec.binary.Hex; import org.apache.commons.lang3.StringUtils; @@ -50,10 +53,17 @@ import org.apache.logging.log4j.Logger; import java.io.IOException; import java.io.InputStream; +import java.lang.reflect.Method; +import java.lang.reflect.Modifier; +import java.lang.reflect.Parameter; +import java.net.MalformedURLException; +import java.net.URL; +import java.net.URLClassLoader; import java.security.MessageDigest; import java.security.NoSuchAlgorithmException; import java.util.List; import java.util.Map; +import java.util.Set; import io.grpc.ManagedChannel; import io.grpc.netty.shaded.io.grpc.netty.NettyChannelBuilder; @@ -61,7 +71,9 @@ import io.grpc.netty.shaded.io.grpc.netty.NettyChannelBuilder; // create a user define function public class CreateFunctionStmt extends DdlStmt { private final static Logger LOG = LogManager.getLogger(CreateFunctionStmt.class); + @Deprecated public static final String OBJECT_FILE_KEY = "object_file"; + public static final String FILE_KEY = "file"; public static final String SYMBOL_KEY = "symbol"; public static final String PREPARE_SYMBOL_KEY = "prepare_fn"; public static final String CLOSE_SYMBOL_KEY = "close_fn"; @@ -74,6 +86,7 @@ public class CreateFunctionStmt extends DdlStmt { public static final String GET_VALUE_KEY = "get_value_fn"; public static final String REMOVE_KEY = "remove_fn"; public static final String BINARY_TYPE = "type"; + public static final String EVAL_METHOD_KEY = "evaluate"; private final FunctionName functionName; private final boolean isAggregate; @@ -87,7 +100,7 @@ public class CreateFunctionStmt extends DdlStmt { TFunctionBinaryType binaryType = TFunctionBinaryType.NATIVE; // needed item set after analyzed - private String objectFile; + private String userFile; private Function function; private String checksum = ""; @@ -183,9 +196,9 @@ public class CreateFunctionStmt extends DdlStmt { throw new AnalysisException("unknown function type"); } - objectFile = properties.get(OBJECT_FILE_KEY); - if (Strings.isNullOrEmpty(objectFile)) { - throw new AnalysisException("No 'object_file' in properties"); + userFile = properties.getOrDefault(FILE_KEY, properties.get(OBJECT_FILE_KEY)); + if (Strings.isNullOrEmpty(userFile)) { + throw new AnalysisException("No 'file' or 'object_file' in properties"); } if (binaryType != TFunctionBinaryType.RPC) { try { @@ -206,7 +219,7 @@ public class CreateFunctionStmt extends DdlStmt { return; } - try (InputStream inputStream = Util.getInputStreamFromUrl(objectFile, null, HTTP_TIMEOUT_MS, HTTP_TIMEOUT_MS)) { + try (InputStream inputStream = Util.getInputStreamFromUrl(userFile, null, HTTP_TIMEOUT_MS, HTTP_TIMEOUT_MS)) { MessageDigest digest = MessageDigest.getInstance("MD5"); byte[] buf = new byte[4096]; int bytesRead = 0; @@ -229,7 +242,7 @@ public class CreateFunctionStmt extends DdlStmt { AggregateFunction.AggregateFunctionBuilder builder = AggregateFunction.AggregateFunctionBuilder.createUdfBuilder(); builder.name(functionName).argsType(argsDef.getArgTypes()).retType(returnType.getType()). - hasVarArgs(argsDef.isVariadic()).intermediateType(intermediateType.getType()).location(URI.create(objectFile)); + hasVarArgs(argsDef.isVariadic()).intermediateType(intermediateType.getType()).location(URI.create(userFile)); String initFnSymbol = properties.get(INIT_KEY); if (initFnSymbol == null) { throw new AnalysisException("No 'init_fn' in properties"); @@ -259,11 +272,11 @@ public class CreateFunctionStmt extends DdlStmt { String closeFnSymbol = properties.get(CLOSE_SYMBOL_KEY); // TODO(yangzhg) support check function in FE when function service behind load balancer // the format for load balance can ref https://github.com/apache/incubator-brpc/blob/master/docs/en/client.md#connect-to-a-cluster - if (binaryType == TFunctionBinaryType.RPC && !objectFile.contains("://")) { + if (binaryType == TFunctionBinaryType.RPC && !userFile.contains("://")) { if (StringUtils.isNotBlank(prepareFnSymbol) || StringUtils.isNotBlank(closeFnSymbol)) { throw new AnalysisException(" prepare and close in RPC UDF are not supported."); } - String[] url = objectFile.split(":"); + String[] url = userFile.split(":"); if (url.length != 2) { throw new AnalysisException("function server address invalid."); } @@ -288,8 +301,10 @@ public class CreateFunctionStmt extends DdlStmt { if (response.getStatus().getStatusCode() != 0) { throw new AnalysisException("check function [" + symbol + "] failed: " + response.getStatus()); } + } else if (binaryType == TFunctionBinaryType.JAVA_UDF) { + analyzeJavaUdf(symbol); } - URI location = URI.create(objectFile); + URI location = URI.create(userFile); function = ScalarFunction.createUdf(binaryType, functionName, argsDef.getArgTypes(), returnType.getType(), argsDef.isVariadic(), @@ -297,6 +312,92 @@ public class CreateFunctionStmt extends DdlStmt { function.setChecksum(checksum); } + private void analyzeJavaUdf(String clazz) throws AnalysisException { + try { + URL[] urls = {new URL("jar:" + userFile + "!/")}; + URLClassLoader cl = URLClassLoader.newInstance(urls); + Class udfClass = cl.loadClass(clazz); + + Method eval = null; + for (Method m : udfClass.getMethods()) { + if (!m.getDeclaringClass().equals(udfClass)) { + continue; + } + String name = m.getName(); + if (EVAL_METHOD_KEY.equals(name) && eval == null) { + eval = m; + } else if (EVAL_METHOD_KEY.equals(name)) { + throw new AnalysisException(String.format( + "UDF class '%s' has multiple methods with name '%s' ", udfClass.getCanonicalName(), + EVAL_METHOD_KEY)); + } + } + if (eval == null) { + throw new AnalysisException(String.format( + "No method '%s' in class '%s'!", EVAL_METHOD_KEY, udfClass.getCanonicalName())); + } + if (Modifier.isStatic(eval.getModifiers())) { + throw new AnalysisException( + String.format("Method '%s' in class '%s' should be non-static", eval.getName(), + udfClass.getCanonicalName())); + } + if (!Modifier.isPublic(eval.getModifiers())) { + throw new AnalysisException( + String.format("Method '%s' in class '%s' should be public", eval.getName(), + udfClass.getCanonicalName())); + } + if (eval.getParameters().length != argsDef.getArgTypes().length) { + throw new AnalysisException( + String.format("The number of parameters for method '%s' in class '%s' should be %d", + eval.getName(), udfClass.getCanonicalName(), argsDef.getArgTypes().length)); + } + + checkUdfType(udfClass, eval, returnType.getType(), eval.getReturnType(), "return"); + for (int i = 0; i < eval.getParameters().length; i++) { + Parameter p = eval.getParameters()[i]; + checkUdfType(udfClass, eval, argsDef.getArgTypes()[i], p.getType(), p.getName()); + } + } catch (MalformedURLException e) { + throw new AnalysisException("Failed to load file: " + userFile); + } catch (ClassNotFoundException e) { + throw new AnalysisException("Class [" + clazz + "] not found in file :" + userFile); + } + } + + private static final ImmutableMap> PrimitiveTypeToJavaClassType = + new ImmutableMap.Builder>() + .put(PrimitiveType.BOOLEAN, Sets.newHashSet(Boolean.class, boolean.class)) + .put(PrimitiveType.TINYINT, Sets.newHashSet(Byte.class, byte.class)) + .put(PrimitiveType.SMALLINT, Sets.newHashSet(Short.class, short.class)) + .put(PrimitiveType.INT, Sets.newHashSet(Integer.class, int.class)) + .put(PrimitiveType.FLOAT, Sets.newHashSet(Float.class, float.class)) + .put(PrimitiveType.DOUBLE, Sets.newHashSet(Double.class, double.class)) + .put(PrimitiveType.BIGINT, Sets.newHashSet(Long.class, long.class)) + .put(PrimitiveType.CHAR, Sets.newHashSet(String.class)) + .put(PrimitiveType.VARCHAR, Sets.newHashSet(String.class)) + .build(); + + private void checkUdfType(Class clazz, Method method, Type expType, Class pType, String pname) + throws AnalysisException { + if (!(expType instanceof ScalarType)) { + throw new AnalysisException( + String.format("Method '%s' in class '%s' does not support non-scalar type '%s'", + method.getName(), clazz.getCanonicalName(), expType)); + } + ScalarType scalarType = (ScalarType) expType; + Set javaTypes = PrimitiveTypeToJavaClassType.get(scalarType.getPrimitiveType()); + if (javaTypes == null) { + throw new AnalysisException( + String.format("Method '%s' in class '%s' does not support type '%s'", + method.getName(), clazz.getCanonicalName(), scalarType)); + } + if (!javaTypes.contains(pType)) { + throw new AnalysisException( + String.format("UDF class '%s' method '%s' %s[%s] type is not supported!", + clazz.getCanonicalName(), method.getName(), pname, pType.getCanonicalName())); + } + } + private Types.PGenericType convertToPParameterType(Type arg) throws AnalysisException { Types.PGenericType.Builder typeBuilder = Types.PGenericType.newBuilder(); switch (arg.getPrimitiveType()) { diff --git a/gensrc/thrift/Types.thrift b/gensrc/thrift/Types.thrift index 88992a7375..71dd38a613 100644 --- a/gensrc/thrift/Types.thrift +++ b/gensrc/thrift/Types.thrift @@ -269,6 +269,8 @@ enum TFunctionBinaryType { // call udfs by rpc service RPC, + + JAVA_UDF, } // Represents a fully qualified function name.