[refactor](jni) unified jni framework for java udf (#25302)

Use the unified jni framework to refactor java udf.
The unified jni framework takes VectorTable as the container to transform data between c++ and java, and hide the details of data format conversion.
In addition, the unified framework supports complex and nested types.
The performance of basic types remains consistent, with a 30% improvement in string types and an order of magnitude improvement in complex types.
This commit is contained in:
Ashin Gau
2023-10-18 09:27:54 +08:00
committed by GitHub
parent 26e332c608
commit 47689fd452
23 changed files with 2153 additions and 742 deletions

View File

@ -81,7 +81,7 @@ public abstract class JniScanner {
public long getNextBatchMeta() throws IOException {
if (vectorTable == null) {
vectorTable = new VectorTable(types, fields, predicates, batchSize);
vectorTable = VectorTable.createWritableTable(types, fields, batchSize);
}
int numRows;
try {
@ -107,7 +107,7 @@ public abstract class JniScanner {
}
private long getMetaAddress(int numRows) {
vectorTable.setNumRows(numRows);
assert (numRows == vectorTable.getNumRows());
return vectorTable.getMetaAddress();
}

View File

@ -34,7 +34,7 @@ public class OffHeap {
private static boolean IS_TESTING = false;
private static final Unsafe UNSAFE;
public static final Unsafe UNSAFE;
public static final int BOOLEAN_ARRAY_OFFSET;
@ -78,6 +78,12 @@ public class OffHeap {
return UNSAFE.getInt(object, offset);
}
public static int[] getInt(Object object, long offset, int length) {
int[] result = new int[length];
UNSAFE.copyMemory(object, offset, result, INT_ARRAY_OFFSET, (long) length * Integer.BYTES);
return result;
}
public static void putInt(Object object, long offset, int value) {
UNSAFE.putInt(object, offset, value);
}
@ -86,6 +92,12 @@ public class OffHeap {
return UNSAFE.getBoolean(object, offset);
}
public static boolean[] getBoolean(Object object, long offset, int length) {
boolean[] result = new boolean[length];
UNSAFE.copyMemory(object, offset, result, BOOLEAN_ARRAY_OFFSET, length);
return result;
}
public static void putBoolean(Object object, long offset, boolean value) {
UNSAFE.putBoolean(object, offset, value);
}
@ -94,6 +106,12 @@ public class OffHeap {
return UNSAFE.getByte(object, offset);
}
public static byte[] getByte(Object object, long offset, int length) {
byte[] result = new byte[length];
UNSAFE.copyMemory(object, offset, result, BYTE_ARRAY_OFFSET, length * Byte.BYTES);
return result;
}
public static void putByte(Object object, long offset, byte value) {
UNSAFE.putByte(object, offset, value);
}
@ -102,6 +120,12 @@ public class OffHeap {
return UNSAFE.getShort(object, offset);
}
public static short[] getShort(Object object, long offset, int length) {
short[] result = new short[length];
UNSAFE.copyMemory(object, offset, result, SHORT_ARRAY_OFFSET, (long) length * Short.BYTES);
return result;
}
public static void putShort(Object object, long offset, short value) {
UNSAFE.putShort(object, offset, value);
}
@ -110,6 +134,12 @@ public class OffHeap {
return UNSAFE.getLong(object, offset);
}
public static long[] getLong(Object object, long offset, int length) {
long[] result = new long[length];
UNSAFE.copyMemory(object, offset, result, LONG_ARRAY_OFFSET, (long) length * Long.BYTES);
return result;
}
public static void putLong(Object object, long offset, long value) {
UNSAFE.putLong(object, offset, value);
}
@ -118,6 +148,12 @@ public class OffHeap {
return UNSAFE.getFloat(object, offset);
}
public static float[] getFloat(Object object, long offset, int length) {
float[] result = new float[length];
UNSAFE.copyMemory(object, offset, result, FLOAT_ARRAY_OFFSET, (long) length * Float.BYTES);
return result;
}
public static void putFloat(Object object, long offset, float value) {
UNSAFE.putFloat(object, offset, value);
}
@ -126,6 +162,12 @@ public class OffHeap {
return UNSAFE.getDouble(object, offset);
}
public static double[] getDouble(Object object, long offset, int length) {
double[] result = new double[length];
UNSAFE.copyMemory(object, offset, result, DOUBLE_ARRAY_OFFSET, (long) length * Double.BYTES);
return result;
}
public static void putDouble(Object object, long offset, double value) {
UNSAFE.putDouble(object, offset, value);
}

View File

@ -40,25 +40,21 @@ public class TypeNativeBytes {
}
public static byte[] getBigIntegerBytes(BigInteger v) {
byte[] bytes = v.toByteArray();
// If the BigInteger is not negative and the first byte is 0, remove the first byte
if (v.signum() >= 0 && bytes[0] == 0) {
bytes = Arrays.copyOfRange(bytes, 1, bytes.length);
byte[] bytes = convertByteOrder(v.toByteArray());
// here value is 16 bytes, so if result data greater than the maximum of 16
// bytes, it will return a wrong num to backend;
byte[] value = new byte[16];
// check data is negative
if (v.signum() == -1) {
Arrays.fill(value, (byte) -1);
}
// Convert the byte order if necessary
return convertByteOrder(bytes);
System.arraycopy(bytes, 0, value, 0, Math.min(bytes.length, value.length));
return value;
}
public static BigInteger getBigInteger(byte[] bytes) {
// Convert the byte order back if necessary
byte[] originalBytes = convertByteOrder(bytes);
// If the first byte has the sign bit set, add a 0 byte at the start
if ((originalBytes[0] & 0x80) != 0) {
byte[] extendedBytes = new byte[originalBytes.length + 1];
extendedBytes[0] = 0;
System.arraycopy(originalBytes, 0, extendedBytes, 1, originalBytes.length);
originalBytes = extendedBytes;
}
return new BigInteger(originalBytes);
}
@ -80,6 +76,22 @@ public class TypeNativeBytes {
return new BigDecimal(value, scale);
}
public static long convertToDateTime(int year, int month, int day, int hour, int minute, int second,
boolean isDate) {
long time = 0;
time = time + year;
time = (time << 8) + month;
time = (time << 8) + day;
time = (time << 8) + hour;
time = (time << 8) + minute;
time = (time << 12) + second;
int type = isDate ? 2 : 3;
time = (time << 3) + type;
//this bit is int neg = 0;
time = (time << 1);
return time;
}
public static int convertToDateV2(int year, int month, int day) {
return (int) (day | (long) month << 5 | (long) year << 9);
}
@ -95,20 +107,128 @@ public class TypeNativeBytes {
| (long) day << 37 | (long) month << 42 | (long) year << 46;
}
public static LocalDate convertToJavaDate(int date) {
public static LocalDate convertToJavaDateV1(long date) {
int year = (int) (date >> 48);
int yearMonth = (int) (date >> 40);
int yearMonthDay = (int) (date >> 32);
int month = (yearMonth & 0XFF);
int day = (yearMonthDay & 0XFF);
try {
return LocalDate.of(year, month, day);
} catch (DateTimeException e) {
return null;
}
}
public static Object convertToJavaDateV1(long date, Class clz) {
int year = (int) (date >> 48);
int yearMonth = (int) (date >> 40);
int yearMonthDay = (int) (date >> 32);
int month = (yearMonth & 0XFF);
int day = (yearMonthDay & 0XFF);
try {
if (LocalDate.class.equals(clz)) {
return LocalDate.of(year, month, day);
} else if (java.util.Date.class.equals(clz)) {
return new java.util.Date(year - 1900, month - 1, day);
} else if (org.joda.time.LocalDate.class.equals(clz)) {
return new org.joda.time.LocalDate(year, month, day);
} else {
return null;
}
} catch (Exception e) {
return null;
}
}
public static LocalDate convertToJavaDateV2(int date) {
int year = date >> 9;
int month = (date >> 5) & 0XF;
int day = date & 0X1F;
LocalDate value;
try {
value = LocalDate.of(year, month, day);
return LocalDate.of(year, month, day);
} catch (DateTimeException e) {
value = LocalDate.MAX;
return null;
}
return value;
}
public static LocalDateTime convertToJavaDateTime(long time) {
public static Object convertToJavaDateV2(int date, Class clz) {
int year = date >> 9;
int month = (date >> 5) & 0XF;
int day = date & 0X1F;
try {
if (LocalDate.class.equals(clz)) {
return LocalDate.of(year, month, day);
} else if (java.util.Date.class.equals(clz)) {
return new java.util.Date(year - 1900, month - 1, day);
} else if (org.joda.time.LocalDate.class.equals(clz)) {
return new org.joda.time.LocalDate(year, month, day);
} else {
return null;
}
} catch (Exception e) {
return null;
}
}
public static LocalDateTime convertToJavaDateTimeV1(long time) {
int year = (int) (time >> 48);
int yearMonth = (int) (time >> 40);
int yearMonthDay = (int) (time >> 32);
int month = (yearMonth & 0XFF);
int day = (yearMonthDay & 0XFF);
int hourMinuteSecond = (int) (time % (1 << 31));
int minuteTypeNeg = (hourMinuteSecond % (1 << 16));
int hour = (hourMinuteSecond >> 24);
int minute = ((hourMinuteSecond >> 16) & 0XFF);
int second = (minuteTypeNeg >> 4);
//here don't need those bits are type = ((minus_type_neg >> 1) & 0x7);
try {
return LocalDateTime.of(year, month, day, hour, minute, second);
} catch (DateTimeException e) {
return null;
}
}
public static Object convertToJavaDateTimeV1(long time, Class clz) {
int year = (int) (time >> 48);
int yearMonth = (int) (time >> 40);
int yearMonthDay = (int) (time >> 32);
int month = (yearMonth & 0XFF);
int day = (yearMonthDay & 0XFF);
int hourMinuteSecond = (int) (time % (1 << 31));
int minuteTypeNeg = (hourMinuteSecond % (1 << 16));
int hour = (hourMinuteSecond >> 24);
int minute = ((hourMinuteSecond >> 16) & 0XFF);
int second = (minuteTypeNeg >> 4);
//here don't need those bits are type = ((minus_type_neg >> 1) & 0x7);
try {
if (LocalDateTime.class.equals(clz)) {
return LocalDateTime.of(year, month, day, hour, minute, second);
} else if (org.joda.time.DateTime.class.equals(clz)) {
return new org.joda.time.DateTime(year, month, day, hour, minute, second);
} else if (org.joda.time.LocalDateTime.class.equals(clz)) {
return new org.joda.time.LocalDateTime(year, month, day, hour, minute, second);
} else {
return null;
}
} catch (Exception e) {
return null;
}
}
public static LocalDateTime convertToJavaDateTimeV2(long time) {
int year = (int) (time >> 46);
int yearMonth = (int) (time >> 42);
int yearMonthDay = (int) (time >> 37);
@ -121,12 +241,38 @@ public class TypeNativeBytes {
int second = (int) ((time >> 20) & 0X3F);
int microsecond = (int) (time & 0XFFFFF);
LocalDateTime value;
try {
value = LocalDateTime.of(year, month, day, hour, minute, second, microsecond * 1000);
return LocalDateTime.of(year, month, day, hour, minute, second, microsecond * 1000);
} catch (DateTimeException e) {
value = LocalDateTime.MAX;
return null;
}
}
public static Object convertToJavaDateTimeV2(long time, Class clz) {
int year = (int) (time >> 46);
int yearMonth = (int) (time >> 42);
int yearMonthDay = (int) (time >> 37);
int month = (yearMonth & 0XF);
int day = (yearMonthDay & 0X1F);
int hour = (int) ((time >> 32) & 0X1F);
int minute = (int) ((time >> 26) & 0X3F);
int second = (int) ((time >> 20) & 0X3F);
int microsecond = (int) (time & 0XFFFFF);
try {
if (LocalDateTime.class.equals(clz)) {
return LocalDateTime.of(year, month, day, hour, minute, second, microsecond * 1000);
} else if (org.joda.time.DateTime.class.equals(clz)) {
return new org.joda.time.DateTime(year, month, day, hour, minute, second, microsecond / 1000);
} else if (org.joda.time.LocalDateTime.class.equals(clz)) {
return new org.joda.time.LocalDateTime(year, month, day, hour, minute, second, microsecond / 1000);
} else {
return null;
}
} catch (Exception e) {
return null;
}
return value;
}
}

View File

@ -367,11 +367,11 @@ public class UdfUtils {
0, 0, true);
} else if (java.util.Date.class.equals(clz)) {
java.util.Date date = (java.util.Date) obj;
return convertToDateTime(date.getYear() + 1900, date.getMonth(), date.getDay(), 0,
return convertToDateTime(date.getYear() + 1900, date.getMonth() + 1, date.getDay(), 0,
0, 0, true);
} else if (org.joda.time.LocalDate.class.equals(clz)) {
org.joda.time.LocalDate date = (org.joda.time.LocalDate) obj;
return convertToDateTime(date.getYear(), date.getDayOfMonth(), date.getDayOfMonth(), 0,
return convertToDateTime(date.getYear(), date.getMonthOfYear(), date.getDayOfMonth(), 0,
0, 0, true);
} else {
return 0;

View File

@ -45,7 +45,9 @@ public class ColumnType {
LARGEINT(16),
FLOAT(4),
DOUBLE(8),
DATE(8),
DATEV2(4),
DATETIME(8),
DATETIMEV2(8),
CHAR(-1),
VARCHAR(-1),
@ -161,6 +163,14 @@ public class ColumnType {
return type == Type.STRUCT;
}
public boolean isDateV2() {
return type == Type.DATEV2;
}
public boolean isDateTimeV2() {
return type == Type.DATETIMEV2;
}
public Type getType() {
return type;
}
@ -264,9 +274,16 @@ public class ColumnType {
case "double":
type = Type.DOUBLE;
break;
case "datev1":
type = Type.DATE;
break;
case "date":
case "datev2":
type = Type.DATEV2;
break;
case "datetimev1":
type = Type.DATETIME;
break;
case "binary":
case "bytes":
type = Type.BINARY;
@ -275,7 +292,9 @@ public class ColumnType {
type = Type.STRING;
break;
default:
if (lowerCaseType.startsWith("timestamp")) {
if (lowerCaseType.startsWith("timestamp")
|| lowerCaseType.startsWith("datetime")
|| lowerCaseType.startsWith("datetimev2")) {
type = Type.DATETIMEV2;
precision = 6; // default
Matcher match = digitPattern.matcher(lowerCaseType);

View File

@ -0,0 +1,25 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.
package org.apache.doris.common.jni.vec;
/**
* Convert the column values if the type is not defined in ColumnType
*/
public interface ColumnValueConverter {
Object[] convert(Object[] column);
}

View File

@ -21,6 +21,9 @@ package org.apache.doris.common.jni.vec;
import org.apache.doris.common.jni.utils.OffHeap;
import org.apache.doris.common.jni.vec.ColumnType.Type;
import java.util.Collections;
import java.util.Map;
/**
* Store a batch of data as vector table.
*/
@ -28,56 +31,133 @@ public class VectorTable {
private final VectorColumn[] columns;
private final ColumnType[] columnTypes;
private final String[] fields;
private final ScanPredicate[] predicates;
private final VectorColumn meta;
private int numRows;
private final boolean onlyReadable;
private final int numRowsOfReadable;
private final boolean isRestoreTable;
public VectorTable(ColumnType[] types, String[] fields, ScanPredicate[] predicates, int capacity) {
// Create writable vector table
private VectorTable(ColumnType[] types, String[] fields, int capacity) {
this.columnTypes = types;
this.fields = fields;
this.columns = new VectorColumn[types.length];
this.predicates = predicates;
int metaSize = 1; // number of rows
for (int i = 0; i < types.length; i++) {
columns[i] = new VectorColumn(types[i], capacity);
columns[i] = VectorColumn.createWritableColumn(types[i], capacity);
metaSize += types[i].metaSize();
}
this.meta = new VectorColumn(new ColumnType("#meta", Type.BIGINT), metaSize);
this.numRows = 0;
this.isRestoreTable = false;
this.meta = VectorColumn.createWritableColumn(new ColumnType("#meta", Type.BIGINT), metaSize);
this.onlyReadable = false;
numRowsOfReadable = -1;
}
public VectorTable(ColumnType[] types, String[] fields, long metaAddress) {
// Create readable vector table
// `metaAddress` is generated by `JniConnector::generate_meta_info`
private VectorTable(ColumnType[] types, String[] fields, long metaAddress) {
long address = metaAddress;
this.columnTypes = types;
this.fields = fields;
this.columns = new VectorColumn[types.length];
this.predicates = new ScanPredicate[0];
this.numRows = (int) OffHeap.getLong(null, address);
int numRows = (int) OffHeap.getLong(null, address);
address += 8;
int metaSize = 1; // stores the number of rows + other columns meta data
for (int i = 0; i < types.length; i++) {
columns[i] = new VectorColumn(types[i], numRows, address);
columns[i] = VectorColumn.createReadableColumn(types[i], numRows, address);
metaSize += types[i].metaSize();
address += types[i].metaSize() * 8L;
}
this.meta = new VectorColumn(metaAddress, metaSize, new ColumnType("#meta", Type.BIGINT));
this.isRestoreTable = true;
this.meta = VectorColumn.createReadableColumn(metaAddress, metaSize, new ColumnType("#meta", Type.BIGINT));
this.onlyReadable = true;
numRowsOfReadable = numRows;
}
public static VectorTable createWritableTable(ColumnType[] types, String[] fields, int capacity) {
return new VectorTable(types, fields, capacity);
}
public static VectorTable createWritableTable(Map<String, String> params, int capacity) {
String[] requiredFields = params.get("required_fields").split(",");
String[] types = params.get("columns_types").split("#");
ColumnType[] columnTypes = new ColumnType[types.length];
for (int i = 0; i < types.length; i++) {
columnTypes[i] = ColumnType.parseType(requiredFields[i], types[i]);
}
return createWritableTable(columnTypes, requiredFields, capacity);
}
public static VectorTable createWritableTable(Map<String, String> params) {
return createWritableTable(params, Integer.parseInt(params.get("num_rows")));
}
public static VectorTable createReadableTable(ColumnType[] types, String[] fields, long metaAddress) {
return new VectorTable(types, fields, metaAddress);
}
public static VectorTable createReadableTable(Map<String, String> params) {
if (params.get("required_fields").isEmpty()) {
assert params.get("columns_types").isEmpty();
return createReadableTable(new ColumnType[0], new String[0], Long.parseLong(params.get("meta_address")));
}
String[] requiredFields = params.get("required_fields").split(",");
String[] types = params.get("columns_types").split("#");
long metaAddress = Long.parseLong(params.get("meta_address"));
// Get sql string from configuration map
ColumnType[] columnTypes = new ColumnType[types.length];
for (int i = 0; i < types.length; i++) {
columnTypes[i] = ColumnType.parseType(requiredFields[i], types[i]);
}
return createReadableTable(columnTypes, requiredFields, metaAddress);
}
public void appendNativeData(int fieldId, NativeColumnValue o) {
assert (!isRestoreTable);
assert (!onlyReadable);
columns[fieldId].appendNativeValue(o);
}
public void appendData(int fieldId, ColumnValue o) {
assert (!isRestoreTable);
assert (!onlyReadable);
columns[fieldId].appendValue(o);
}
public void appendData(int fieldId, Object[] batch, ColumnValueConverter converter, boolean isNullable) {
assert (!onlyReadable);
if (converter != null) {
columns[fieldId].appendObjectColumn(converter.convert(batch), isNullable);
} else {
columns[fieldId].appendObjectColumn(batch, isNullable);
}
}
public void appendData(int fieldId, Object[] batch, boolean isNullable) {
appendData(fieldId, batch, null, isNullable);
}
/**
* Get materialized data, each type is wrapped by its Java type. For example: int -> Integer, decimal -> BigDecimal
*
* @param converters A map of converters. Convert the column values if the type is not defined in ColumnType.
* The map key is the field ID in VectorTable.
*/
public Object[][] getMaterializedData(Map<Integer, ColumnValueConverter> converters) {
if (columns.length == 0) {
return new Object[0][0];
}
Object[][] data = new Object[columns.length][];
for (int j = 0; j < columns.length; ++j) {
Object[] columnData = columns[j].getObjectColumn(0, columns[j].numRows());
if (converters.containsKey(j)) {
data[j] = converters.get(j).convert(columnData);
} else {
data[j] = columnData;
}
}
return data;
}
public Object[][] getMaterializedData() {
return getMaterializedData(Collections.emptyMap());
}
public VectorColumn[] getColumns() {
return columns;
}
@ -95,22 +175,26 @@ public class VectorTable {
}
public void releaseColumn(int fieldId) {
assert (!isRestoreTable);
assert (!onlyReadable);
columns[fieldId].close();
}
public void setNumRows(int numRows) {
this.numRows = numRows;
public int getNumRows() {
if (onlyReadable) {
return numRowsOfReadable;
} else {
return columns[0].numRows();
}
}
public int getNumRows() {
return this.numRows;
public int getNumColumns() {
return columns.length;
}
public long getMetaAddress() {
if (!isRestoreTable) {
if (!onlyReadable) {
meta.reset();
meta.appendLong(numRows);
meta.appendLong(getNumRows());
for (VectorColumn c : columns) {
c.updateMeta(meta);
}
@ -119,7 +203,7 @@ public class VectorTable {
}
public void reset() {
assert (!isRestoreTable);
assert (!onlyReadable);
for (VectorColumn column : columns) {
column.reset();
}
@ -127,7 +211,7 @@ public class VectorTable {
}
public void close() {
assert (!isRestoreTable);
assert (!onlyReadable);
for (int i = 0; i < columns.length; i++) {
releaseColumn(i);
}
@ -137,7 +221,7 @@ public class VectorTable {
// for test only.
public String dump(int rowLimit) {
StringBuilder sb = new StringBuilder();
for (int i = 0; i < rowLimit && i < numRows; i++) {
for (int i = 0; i < rowLimit && i < getNumRows(); i++) {
for (int j = 0; j < columns.length; j++) {
if (j != 0) {
sb.append(", ");

View File

@ -35,10 +35,12 @@ public class JniScannerTest {
{
put("mock_rows", "128");
put("required_fields", "boolean,tinyint,smallint,int,bigint,largeint,float,double,"
+ "date,timestamp,char,varchar,string,decimalv2,decimal64,array,map,struct");
+ "date,timestamp,char,varchar,string,decimalv2,decimal64,array,map,struct,"
+ "decimal18,timestamp4,datev1,datev2,datetimev1,datetimev2");
put("columns_types", "boolean#tinyint#smallint#int#bigint#largeint#float#double#"
+ "date#timestamp#char(10)#varchar(10)#string#decimalv2(12,4)#decimal64(10,3)#"
+ "array<array<string>>#map<string,array<int>>#struct<col1:timestamp(6),col2:array<char(10)>>");
+ "array<array<string>>#map<string,array<int>>#struct<col1:timestamp(6),col2:array<char(10)>>#"
+ "decimal(18,5)#timestamp(4)#datev1#datev2#datetimev1#datetimev2(4)");
}
});
scanner.open();
@ -49,7 +51,7 @@ public class JniScannerTest {
long rows = OffHeap.getLong(null, metaAddress);
Assert.assertEquals(32, rows);
VectorTable restoreTable = new VectorTable(scanner.getTable().getColumnTypes(),
VectorTable restoreTable = VectorTable.createReadableTable(scanner.getTable().getColumnTypes(),
scanner.getTable().getFields(), metaAddress);
System.out.println(restoreTable.dump((int) rows).substring(0, 128));
// Restored table is release by the origin table.