[fix](Java UDF) Do not use enum as the data type for JavaUdfDataType. (#24460)

This commit is contained in:
Mryange
2023-09-19 14:06:02 +08:00
committed by GitHub
parent eea84ac36c
commit ee56783629
5 changed files with 251 additions and 177 deletions

View File

@ -0,0 +1,235 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.
package org.apache.doris.common.jni.utils;
import org.apache.doris.catalog.Type;
import org.apache.doris.common.exception.InternalException;
import org.apache.doris.thrift.TPrimitiveType;
import com.google.common.collect.Sets;
import org.apache.log4j.Logger;
import java.math.BigDecimal;
import java.math.BigInteger;
import java.util.HashSet;
import java.util.Set;
// Data types that are supported as return or argument types in Java UDFs.
public class JavaUdfDataType {
public static final Logger LOG = Logger.getLogger(JavaUdfDataType.class);
public static final JavaUdfDataType INVALID_TYPE = new JavaUdfDataType("INVALID_TYPE",
TPrimitiveType.INVALID_TYPE, 0);
public static final JavaUdfDataType BOOLEAN = new JavaUdfDataType("BOOLEAN", TPrimitiveType.BOOLEAN, 1);
public static final JavaUdfDataType TINYINT = new JavaUdfDataType("TINYINT", TPrimitiveType.TINYINT, 1);
public static final JavaUdfDataType SMALLINT = new JavaUdfDataType("SMALLINT", TPrimitiveType.SMALLINT, 2);
public static final JavaUdfDataType INT = new JavaUdfDataType("INT", TPrimitiveType.INT, 4);
public static final JavaUdfDataType BIGINT = new JavaUdfDataType("BIGINT", TPrimitiveType.BIGINT, 8);
public static final JavaUdfDataType FLOAT = new JavaUdfDataType("FLOAT", TPrimitiveType.FLOAT, 4);
public static final JavaUdfDataType DOUBLE = new JavaUdfDataType("DOUBLE", TPrimitiveType.DOUBLE, 8);
public static final JavaUdfDataType CHAR = new JavaUdfDataType("CHAR", TPrimitiveType.CHAR, 0);
public static final JavaUdfDataType VARCHAR = new JavaUdfDataType("VARCHAR", TPrimitiveType.VARCHAR, 0);
public static final JavaUdfDataType STRING = new JavaUdfDataType("STRING", TPrimitiveType.STRING, 0);
public static final JavaUdfDataType DATE = new JavaUdfDataType("DATE", TPrimitiveType.DATE, 8);
public static final JavaUdfDataType DATETIME = new JavaUdfDataType("DATETIME", TPrimitiveType.DATETIME, 8);
public static final JavaUdfDataType LARGEINT = new JavaUdfDataType("LARGEINT", TPrimitiveType.LARGEINT, 16);
public static final JavaUdfDataType DECIMALV2 = new JavaUdfDataType("DECIMALV2", TPrimitiveType.DECIMALV2, 16);
public static final JavaUdfDataType DATEV2 = new JavaUdfDataType("DATEV2", TPrimitiveType.DATEV2, 4);
public static final JavaUdfDataType DATETIMEV2 = new JavaUdfDataType("DATETIMEV2", TPrimitiveType.DATETIMEV2,
8);
public static final JavaUdfDataType DECIMAL32 = new JavaUdfDataType("DECIMAL32", TPrimitiveType.DECIMAL32, 4);
public static final JavaUdfDataType DECIMAL64 = new JavaUdfDataType("DECIMAL64", TPrimitiveType.DECIMAL64, 8);
public static final JavaUdfDataType DECIMAL128 = new JavaUdfDataType("DECIMAL128", TPrimitiveType.DECIMAL128I,
16);
public static final JavaUdfDataType ARRAY_TYPE = new JavaUdfDataType("ARRAY_TYPE", TPrimitiveType.ARRAY, 0);
public static final JavaUdfDataType MAP_TYPE = new JavaUdfDataType("MAP_TYPE", TPrimitiveType.MAP, 0);
private static Set<JavaUdfDataType> JavaUdfDataTypeSet = new HashSet<>();
static {
JavaUdfDataTypeSet.add(INVALID_TYPE);
JavaUdfDataTypeSet.add(BOOLEAN);
JavaUdfDataTypeSet.add(TINYINT);
JavaUdfDataTypeSet.add(SMALLINT);
JavaUdfDataTypeSet.add(INT);
JavaUdfDataTypeSet.add(BIGINT);
JavaUdfDataTypeSet.add(FLOAT);
JavaUdfDataTypeSet.add(DOUBLE);
JavaUdfDataTypeSet.add(CHAR);
JavaUdfDataTypeSet.add(VARCHAR);
JavaUdfDataTypeSet.add(STRING);
JavaUdfDataTypeSet.add(DATE);
JavaUdfDataTypeSet.add(DATETIME);
JavaUdfDataTypeSet.add(LARGEINT);
JavaUdfDataTypeSet.add(DECIMALV2);
JavaUdfDataTypeSet.add(DATEV2);
JavaUdfDataTypeSet.add(DATETIMEV2);
JavaUdfDataTypeSet.add(DECIMAL32);
JavaUdfDataTypeSet.add(DECIMAL64);
JavaUdfDataTypeSet.add(DECIMAL128);
JavaUdfDataTypeSet.add(ARRAY_TYPE);
JavaUdfDataTypeSet.add(MAP_TYPE);
}
private final String description;
private final TPrimitiveType thriftType;
private final int len;
private int precision;
private int scale;
private Type itemType = null;
private Type keyType;
private Type valueType;
private int keyScale;
private int valueScale;
public JavaUdfDataType(String description, TPrimitiveType thriftType, int len) {
this.description = description;
this.thriftType = thriftType;
this.len = len;
}
public JavaUdfDataType(JavaUdfDataType other) {
this.description = other.description;
this.thriftType = other.thriftType;
this.len = other.len;
}
@Override
public String toString() {
return description;
}
public TPrimitiveType getPrimitiveType() {
return thriftType;
}
public int getLen() {
return len;
}
public static Set<JavaUdfDataType> getCandidateTypes(Class<?> c) {
if (c == boolean.class || c == Boolean.class) {
return Sets.newHashSet(JavaUdfDataType.BOOLEAN);
} else if (c == byte.class || c == Byte.class) {
return Sets.newHashSet(JavaUdfDataType.TINYINT);
} else if (c == short.class || c == Short.class) {
return Sets.newHashSet(JavaUdfDataType.SMALLINT);
} else if (c == int.class || c == Integer.class) {
return Sets.newHashSet(JavaUdfDataType.INT);
} else if (c == long.class || c == Long.class) {
return Sets.newHashSet(JavaUdfDataType.BIGINT);
} else if (c == float.class || c == Float.class) {
return Sets.newHashSet(JavaUdfDataType.FLOAT);
} else if (c == double.class || c == Double.class) {
return Sets.newHashSet(JavaUdfDataType.DOUBLE);
} else if (c == char.class || c == Character.class) {
return Sets.newHashSet(JavaUdfDataType.CHAR);
} else if (c == String.class) {
return Sets.newHashSet(JavaUdfDataType.STRING);
} else if (Type.DATE_SUPPORTED_JAVA_TYPE.contains(c)) {
return Sets.newHashSet(JavaUdfDataType.DATE, JavaUdfDataType.DATEV2);
} else if (Type.DATETIME_SUPPORTED_JAVA_TYPE.contains(c)) {
return Sets.newHashSet(JavaUdfDataType.DATETIME, JavaUdfDataType.DATETIMEV2);
} else if (c == BigInteger.class) {
return Sets.newHashSet(JavaUdfDataType.LARGEINT);
} else if (c == BigDecimal.class) {
return Sets.newHashSet(JavaUdfDataType.DECIMALV2, JavaUdfDataType.DECIMAL32, JavaUdfDataType.DECIMAL64,
JavaUdfDataType.DECIMAL128);
} else if (c == java.util.ArrayList.class) {
return Sets.newHashSet(JavaUdfDataType.ARRAY_TYPE);
} else if (c == java.util.HashMap.class) {
return Sets.newHashSet(JavaUdfDataType.MAP_TYPE);
}
return Sets.newHashSet(JavaUdfDataType.INVALID_TYPE);
}
public static boolean isSupported(Type t) {
for (JavaUdfDataType javaType : JavaUdfDataTypeSet) {
if (javaType == JavaUdfDataType.INVALID_TYPE) {
continue;
}
if (javaType.getPrimitiveType() == t.getPrimitiveType().toThrift()) {
return true;
}
}
return false;
}
public int getPrecision() {
return precision;
}
public void setPrecision(int precision) {
this.precision = precision;
}
public int getScale() {
return this.thriftType == TPrimitiveType.DECIMALV2 ? 9 : scale;
}
public void setScale(int scale) {
this.scale = scale;
}
public Type getItemType() {
return itemType;
}
public void setItemType(Type type) throws InternalException {
if (this.itemType == null) {
this.itemType = type;
} else {
if (!this.itemType.matchesType(type)) {
LOG.info("set error");
throw new InternalException("udf type not matches origin type :" + this.itemType.toSql()
+ " set type :" + type.toSql());
}
}
}
public Type getKeyType() {
return keyType;
}
public Type getValueType() {
return valueType;
}
public void setKeyType(Type type) {
this.keyType = type;
}
public void setValueType(Type type) {
this.valueType = type;
}
public void setKeyScale(int scale) {
this.keyScale = scale;
}
public void setValueScale(int scale) {
this.valueScale = scale;
}
public int getKeyScale() {
return keyScale;
}
public int getValueScale() {
return valueScale;
}
}

View File

@ -23,9 +23,7 @@ import org.apache.doris.catalog.ScalarType;
import org.apache.doris.catalog.Type;
import org.apache.doris.common.Pair;
import org.apache.doris.common.exception.InternalException;
import org.apache.doris.thrift.TPrimitiveType;
import com.google.common.collect.Sets;
import com.vesoft.nebula.client.graph.data.DateTimeWrapper;
import com.vesoft.nebula.client.graph.data.DateWrapper;
import com.vesoft.nebula.client.graph.data.ValueWrapper;
@ -35,8 +33,6 @@ import sun.misc.Unsafe;
import java.io.File;
import java.io.FileNotFoundException;
import java.lang.reflect.Field;
import java.math.BigDecimal;
import java.math.BigInteger;
import java.net.MalformedURLException;
import java.net.URL;
import java.net.URLClassLoader;
@ -69,166 +65,6 @@ public class UdfUtils {
INT_ARRAY_OFFSET = UNSAFE.arrayBaseOffset(int[].class);
}
// Data types that are supported as return or argument types in Java UDFs.
public enum JavaUdfDataType {
INVALID_TYPE("INVALID_TYPE", TPrimitiveType.INVALID_TYPE, 0),
BOOLEAN("BOOLEAN", TPrimitiveType.BOOLEAN, 1),
TINYINT("TINYINT", TPrimitiveType.TINYINT, 1),
SMALLINT("SMALLINT", TPrimitiveType.SMALLINT, 2),
INT("INT", TPrimitiveType.INT, 4),
BIGINT("BIGINT", TPrimitiveType.BIGINT, 8),
FLOAT("FLOAT", TPrimitiveType.FLOAT, 4),
DOUBLE("DOUBLE", TPrimitiveType.DOUBLE, 8),
CHAR("CHAR", TPrimitiveType.CHAR, 0),
VARCHAR("VARCHAR", TPrimitiveType.VARCHAR, 0),
STRING("STRING", TPrimitiveType.STRING, 0),
DATE("DATE", TPrimitiveType.DATE, 8),
DATETIME("DATETIME", TPrimitiveType.DATETIME, 8),
LARGEINT("LARGEINT", TPrimitiveType.LARGEINT, 16),
DECIMALV2("DECIMALV2", TPrimitiveType.DECIMALV2, 16),
DATEV2("DATEV2", TPrimitiveType.DATEV2, 4),
DATETIMEV2("DATETIMEV2", TPrimitiveType.DATETIMEV2, 8),
DECIMAL32("DECIMAL32", TPrimitiveType.DECIMAL32, 4),
DECIMAL64("DECIMAL64", TPrimitiveType.DECIMAL64, 8),
DECIMAL128("DECIMAL128", TPrimitiveType.DECIMAL128I, 16),
ARRAY_TYPE("ARRAY_TYPE", TPrimitiveType.ARRAY, 0),
MAP_TYPE("MAP_TYPE", TPrimitiveType.MAP, 0);
private final String description;
private final TPrimitiveType thriftType;
private final int len;
private int precision;
private int scale;
private Type itemType;
private Type keyType;
private Type valueType;
private int keyScale;
private int valueScale;
JavaUdfDataType(String description, TPrimitiveType thriftType, int len) {
this.description = description;
this.thriftType = thriftType;
this.len = len;
}
@Override
public String toString() {
return description;
}
public TPrimitiveType getPrimitiveType() {
return thriftType;
}
public int getLen() {
return len;
}
public static Set<JavaUdfDataType> getCandidateTypes(Class<?> c) {
if (c == boolean.class || c == Boolean.class) {
return Sets.newHashSet(JavaUdfDataType.BOOLEAN);
} else if (c == byte.class || c == Byte.class) {
return Sets.newHashSet(JavaUdfDataType.TINYINT);
} else if (c == short.class || c == Short.class) {
return Sets.newHashSet(JavaUdfDataType.SMALLINT);
} else if (c == int.class || c == Integer.class) {
return Sets.newHashSet(JavaUdfDataType.INT);
} else if (c == long.class || c == Long.class) {
return Sets.newHashSet(JavaUdfDataType.BIGINT);
} else if (c == float.class || c == Float.class) {
return Sets.newHashSet(JavaUdfDataType.FLOAT);
} else if (c == double.class || c == Double.class) {
return Sets.newHashSet(JavaUdfDataType.DOUBLE);
} else if (c == char.class || c == Character.class) {
return Sets.newHashSet(JavaUdfDataType.CHAR);
} else if (c == String.class) {
return Sets.newHashSet(JavaUdfDataType.STRING);
} else if (Type.DATE_SUPPORTED_JAVA_TYPE.contains(c)) {
return Sets.newHashSet(JavaUdfDataType.DATE, JavaUdfDataType.DATEV2);
} else if (Type.DATETIME_SUPPORTED_JAVA_TYPE.contains(c)) {
return Sets.newHashSet(JavaUdfDataType.DATETIME, JavaUdfDataType.DATETIMEV2);
} else if (c == BigInteger.class) {
return Sets.newHashSet(JavaUdfDataType.LARGEINT);
} else if (c == BigDecimal.class) {
return Sets.newHashSet(JavaUdfDataType.DECIMALV2, JavaUdfDataType.DECIMAL32, JavaUdfDataType.DECIMAL64,
JavaUdfDataType.DECIMAL128);
} else if (c == java.util.ArrayList.class) {
return Sets.newHashSet(JavaUdfDataType.ARRAY_TYPE);
} else if (c == java.util.HashMap.class) {
return Sets.newHashSet(JavaUdfDataType.MAP_TYPE);
}
return Sets.newHashSet(JavaUdfDataType.INVALID_TYPE);
}
public static boolean isSupported(Type t) {
for (JavaUdfDataType javaType : JavaUdfDataType.values()) {
if (javaType == JavaUdfDataType.INVALID_TYPE) {
continue;
}
if (javaType.getPrimitiveType() == t.getPrimitiveType().toThrift()) {
return true;
}
}
return false;
}
public int getPrecision() {
return precision;
}
public void setPrecision(int precision) {
this.precision = precision;
}
public int getScale() {
return this.thriftType == TPrimitiveType.DECIMALV2 ? 9 : scale;
}
public void setScale(int scale) {
this.scale = scale;
}
public Type getItemType() {
return itemType;
}
public void setItemType(Type type) {
this.itemType = type;
}
public Type getKeyType() {
return keyType;
}
public Type getValueType() {
return valueType;
}
public void setKeyType(Type type) {
this.keyType = type;
}
public void setValueType(Type type) {
this.valueType = type;
}
public void setKeyScale(int scale) {
this.keyScale = scale;
}
public void setValueScale(int scale) {
this.valueScale = scale;
}
public int getKeyScale() {
return keyScale;
}
public int getValueScale() {
return valueScale;
}
}
public static void copyMemory(
Object src, long srcOffset, Object dst, long dstOffset, long length) {
// Check if dstOffset is before or after srcOffset to determine if we should copy
@ -282,7 +118,8 @@ public class UdfUtils {
Object[] res = javaTypes.stream().filter(
t -> t.getPrimitiveType() == retType.getPrimitiveType().toThrift()).toArray();
JavaUdfDataType result = res.length == 0 ? javaTypes.iterator().next() : (JavaUdfDataType) res[0];
JavaUdfDataType result = new JavaUdfDataType(
res.length == 0 ? javaTypes.iterator().next() : (JavaUdfDataType) res[0]);
if (retType.isDecimalV3() || retType.isDatetimeV2()) {
result.setPrecision(retType.getPrecision());
result.setScale(((ScalarType) retType).getScalarScale());
@ -313,9 +150,10 @@ public class UdfUtils {
* Sets the argument types of a Java UDF or UDAF. Returns true if the argument types specified
* in the UDF are compatible with the argument types of the evaluate() function loaded
* from the associated JAR file.
* @throws InternalException
*/
public static Pair<Boolean, JavaUdfDataType[]> setArgTypes(Type[] parameterTypes, Class<?>[] udfArgTypes,
boolean isUdaf) {
boolean isUdaf) throws InternalException {
JavaUdfDataType[] inputArgTypes = new JavaUdfDataType[parameterTypes.length];
int firstPos = isUdaf ? 1 : 0;
for (int i = 0; i < parameterTypes.length; ++i) {
@ -323,7 +161,8 @@ public class UdfUtils {
int finalI = i;
Object[] res = javaTypes.stream().filter(
t -> t.getPrimitiveType() == parameterTypes[finalI].getPrimitiveType().toThrift()).toArray();
inputArgTypes[i] = res.length == 0 ? javaTypes.iterator().next() : (JavaUdfDataType) res[0];
inputArgTypes[i] = new JavaUdfDataType(
res.length == 0 ? javaTypes.iterator().next() : (JavaUdfDataType) res[0]);
if (parameterTypes[finalI].isDecimalV3() || parameterTypes[finalI].isDatetimeV2()) {
inputArgTypes[i].setPrecision(parameterTypes[finalI].getPrecision());
inputArgTypes[i].setScale(((ScalarType) parameterTypes[finalI]).getScalarScale());