[refactor](jni) unified jni framework for java udf (#25302)

Use the unified jni framework to refactor java udf.
The unified jni framework takes VectorTable as the container to transform data between c++ and java, and hide the details of data format conversion.
In addition, the unified framework supports complex and nested types.
The performance of basic types remains consistent, with a 30% improvement in string types and an order of magnitude improvement in complex types.
This commit is contained in:
Ashin Gau
2023-10-18 09:27:54 +08:00
committed by GitHub
parent 26e332c608
commit 47689fd452
23 changed files with 2153 additions and 742 deletions

View File

@ -68,7 +68,7 @@ Status AvroJNIReader::init_fetch_table_reader(
for (auto& desc : _file_slot_descs) {
std::string field = desc->col_name();
column_names.emplace_back(field);
std::string type = JniConnector::get_hive_type(desc->type());
std::string type = JniConnector::get_jni_type(desc->type());
if (index == 0) {
required_fields << field;
columns_types << type;

View File

@ -44,7 +44,7 @@ MockJniReader::MockJniReader(const std::vector<SlotDescriptor*>& file_slot_descs
int index = 0;
for (auto& desc : _file_slot_descs) {
std::string field = desc->col_name();
std::string type = JniConnector::get_hive_type(desc->type());
std::string type = JniConnector::get_jni_type(desc->type());
column_names.emplace_back(field);
if (index == 0) {
required_fields << field;

View File

@ -49,7 +49,7 @@ MaxComputeJniReader::MaxComputeJniReader(const MaxComputeTableDescriptor* mc_des
int index = 0;
for (auto& desc : _file_slot_descs) {
std::string field = desc->col_name();
std::string type = JniConnector::get_hive_type(desc->type());
std::string type = JniConnector::get_jni_type(desc->type());
column_names.emplace_back(field);
if (index == 0) {
required_fields << field;

View File

@ -44,18 +44,26 @@ class RuntimeProfile;
namespace doris::vectorized {
#define FOR_LOGICAL_NUMERIC_TYPES(M) \
M(TypeIndex::Int8, Int8) \
M(TypeIndex::UInt8, UInt8) \
M(TypeIndex::Int16, Int16) \
M(TypeIndex::UInt16, UInt16) \
M(TypeIndex::Int32, Int32) \
M(TypeIndex::UInt32, UInt32) \
M(TypeIndex::Int64, Int64) \
M(TypeIndex::UInt64, UInt64) \
M(TypeIndex::Int128, Int128) \
M(TypeIndex::Float32, Float32) \
M(TypeIndex::Float64, Float64)
#define FOR_FIXED_LENGTH_TYPES(M) \
M(TypeIndex::Int8, ColumnVector<Int8>, Int8) \
M(TypeIndex::UInt8, ColumnVector<UInt8>, UInt8) \
M(TypeIndex::Int16, ColumnVector<Int16>, Int16) \
M(TypeIndex::UInt16, ColumnVector<UInt16>, UInt16) \
M(TypeIndex::Int32, ColumnVector<Int32>, Int32) \
M(TypeIndex::UInt32, ColumnVector<UInt32>, UInt32) \
M(TypeIndex::Int64, ColumnVector<Int64>, Int64) \
M(TypeIndex::UInt64, ColumnVector<UInt64>, UInt64) \
M(TypeIndex::Int128, ColumnVector<Int128>, Int128) \
M(TypeIndex::Float32, ColumnVector<Float32>, Float32) \
M(TypeIndex::Float64, ColumnVector<Float64>, Float64) \
M(TypeIndex::Decimal128, ColumnDecimal<Decimal<Int128>>, Int128) \
M(TypeIndex::Decimal128I, ColumnDecimal<Decimal<Int128>>, Int128) \
M(TypeIndex::Decimal32, ColumnDecimal<Decimal<Int32>>, Int32) \
M(TypeIndex::Decimal64, ColumnDecimal<Decimal<Int64>>, Int64) \
M(TypeIndex::Date, ColumnVector<Int64>, Int64) \
M(TypeIndex::DateV2, ColumnVector<UInt32>, UInt32) \
M(TypeIndex::DateTime, ColumnVector<Int64>, Int64) \
M(TypeIndex::DateTimeV2, ColumnVector<UInt64>, UInt64)
JniConnector::~JniConnector() {
Status st = close();
@ -121,7 +129,7 @@ Status JniConnector::get_nex_block(Block* block, size_t* read_rows, bool* eof) {
return Status::OK();
}
_set_meta(meta_address);
long num_rows = _next_meta_as_long();
long num_rows = _table_meta.next_meta_as_long();
if (num_rows == 0) {
*read_rows = 0;
*eof = true;
@ -239,15 +247,53 @@ Status JniConnector::_init_jni_scanner(JNIEnv* env, int batch_size) {
return Status::OK();
}
Status JniConnector::fill_block(Block* block, const ColumnNumbers& arguments, long table_address) {
if (table_address == 0) {
return Status::OK();
}
TableMetaAddress table_meta(table_address);
long num_rows = table_meta.next_meta_as_long();
if (num_rows == 0) {
return Status::OK();
}
for (size_t i : arguments) {
if (block->get_by_position(i).column == nullptr) {
auto return_type = block->get_data_type(i);
bool result_nullable = return_type->is_nullable();
ColumnUInt8::MutablePtr null_col = nullptr;
if (result_nullable) {
return_type = remove_nullable(return_type);
null_col = ColumnUInt8::create();
}
auto res_col = return_type->create_column();
if (result_nullable) {
block->replace_by_position(
i, ColumnNullable::create(std::move(res_col), std::move(null_col)));
} else {
block->replace_by_position(i, std::move(res_col));
}
} else if (is_column_const(*(block->get_by_position(i).column))) {
auto doris_column = block->get_by_position(i).column->convert_to_full_column_if_const();
bool is_nullable = block->get_by_position(i).type->is_nullable();
block->replace_by_position(i, is_nullable ? make_nullable(doris_column) : doris_column);
}
auto& column_with_type_and_name = block->get_by_position(i);
auto& column_ptr = column_with_type_and_name.column;
auto& column_type = column_with_type_and_name.type;
RETURN_IF_ERROR(_fill_column(table_meta, column_ptr, column_type, num_rows));
}
return Status::OK();
}
Status JniConnector::_fill_block(Block* block, size_t num_rows) {
SCOPED_TIMER(_fill_block_time);
JNIEnv* env = nullptr;
RETURN_IF_ERROR(JniUtil::GetJNIEnv(&env));
for (int i = 0; i < _column_names.size(); ++i) {
auto& column_with_type_and_name = block->get_by_name(_column_names[i]);
auto& column_ptr = column_with_type_and_name.column;
auto& column_type = column_with_type_and_name.type;
RETURN_IF_ERROR(_fill_column(column_ptr, column_type, num_rows));
JNIEnv* env = nullptr;
RETURN_IF_ERROR(JniUtil::GetJNIEnv(&env));
RETURN_IF_ERROR(_fill_column(_table_meta, column_ptr, column_type, num_rows));
// Column is not released when _fill_column failed. It will be released when releasing table.
env->CallVoidMethod(_jni_scanner_obj, _jni_scanner_release_column, i);
RETURN_ERROR_IF_EXC(env);
@ -255,10 +301,10 @@ Status JniConnector::_fill_block(Block* block, size_t num_rows) {
return Status::OK();
}
Status JniConnector::_fill_column(ColumnPtr& doris_column, DataTypePtr& data_type,
size_t num_rows) {
Status JniConnector::_fill_column(TableMetaAddress& address, ColumnPtr& doris_column,
DataTypePtr& data_type, size_t num_rows) {
TypeIndex logical_type = remove_nullable(data_type)->get_type_id();
void* null_map_ptr = _next_meta_as_ptr();
void* null_map_ptr = address.next_meta_as_ptr();
if (null_map_ptr == nullptr) {
// org.apache.doris.common.jni.vec.ColumnType.Type#UNSUPPORTED will set column address as 0
return Status::InternalError("Unsupported type {} in java side", getTypeName(logical_type));
@ -277,39 +323,22 @@ Status JniConnector::_fill_column(ColumnPtr& doris_column, DataTypePtr& data_typ
}
// Date and DateTime are deprecated and not supported.
switch (logical_type) {
#define DISPATCH(NUMERIC_TYPE, CPP_NUMERIC_TYPE) \
case NUMERIC_TYPE: \
return _fill_numeric_column<CPP_NUMERIC_TYPE>( \
data_column, reinterpret_cast<CPP_NUMERIC_TYPE*>(_next_meta_as_ptr()), num_rows);
FOR_LOGICAL_NUMERIC_TYPES(DISPATCH)
#define DISPATCH(TYPE_INDEX, COLUMN_TYPE, CPP_TYPE) \
case TYPE_INDEX: \
return _fill_fixed_length_column<COLUMN_TYPE, CPP_TYPE>( \
data_column, reinterpret_cast<CPP_TYPE*>(address.next_meta_as_ptr()), num_rows);
FOR_FIXED_LENGTH_TYPES(DISPATCH)
#undef DISPATCH
case TypeIndex::Decimal128:
[[fallthrough]];
case TypeIndex::Decimal128I:
return _fill_decimal_column<Int128>(
data_column, reinterpret_cast<Int128*>(_next_meta_as_ptr()), num_rows);
case TypeIndex::Decimal32:
return _fill_decimal_column<Int32>(data_column,
reinterpret_cast<Int32*>(_next_meta_as_ptr()), num_rows);
case TypeIndex::Decimal64:
return _fill_decimal_column<Int64>(data_column,
reinterpret_cast<Int64*>(_next_meta_as_ptr()), num_rows);
case TypeIndex::DateV2:
return _decode_time_column<UInt32>(
data_column, reinterpret_cast<UInt32*>(_next_meta_as_ptr()), num_rows);
case TypeIndex::DateTimeV2:
return _decode_time_column<UInt64>(
data_column, reinterpret_cast<UInt64*>(_next_meta_as_ptr()), num_rows);
case TypeIndex::String:
[[fallthrough]];
case TypeIndex::FixedString:
return _fill_string_column(data_column, num_rows);
return _fill_string_column(address, data_column, num_rows);
case TypeIndex::Array:
return _fill_array_column(data_column, data_type, num_rows);
return _fill_array_column(address, data_column, data_type, num_rows);
case TypeIndex::Map:
return _fill_map_column(data_column, data_type, num_rows);
return _fill_map_column(address, data_column, data_type, num_rows);
case TypeIndex::Struct:
return _fill_struct_column(data_column, data_type, num_rows);
return _fill_struct_column(address, data_column, data_type, num_rows);
default:
return Status::InvalidArgument("Unsupported type {} in jni scanner",
getTypeName(logical_type));
@ -317,68 +346,8 @@ Status JniConnector::_fill_column(ColumnPtr& doris_column, DataTypePtr& data_typ
return Status::OK();
}
Status JniConnector::_fill_array_column(MutableColumnPtr& doris_column, DataTypePtr& data_type,
size_t num_rows) {
ColumnPtr& element_column = static_cast<ColumnArray&>(*doris_column).get_data_ptr();
DataTypePtr& element_type = const_cast<DataTypePtr&>(
(reinterpret_cast<const DataTypeArray*>(remove_nullable(data_type).get()))
->get_nested_type());
ColumnArray::Offsets64& offsets_data = static_cast<ColumnArray&>(*doris_column).get_offsets();
int64* offsets = reinterpret_cast<int64*>(_next_meta_as_ptr());
size_t origin_size = offsets_data.size();
offsets_data.resize(origin_size + num_rows);
size_t start_offset = offsets_data[origin_size - 1];
for (size_t i = 0; i < num_rows; ++i) {
offsets_data[origin_size + i] = offsets[i] + start_offset;
}
// offsets[num_rows - 1] == offsets_data[origin_size + num_rows - 1] - start_offset
// but num_row equals 0 when there are all empty arrays
return _fill_column(element_column, element_type,
offsets_data[origin_size + num_rows - 1] - start_offset);
}
Status JniConnector::_fill_map_column(MutableColumnPtr& doris_column, DataTypePtr& data_type,
size_t num_rows) {
auto& map = static_cast<ColumnMap&>(*doris_column);
DataTypePtr& key_type = const_cast<DataTypePtr&>(
reinterpret_cast<const DataTypeMap*>(remove_nullable(data_type).get())->get_key_type());
DataTypePtr& value_type = const_cast<DataTypePtr&>(
reinterpret_cast<const DataTypeMap*>(remove_nullable(data_type).get())
->get_value_type());
ColumnPtr& key_column = map.get_keys_ptr();
ColumnPtr& value_column = map.get_values_ptr();
ColumnArray::Offsets64& map_offsets = map.get_offsets();
int64* offsets = reinterpret_cast<int64*>(_next_meta_as_ptr());
size_t origin_size = map_offsets.size();
map_offsets.resize(origin_size + num_rows);
size_t start_offset = map_offsets[origin_size - 1];
for (size_t i = 0; i < num_rows; ++i) {
map_offsets[origin_size + i] = offsets[i] + start_offset;
}
RETURN_IF_ERROR(_fill_column(key_column, key_type,
map_offsets[origin_size + num_rows - 1] - start_offset));
return _fill_column(value_column, value_type,
map_offsets[origin_size + num_rows - 1] - start_offset);
}
Status JniConnector::_fill_struct_column(MutableColumnPtr& doris_column, DataTypePtr& data_type,
Status JniConnector::_fill_string_column(TableMetaAddress& address, MutableColumnPtr& doris_column,
size_t num_rows) {
auto& doris_struct = static_cast<ColumnStruct&>(*doris_column);
const DataTypeStruct* doris_struct_type =
reinterpret_cast<const DataTypeStruct*>(remove_nullable(data_type).get());
for (int i = 0; i < doris_struct.tuple_size(); ++i) {
ColumnPtr& struct_field = doris_struct.get_column_ptr(i);
DataTypePtr& field_type = const_cast<DataTypePtr&>(doris_struct_type->get_element(i));
RETURN_IF_ERROR(_fill_column(struct_field, field_type, num_rows));
}
return Status::OK();
}
Status JniConnector::_fill_string_column(MutableColumnPtr& doris_column, size_t num_rows) {
if (num_rows == 0) {
return Status::OK();
}
@ -386,8 +355,8 @@ Status JniConnector::_fill_string_column(MutableColumnPtr& doris_column, size_t
ColumnString::Chars& string_chars = const_cast<ColumnString::Chars&>(string_col.get_chars());
ColumnString::Offsets& string_offsets =
const_cast<ColumnString::Offsets&>(string_col.get_offsets());
int* offsets = reinterpret_cast<int*>(_next_meta_as_ptr());
char* chars = reinterpret_cast<char*>(_next_meta_as_ptr());
int* offsets = reinterpret_cast<int*>(address.next_meta_as_ptr());
char* chars = reinterpret_cast<char*>(address.next_meta_as_ptr());
size_t origin_chars_size = string_chars.size();
string_chars.resize(origin_chars_size + offsets[num_rows - 1]);
@ -402,6 +371,67 @@ Status JniConnector::_fill_string_column(MutableColumnPtr& doris_column, size_t
return Status::OK();
}
Status JniConnector::_fill_array_column(TableMetaAddress& address, MutableColumnPtr& doris_column,
DataTypePtr& data_type, size_t num_rows) {
ColumnPtr& element_column = static_cast<ColumnArray&>(*doris_column).get_data_ptr();
DataTypePtr& element_type = const_cast<DataTypePtr&>(
(reinterpret_cast<const DataTypeArray*>(remove_nullable(data_type).get()))
->get_nested_type());
ColumnArray::Offsets64& offsets_data = static_cast<ColumnArray&>(*doris_column).get_offsets();
int64* offsets = reinterpret_cast<int64*>(address.next_meta_as_ptr());
size_t origin_size = offsets_data.size();
offsets_data.resize(origin_size + num_rows);
size_t start_offset = offsets_data[origin_size - 1];
for (size_t i = 0; i < num_rows; ++i) {
offsets_data[origin_size + i] = offsets[i] + start_offset;
}
// offsets[num_rows - 1] == offsets_data[origin_size + num_rows - 1] - start_offset
// but num_row equals 0 when there are all empty arrays
return _fill_column(address, element_column, element_type,
offsets_data[origin_size + num_rows - 1] - start_offset);
}
Status JniConnector::_fill_map_column(TableMetaAddress& address, MutableColumnPtr& doris_column,
DataTypePtr& data_type, size_t num_rows) {
auto& map = static_cast<ColumnMap&>(*doris_column);
DataTypePtr& key_type = const_cast<DataTypePtr&>(
reinterpret_cast<const DataTypeMap*>(remove_nullable(data_type).get())->get_key_type());
DataTypePtr& value_type = const_cast<DataTypePtr&>(
reinterpret_cast<const DataTypeMap*>(remove_nullable(data_type).get())
->get_value_type());
ColumnPtr& key_column = map.get_keys_ptr();
ColumnPtr& value_column = map.get_values_ptr();
ColumnArray::Offsets64& map_offsets = map.get_offsets();
int64* offsets = reinterpret_cast<int64*>(address.next_meta_as_ptr());
size_t origin_size = map_offsets.size();
map_offsets.resize(origin_size + num_rows);
size_t start_offset = map_offsets[origin_size - 1];
for (size_t i = 0; i < num_rows; ++i) {
map_offsets[origin_size + i] = offsets[i] + start_offset;
}
RETURN_IF_ERROR(_fill_column(address, key_column, key_type,
map_offsets[origin_size + num_rows - 1] - start_offset));
return _fill_column(address, value_column, value_type,
map_offsets[origin_size + num_rows - 1] - start_offset);
}
Status JniConnector::_fill_struct_column(TableMetaAddress& address, MutableColumnPtr& doris_column,
DataTypePtr& data_type, size_t num_rows) {
auto& doris_struct = static_cast<ColumnStruct&>(*doris_column);
const DataTypeStruct* doris_struct_type =
reinterpret_cast<const DataTypeStruct*>(remove_nullable(data_type).get());
for (int i = 0; i < doris_struct.tuple_size(); ++i) {
ColumnPtr& struct_field = doris_struct.get_column_ptr(i);
DataTypePtr& field_type = const_cast<DataTypePtr&>(doris_struct_type->get_element(i));
RETURN_IF_ERROR(_fill_column(address, struct_field, field_type, num_rows));
}
return Status::OK();
}
void JniConnector::_generate_predicates(
std::unordered_map<std::string, ColumnValueRangeType>* colname_to_value_range) {
if (colname_to_value_range == nullptr) {
@ -414,7 +444,93 @@ void JniConnector::_generate_predicates(
}
}
std::string JniConnector::get_hive_type(const TypeDescriptor& desc) {
std::string JniConnector::get_jni_type(const DataTypePtr& data_type) {
DataTypePtr type = remove_nullable(data_type);
std::ostringstream buffer;
switch (type->get_type_as_primitive_type()) {
case TYPE_BOOLEAN:
return "boolean";
case TYPE_TINYINT:
return "tinyint";
case TYPE_SMALLINT:
return "smallint";
case TYPE_INT:
return "int";
case TYPE_BIGINT:
return "bigint";
case TYPE_LARGEINT:
return "largeint";
case TYPE_FLOAT:
return "float";
case TYPE_DOUBLE:
return "double";
case TYPE_VARCHAR:
[[fallthrough]];
case TYPE_CHAR:
[[fallthrough]];
case TYPE_STRING:
return "string";
case TYPE_DATE:
return "datev1";
case TYPE_DATEV2:
return "datev2";
case TYPE_DATETIME:
[[fallthrough]];
case TYPE_TIME:
return "datetimev1";
case TYPE_DATETIMEV2:
[[fallthrough]];
case TYPE_TIMEV2:
// can ignore precision of timestamp in jni
return "datetimev2";
case TYPE_BINARY:
return "binary";
case TYPE_DECIMALV2: {
buffer << "decimalv2(" << DecimalV2Value::PRECISION << "," << DecimalV2Value::SCALE << ")";
return buffer.str();
}
case TYPE_DECIMAL32: {
buffer << "decimal32(" << type->get_precision() << "," << type->get_scale() << ")";
return buffer.str();
}
case TYPE_DECIMAL64: {
buffer << "decimal64(" << type->get_precision() << "," << type->get_scale() << ")";
return buffer.str();
}
case TYPE_DECIMAL128I: {
buffer << "decimal128(" << type->get_precision() << "," << type->get_scale() << ")";
return buffer.str();
}
case TYPE_STRUCT: {
const DataTypeStruct* struct_type = reinterpret_cast<const DataTypeStruct*>(type.get());
buffer << "struct<";
for (int i = 0; i < struct_type->get_elements().size(); ++i) {
if (i != 0) {
buffer << ",";
}
buffer << struct_type->get_element_names()[i] << ":"
<< get_jni_type(struct_type->get_element(i));
}
buffer << ">";
return buffer.str();
}
case TYPE_ARRAY: {
const DataTypeArray* array_type = reinterpret_cast<const DataTypeArray*>(type.get());
buffer << "array<" << get_jni_type(array_type->get_nested_type()) << ">";
return buffer.str();
}
case TYPE_MAP: {
const DataTypeMap* map_type = reinterpret_cast<const DataTypeMap*>(type.get());
buffer << "map<" << get_jni_type(map_type->get_key_type()) << ","
<< get_jni_type(map_type->get_value_type()) << ">";
return buffer.str();
}
default:
return "unsupported";
}
}
std::string JniConnector::get_jni_type(const TypeDescriptor& desc) {
std::ostringstream buffer;
switch (desc.type) {
case TYPE_BOOLEAN:
@ -438,17 +554,18 @@ std::string JniConnector::get_hive_type(const TypeDescriptor& desc) {
return buffer.str();
}
case TYPE_DATE:
[[fallthrough]];
return "datev1";
case TYPE_DATEV2:
return "date";
return "datev2";
case TYPE_DATETIME:
[[fallthrough]];
case TYPE_TIME:
return "datetimev1";
case TYPE_DATETIMEV2:
[[fallthrough]];
case TYPE_TIME:
[[fallthrough]];
case TYPE_TIMEV2:
return "timestamp";
// can ignore precision of timestamp in jni
return "datetimev2";
case TYPE_BINARY:
return "binary";
case TYPE_CHAR: {
@ -479,18 +596,18 @@ std::string JniConnector::get_hive_type(const TypeDescriptor& desc) {
if (i != 0) {
buffer << ",";
}
buffer << desc.field_names[i] << ":" << get_hive_type(desc.children[i]);
buffer << desc.field_names[i] << ":" << get_jni_type(desc.children[i]);
}
buffer << ">";
return buffer.str();
}
case TYPE_ARRAY: {
buffer << "array<" << get_hive_type(desc.children[0]) << ">";
buffer << "array<" << get_jni_type(desc.children[0]) << ">";
return buffer.str();
}
case TYPE_MAP: {
buffer << "map<" << get_hive_type(desc.children[0]) << ","
<< get_hive_type(desc.children[1]) << ">";
buffer << "map<" << get_jni_type(desc.children[0]) << "," << get_jni_type(desc.children[1])
<< ">";
return buffer.str();
}
default:
@ -498,8 +615,8 @@ std::string JniConnector::get_hive_type(const TypeDescriptor& desc) {
}
}
void JniConnector::_fill_column_meta(ColumnPtr& doris_column, DataTypePtr& data_type,
std::vector<long>& meta_data) {
Status JniConnector::_fill_column_meta(ColumnPtr& doris_column, DataTypePtr& data_type,
std::vector<long>& meta_data) {
TypeIndex logical_type = remove_nullable(data_type)->get_type_id();
// insert null map address
MutableColumnPtr data_column;
@ -514,35 +631,13 @@ void JniConnector::_fill_column_meta(ColumnPtr& doris_column, DataTypePtr& data_
data_column = doris_column->assume_mutable();
}
switch (logical_type) {
#define DISPATCH(NUMERIC_TYPE, CPP_NUMERIC_TYPE) \
case NUMERIC_TYPE: { \
meta_data.emplace_back(_get_numeric_data_address<CPP_NUMERIC_TYPE>(data_column)); \
break; \
#define DISPATCH(TYPE_INDEX, COLUMN_TYPE, CPP_TYPE) \
case TYPE_INDEX: { \
meta_data.emplace_back(_get_fixed_length_column_address<COLUMN_TYPE>(data_column)); \
break; \
}
FOR_LOGICAL_NUMERIC_TYPES(DISPATCH)
FOR_FIXED_LENGTH_TYPES(DISPATCH)
#undef DISPATCH
case TypeIndex::Decimal128:
[[fallthrough]];
case TypeIndex::Decimal128I: {
meta_data.emplace_back(_get_decimal_data_address<Int128>(data_column));
break;
}
case TypeIndex::Decimal32: {
meta_data.emplace_back(_get_decimal_data_address<Int32>(data_column));
break;
}
case TypeIndex::Decimal64: {
meta_data.emplace_back(_get_decimal_data_address<Int64>(data_column));
break;
}
case TypeIndex::DateV2: {
meta_data.emplace_back(_get_time_data_address<UInt32>(data_column));
break;
}
case TypeIndex::DateTimeV2: {
meta_data.emplace_back(_get_time_data_address<UInt64>(data_column));
break;
}
case TypeIndex::String:
[[fallthrough]];
case TypeIndex::FixedString: {
@ -558,7 +653,7 @@ void JniConnector::_fill_column_meta(ColumnPtr& doris_column, DataTypePtr& data_
DataTypePtr& element_type = const_cast<DataTypePtr&>(
(reinterpret_cast<const DataTypeArray*>(remove_nullable(data_type).get()))
->get_nested_type());
_fill_column_meta(element_column, element_type, meta_data);
RETURN_IF_ERROR(_fill_column_meta(element_column, element_type, meta_data));
break;
}
case TypeIndex::Struct: {
@ -568,7 +663,7 @@ void JniConnector::_fill_column_meta(ColumnPtr& doris_column, DataTypePtr& data_
for (int i = 0; i < doris_struct.tuple_size(); ++i) {
ColumnPtr& struct_field = doris_struct.get_column_ptr(i);
DataTypePtr& field_type = const_cast<DataTypePtr&>(doris_struct_type->get_element(i));
_fill_column_meta(struct_field, field_type, meta_data);
RETURN_IF_ERROR(_fill_column_meta(struct_field, field_type, meta_data));
}
break;
}
@ -583,27 +678,81 @@ void JniConnector::_fill_column_meta(ColumnPtr& doris_column, DataTypePtr& data_
ColumnPtr& key_column = map.get_keys_ptr();
ColumnPtr& value_column = map.get_values_ptr();
meta_data.emplace_back((long)map.get_offsets().data());
_fill_column_meta(key_column, key_type, meta_data);
_fill_column_meta(value_column, value_type, meta_data);
RETURN_IF_ERROR(_fill_column_meta(key_column, key_type, meta_data));
RETURN_IF_ERROR(_fill_column_meta(value_column, value_type, meta_data));
break;
}
default:
return;
return Status::InternalError("Unsupported type: {}", getTypeName(logical_type));
}
return Status::OK();
}
Status JniConnector::generate_meta_info(Block* block, std::unique_ptr<long[]>& meta) {
Status JniConnector::to_java_table(Block* block, std::unique_ptr<long[]>& meta) {
ColumnNumbers arguments;
for (size_t i = 0; i < block->columns(); ++i) {
arguments.emplace_back(i);
}
return to_java_table(block, block->rows(), arguments, meta);
}
Status JniConnector::to_java_table(Block* block, size_t num_rows, const ColumnNumbers& arguments,
std::unique_ptr<long[]>& meta) {
std::vector<long> meta_data;
// insert number of rows
meta_data.emplace_back(block->rows());
for (int i = 0; i < block->columns(); ++i) {
meta_data.emplace_back(num_rows);
for (size_t i : arguments) {
if (is_column_const(*(block->get_by_position(i).column))) {
auto doris_column = block->get_by_position(i).column->convert_to_full_column_if_const();
bool is_nullable = block->get_by_position(i).type->is_nullable();
block->replace_by_position(i, is_nullable ? make_nullable(doris_column) : doris_column);
}
auto& column_with_type_and_name = block->get_by_position(i);
_fill_column_meta(column_with_type_and_name.column, column_with_type_and_name.type,
meta_data);
RETURN_IF_ERROR(_fill_column_meta(column_with_type_and_name.column,
column_with_type_and_name.type, meta_data));
}
meta.reset(new long[meta_data.size()]);
memcpy(meta.get(), &meta_data[0], meta_data.size() * 8);
return Status::OK();
}
std::pair<std::string, std::string> JniConnector::parse_table_schema(Block* block,
const ColumnNumbers& arguments,
bool ignore_column_name) {
// prepare table schema
std::ostringstream required_fields;
std::ostringstream columns_types;
for (int i = 0; i < arguments.size(); ++i) {
// column name maybe empty or has special characters
// std::string field = block->get_by_position(i).name;
std::string type = JniConnector::get_jni_type(block->get_by_position(arguments[i]).type);
if (i == 0) {
if (ignore_column_name) {
required_fields << "_col_" << arguments[i];
} else {
required_fields << block->get_by_position(arguments[i]).name;
}
columns_types << type;
} else {
if (ignore_column_name) {
required_fields << ","
<< "_col_" << arguments[i];
} else {
required_fields << "," << block->get_by_position(arguments[i]).name;
}
columns_types << "#" << type;
}
}
return std::make_pair(required_fields.str(), columns_types.str());
}
std::pair<std::string, std::string> JniConnector::parse_table_schema(Block* block) {
ColumnNumbers arguments;
for (size_t i = 0; i < block->columns(); ++i) {
arguments.emplace_back(i);
}
return parse_table_schema(block, arguments, true);
}
} // namespace doris::vectorized

View File

@ -60,6 +60,32 @@ namespace doris::vectorized {
*/
class JniConnector {
public:
class TableMetaAddress {
private:
long* _meta_ptr;
int _meta_index;
public:
TableMetaAddress() {
_meta_ptr = nullptr;
_meta_index = 0;
}
TableMetaAddress(long meta_addr) {
_meta_ptr = static_cast<long*>(reinterpret_cast<void*>(meta_addr));
_meta_index = 0;
}
void set_meta(long meta_addr) {
_meta_ptr = static_cast<long*>(reinterpret_cast<void*>(meta_addr));
_meta_index = 0;
}
long next_meta_as_long() { return _meta_ptr[_meta_index++]; }
void* next_meta_as_ptr() { return reinterpret_cast<void*>(_meta_ptr[_meta_index++]); }
};
/**
* The predicates that can be pushed down to java side.
* Reference to java class org.apache.doris.common.jni.vec.ScanPredicate
@ -220,11 +246,7 @@ public:
/**
* Call java side function JniScanner.getTableSchema.
*
* The schema information are stored as a string.
* Use # between column names and column types.
*
* like: col_name1,col_name2,col_name3#col_type1,col_type2.col_type3
*
* The schema information are stored as json format
*/
Status get_table_schema(std::string& table_schema_str);
@ -233,12 +255,25 @@ public:
*/
Status close();
static std::string get_jni_type(const DataTypePtr& data_type);
/**
* Map PrimitiveType to hive type.
*/
static std::string get_hive_type(const TypeDescriptor& desc);
static std::string get_jni_type(const TypeDescriptor& desc);
static Status generate_meta_info(Block* block, std::unique_ptr<long[]>& meta);
static Status to_java_table(Block* block, size_t num_rows, const ColumnNumbers& arguments,
std::unique_ptr<long[]>& meta);
static Status to_java_table(Block* block, std::unique_ptr<long[]>& meta);
static std::pair<std::string, std::string> parse_table_schema(Block* block,
const ColumnNumbers& arguments,
bool ignore_column_name = true);
static std::pair<std::string, std::string> parse_table_schema(Block* block);
static Status fill_block(Block* block, const ColumnNumbers& arguments, long table_address);
private:
std::string _connector_name;
@ -268,8 +303,7 @@ private:
jmethodID _jni_scanner_release_table;
jmethodID _jni_scanner_get_statistics;
long* _meta_ptr;
int _meta_index;
TableMetaAddress _table_meta;
int _predicates_length = 0;
std::unique_ptr<char[]> _predicates = nullptr;
@ -277,88 +311,45 @@ private:
/**
* Set the address of meta information, which is returned by org.apache.doris.common.jni.JniScanner#getNextBatchMeta
*/
void _set_meta(long meta_addr) {
_meta_ptr = static_cast<long*>(reinterpret_cast<void*>(meta_addr));
_meta_index = 0;
}
/**
* Get the number of rows in next batch.
*/
long _next_meta_as_long() { return _meta_ptr[_meta_index++]; }
/**
* Get the next column address
*/
void* _next_meta_as_ptr() { return reinterpret_cast<void*>(_meta_ptr[_meta_index++]); }
void _set_meta(long meta_addr) { _table_meta.set_meta(meta_addr); }
Status _init_jni_scanner(JNIEnv* env, int batch_size);
Status _fill_block(Block* block, size_t num_rows);
Status _fill_column(ColumnPtr& doris_column, DataTypePtr& data_type, size_t num_rows);
static Status _fill_column(TableMetaAddress& address, ColumnPtr& doris_column,
DataTypePtr& data_type, size_t num_rows);
Status _fill_map_column(MutableColumnPtr& doris_column, DataTypePtr& data_type,
size_t num_rows);
static Status _fill_string_column(TableMetaAddress& address, MutableColumnPtr& doris_column,
size_t num_rows);
Status _fill_array_column(MutableColumnPtr& doris_column, DataTypePtr& data_type,
size_t num_rows);
static Status _fill_map_column(TableMetaAddress& address, MutableColumnPtr& doris_column,
DataTypePtr& data_type, size_t num_rows);
Status _fill_struct_column(MutableColumnPtr& doris_column, DataTypePtr& data_type,
size_t num_rows);
static Status _fill_array_column(TableMetaAddress& address, MutableColumnPtr& doris_column,
DataTypePtr& data_type, size_t num_rows);
static void _fill_column_meta(ColumnPtr& doris_column, DataTypePtr& data_type,
std::vector<long>& meta_data);
static Status _fill_struct_column(TableMetaAddress& address, MutableColumnPtr& doris_column,
DataTypePtr& data_type, size_t num_rows);
template <typename CppType>
Status _fill_numeric_column(MutableColumnPtr& doris_column, CppType* ptr, size_t num_rows) {
auto& column_data = static_cast<ColumnVector<CppType>&>(*doris_column).get_data();
static Status _fill_column_meta(ColumnPtr& doris_column, DataTypePtr& data_type,
std::vector<long>& meta_data);
template <typename COLUMN_TYPE, typename CPP_TYPE>
static Status _fill_fixed_length_column(MutableColumnPtr& doris_column, CPP_TYPE* ptr,
size_t num_rows) {
auto& column_data = static_cast<COLUMN_TYPE&>(*doris_column).get_data();
size_t origin_size = column_data.size();
column_data.resize(origin_size + num_rows);
memcpy(column_data.data() + origin_size, ptr, sizeof(CppType) * num_rows);
memcpy(column_data.data() + origin_size, ptr, sizeof(CPP_TYPE) * num_rows);
return Status::OK();
}
template <typename CppType>
static long _get_numeric_data_address(MutableColumnPtr& doris_column) {
return (long)static_cast<ColumnVector<CppType>&>(*doris_column).get_data().data();
template <typename COLUMN_TYPE>
static long _get_fixed_length_column_address(MutableColumnPtr& doris_column) {
return (long)static_cast<COLUMN_TYPE&>(*doris_column).get_data().data();
}
template <typename DecimalPrimitiveType>
Status _fill_decimal_column(MutableColumnPtr& doris_column, DecimalPrimitiveType* ptr,
size_t num_rows) {
auto& column_data =
static_cast<ColumnDecimal<Decimal<DecimalPrimitiveType>>&>(*doris_column)
.get_data();
size_t origin_size = column_data.size();
column_data.resize(origin_size + num_rows);
memcpy(column_data.data() + origin_size, ptr, sizeof(DecimalPrimitiveType) * num_rows);
return Status::OK();
}
template <typename DecimalPrimitiveType>
static long _get_decimal_data_address(MutableColumnPtr& doris_column) {
return (long)static_cast<ColumnDecimal<Decimal<DecimalPrimitiveType>>&>(*doris_column)
.get_data()
.data();
}
template <typename CppType>
Status _decode_time_column(MutableColumnPtr& doris_column, CppType* ptr, size_t num_rows) {
auto& column_data = static_cast<ColumnVector<CppType>&>(*doris_column).get_data();
size_t origin_size = column_data.size();
column_data.resize(origin_size + num_rows);
memcpy(column_data.data() + origin_size, ptr, sizeof(CppType) * num_rows);
return Status::OK();
}
template <typename CppType>
static long _get_time_data_address(MutableColumnPtr& doris_column) {
return (long)static_cast<ColumnVector<CppType>&>(*doris_column).get_data().data();
}
Status _fill_string_column(MutableColumnPtr& doris_column, size_t num_rows);
void _generate_predicates(
std::unordered_map<std::string, ColumnValueRangeType>* colname_to_value_range);

View File

@ -995,48 +995,17 @@ Status JdbcConnector::exec_stmt_write(Block* block, const VExprContextSPtrs& out
JNIEnv* env = nullptr;
RETURN_IF_ERROR(JniUtil::GetJNIEnv(&env));
// prepare table schema
std::ostringstream required_fields;
std::ostringstream columns_types;
for (int i = 0; i < block->columns(); ++i) {
// column name maybe empty or has special characters
// std::string field = block->get_by_position(i).name;
std::string type = JniConnector::get_hive_type(output_vexpr_ctxs[i]->root()->type());
if (i == 0) {
required_fields << "_col" << i;
columns_types << type;
} else {
required_fields << ","
<< "_col" << i;
columns_types << "#" << type;
}
}
// prepare table meta information
std::unique_ptr<long[]> meta_data;
RETURN_IF_ERROR(JniConnector::generate_meta_info(block, meta_data));
RETURN_IF_ERROR(JniConnector::to_java_table(block, meta_data));
long meta_address = (long)meta_data.get();
auto table_schema = JniConnector::parse_table_schema(block);
// prepare constructor parameters
std::map<String, String> write_params = {{"meta_address", std::to_string(meta_address)},
{"required_fields", required_fields.str()},
{"columns_types", columns_types.str()},
{"write_sql", "/* todo */"}};
jclass hashmap_class = env->FindClass("java/util/HashMap");
jmethodID hashmap_constructor = env->GetMethodID(hashmap_class, "<init>", "(I)V");
jobject hashmap_object =
env->NewObject(hashmap_class, hashmap_constructor, write_params.size());
jmethodID hashmap_put = env->GetMethodID(
hashmap_class, "put", "(Ljava/lang/Object;Ljava/lang/Object;)Ljava/lang/Object;");
RETURN_ERROR_IF_EXC(env);
for (const auto& it : write_params) {
jstring key = env->NewStringUTF(it.first.c_str());
jstring value = env->NewStringUTF(it.second.c_str());
env->CallObjectMethod(hashmap_object, hashmap_put, key, value);
env->DeleteLocalRef(key);
env->DeleteLocalRef(value);
}
env->DeleteLocalRef(hashmap_class);
{"required_fields", table_schema.first},
{"columns_types", table_schema.second}};
jobject hashmap_object = JniUtil::convert_to_java_map(env, write_params);
env->CallNonvirtualIntMethod(_executor_obj, _executor_clazz, _executor_stmt_write_id,
hashmap_object);
env->DeleteLocalRef(hashmap_object);

View File

@ -42,10 +42,11 @@
#include "vec/core/block.h"
#include "vec/data_types/data_type_array.h"
#include "vec/data_types/data_type_nullable.h"
#include "vec/exec/jni_connector.h"
const char* EXECUTOR_CLASS = "org/apache/doris/udf/UdfExecutor";
const char* EXECUTOR_CTOR_SIGNATURE = "([B)V";
const char* EXECUTOR_EVALUATE_SIGNATURE = "()V";
const char* EXECUTOR_EVALUATE_SIGNATURE = "(Ljava/util/Map;Ljava/util/Map;)J";
const char* EXECUTOR_CLOSE_SIGNATURE = "()V";
namespace doris::vectorized {
@ -65,21 +66,8 @@ Status JavaFunctionCall::open(FunctionContext* context, FunctionContext::Functio
jni_env->executor_ctor_id =
env->GetMethodID(jni_env->executor_cl, "<init>", EXECUTOR_CTOR_SIGNATURE);
RETURN_ERROR_IF_EXC(env);
jni_env->executor_evaluate_id = env->GetMethodID(
jni_env->executor_cl, "evaluate", "(I[Ljava/lang/Object;)[Ljava/lang/Object;");
jni_env->executor_convert_basic_argument_id = env->GetMethodID(
jni_env->executor_cl, "convertBasicArguments", "(IZIJJJ)[Ljava/lang/Object;");
jni_env->executor_convert_array_argument_id = env->GetMethodID(
jni_env->executor_cl, "convertArrayArguments", "(IZIJJJJJ)[Ljava/lang/Object;");
jni_env->executor_convert_map_argument_id = env->GetMethodID(
jni_env->executor_cl, "convertMapArguments", "(IZIJJJJJJJJ)[Ljava/lang/Object;");
jni_env->executor_result_basic_batch_id = env->GetMethodID(
jni_env->executor_cl, "copyBatchBasicResult", "(ZI[Ljava/lang/Object;JJJ)V");
jni_env->executor_result_array_batch_id = env->GetMethodID(
jni_env->executor_cl, "copyBatchArrayResult", "(ZI[Ljava/lang/Object;JJJJJ)V");
jni_env->executor_result_map_batch_id = env->GetMethodID(
jni_env->executor_cl, "copyBatchMapResult", "(ZI[Ljava/lang/Object;JJJJJJJJ)V");
jni_env->executor_evaluate_id =
env->GetMethodID(jni_env->executor_cl, "evaluate", EXECUTOR_EVALUATE_SIGNATURE);
jni_env->executor_close_id =
env->GetMethodID(jni_env->executor_cl, "close", EXECUTOR_CLOSE_SIGNATURE);
RETURN_ERROR_IF_EXC(env);
@ -132,288 +120,29 @@ Status JavaFunctionCall::execute_impl(FunctionContext* context, Block& block,
context->get_function_state(FunctionContext::THREAD_LOCAL));
JniEnv* jni_env =
reinterpret_cast<JniEnv*>(context->get_function_state(FunctionContext::FRAGMENT_LOCAL));
int arg_size = arguments.size();
ColumnPtr data_cols[arg_size];
ColumnPtr null_cols[arg_size];
jclass obj_class = env->FindClass("[Ljava/lang/Object;");
jclass arraylist_class = env->FindClass("Ljava/util/ArrayList;");
jclass hashmap_class = env->FindClass("Ljava/util/HashMap;");
jobjectArray arg_objects = env->NewObjectArray(arg_size, obj_class, nullptr);
int64_t nullmap_address = 0;
for (size_t arg_idx = 0; arg_idx < arg_size; ++arg_idx) {
bool arg_column_nullable = false;
// get argument column and type
ColumnWithTypeAndName& column = block.get_by_position(arguments[arg_idx]);
auto column_type = column.type;
data_cols[arg_idx] = column.column->convert_to_full_column_if_const();
// check type
DCHECK(_argument_types[arg_idx]->equals(*column_type))
<< " input column's type is " + column_type->get_name()
<< " does not equal to required type " << _argument_types[arg_idx]->get_name();
// get argument null map and nested column
if (auto* nullable = check_and_get_column<const ColumnNullable>(*data_cols[arg_idx])) {
arg_column_nullable = true;
column_type = remove_nullable(column_type);
null_cols[arg_idx] = nullable->get_null_map_column_ptr();
data_cols[arg_idx] = nullable->get_nested_column_ptr();
nullmap_address = reinterpret_cast<int64_t>(
check_and_get_column<ColumnVector<UInt8>>(null_cols[arg_idx])
->get_data()
.data());
}
// convert argument column data into java type
jobjectArray arr_obj = nullptr;
if (data_cols[arg_idx]->is_numeric() || data_cols[arg_idx]->is_column_decimal()) {
arr_obj = (jobjectArray)env->CallNonvirtualObjectMethod(
jni_ctx->executor, jni_env->executor_cl,
jni_env->executor_convert_basic_argument_id, arg_idx, arg_column_nullable,
num_rows, nullmap_address,
reinterpret_cast<int64_t>(data_cols[arg_idx]->get_raw_data().data), 0);
} else if (data_cols[arg_idx]->is_column_string()) {
const ColumnString* str_col =
assert_cast<const ColumnString*>(data_cols[arg_idx].get());
arr_obj = (jobjectArray)env->CallNonvirtualObjectMethod(
jni_ctx->executor, jni_env->executor_cl,
jni_env->executor_convert_basic_argument_id, arg_idx, arg_column_nullable,
num_rows, nullmap_address,
reinterpret_cast<int64_t>(str_col->get_chars().data()),
reinterpret_cast<int64_t>(str_col->get_offsets().data()));
} else if (data_cols[arg_idx]->is_column_array()) {
const ColumnArray* array_col =
assert_cast<const ColumnArray*>(data_cols[arg_idx].get());
const ColumnNullable& array_nested_nullable =
assert_cast<const ColumnNullable&>(array_col->get_data());
auto data_column_null_map = array_nested_nullable.get_null_map_column_ptr();
auto data_column = array_nested_nullable.get_nested_column_ptr();
auto offset_address =
reinterpret_cast<int64_t>(array_col->get_offsets_column().get_raw_data().data);
auto nested_nullmap_address = reinterpret_cast<int64_t>(
check_and_get_column<ColumnVector<UInt8>>(data_column_null_map)
->get_data()
.data());
int64_t nested_data_address = 0, nested_offset_address = 0;
// array type need pass address: [nullmap_address], offset_address, nested_nullmap_address, nested_data_address/nested_char_address,nested_offset_address
if (data_column->is_column_string()) {
const ColumnString* col = assert_cast<const ColumnString*>(data_column.get());
nested_data_address = reinterpret_cast<int64_t>(col->get_chars().data());
nested_offset_address = reinterpret_cast<int64_t>(col->get_offsets().data());
} else {
nested_data_address = reinterpret_cast<int64_t>(data_column->get_raw_data().data);
}
arr_obj = (jobjectArray)env->CallNonvirtualObjectMethod(
jni_ctx->executor, jni_env->executor_cl,
jni_env->executor_convert_array_argument_id, arg_idx, arg_column_nullable,
num_rows, nullmap_address, offset_address, nested_nullmap_address,
nested_data_address, nested_offset_address);
} else if (data_cols[arg_idx]->is_column_map()) {
const ColumnMap* map_col = assert_cast<const ColumnMap*>(data_cols[arg_idx].get());
auto offset_address =
reinterpret_cast<int64_t>(map_col->get_offsets_column().get_raw_data().data);
const ColumnNullable& map_key_column_nullable =
assert_cast<const ColumnNullable&>(map_col->get_keys());
auto key_data_column_null_map = map_key_column_nullable.get_null_map_column_ptr();
auto key_data_column = map_key_column_nullable.get_nested_column_ptr();
auto key_nested_nullmap_address = reinterpret_cast<int64_t>(
check_and_get_column<ColumnVector<UInt8>>(key_data_column_null_map)
->get_data()
.data());
int64_t key_nested_data_address = 0, key_nested_offset_address = 0;
if (key_data_column->is_column_string()) {
const ColumnString* col = assert_cast<const ColumnString*>(key_data_column.get());
key_nested_data_address = reinterpret_cast<int64_t>(col->get_chars().data());
key_nested_offset_address = reinterpret_cast<int64_t>(col->get_offsets().data());
} else {
key_nested_data_address =
reinterpret_cast<int64_t>(key_data_column->get_raw_data().data);
}
const ColumnNullable& map_value_column_nullable =
assert_cast<const ColumnNullable&>(map_col->get_values());
auto value_data_column_null_map = map_value_column_nullable.get_null_map_column_ptr();
auto value_data_column = map_value_column_nullable.get_nested_column_ptr();
auto value_nested_nullmap_address = reinterpret_cast<int64_t>(
check_and_get_column<ColumnVector<UInt8>>(value_data_column_null_map)
->get_data()
.data());
int64_t value_nested_data_address = 0, value_nested_offset_address = 0;
if (value_data_column->is_column_string()) {
const ColumnString* col = assert_cast<const ColumnString*>(value_data_column.get());
value_nested_data_address = reinterpret_cast<int64_t>(col->get_chars().data());
value_nested_offset_address = reinterpret_cast<int64_t>(col->get_offsets().data());
} else {
value_nested_data_address =
reinterpret_cast<int64_t>(value_data_column->get_raw_data().data);
}
arr_obj = (jobjectArray)env->CallNonvirtualObjectMethod(
jni_ctx->executor, jni_env->executor_cl,
jni_env->executor_convert_map_argument_id, arg_idx, arg_column_nullable,
num_rows, nullmap_address, offset_address, key_nested_nullmap_address,
key_nested_data_address, key_nested_offset_address,
value_nested_nullmap_address, value_nested_data_address,
value_nested_offset_address);
} else {
return Status::InvalidArgument(
strings::Substitute("Java UDF doesn't support type $0 now !",
_argument_types[arg_idx]->get_name()));
}
env->SetObjectArrayElement(arg_objects, arg_idx, arr_obj);
env->DeleteLocalRef(arr_obj);
}
std::unique_ptr<long[]> input_table;
RETURN_IF_ERROR(JniConnector::to_java_table(&block, num_rows, arguments, input_table));
auto input_table_schema = JniConnector::parse_table_schema(&block, arguments, true);
std::map<String, String> input_params = {
{"meta_address", std::to_string((long)input_table.get())},
{"required_fields", input_table_schema.first},
{"columns_types", input_table_schema.second}};
jobject input_map = JniUtil::convert_to_java_map(env, input_params);
auto output_table_schema = JniConnector::parse_table_schema(&block, {result}, true);
std::string output_nullable =
block.get_by_position(result).type->is_nullable() ? "true" : "false";
std::map<String, String> output_params = {{"is_nullable", output_nullable},
{"required_fields", output_table_schema.first},
{"columns_types", output_table_schema.second}};
jobject output_map = JniUtil::convert_to_java_map(env, output_params);
long output_address = env->CallLongMethod(jni_ctx->executor, jni_env->executor_evaluate_id,
input_map, output_map);
env->DeleteLocalRef(input_map);
env->DeleteLocalRef(output_map);
RETURN_IF_ERROR(JniUtil::GetJniExceptionMsg(env));
// evaluate with argument object
jobjectArray result_obj = (jobjectArray)env->CallNonvirtualObjectMethod(
jni_ctx->executor, jni_env->executor_cl, jni_env->executor_evaluate_id, num_rows,
arg_objects);
env->DeleteLocalRef(arg_objects);
RETURN_IF_ERROR(JniUtil::GetJniExceptionMsg(env));
auto return_type = block.get_data_type(result);
bool result_nullable = return_type->is_nullable();
ColumnUInt8::MutablePtr null_col = nullptr;
if (result_nullable) {
return_type = remove_nullable(return_type);
null_col = ColumnUInt8::create(num_rows, 0);
memset(null_col->get_data().data(), 0, num_rows);
nullmap_address = reinterpret_cast<int64_t>(null_col->get_data().data());
}
auto res_col = return_type->create_column();
res_col->resize(num_rows);
//could resize for column firstly, copy batch result into column
if (res_col->is_numeric() || res_col->is_column_decimal()) {
env->CallNonvirtualVoidMethod(jni_ctx->executor, jni_env->executor_cl,
jni_env->executor_result_basic_batch_id, result_nullable,
num_rows, result_obj, nullmap_address,
reinterpret_cast<int64_t>(res_col->get_raw_data().data), 0);
} else if (res_col->is_column_string()) {
const ColumnString* str_col = assert_cast<const ColumnString*>(res_col.get());
ColumnString::Chars& chars = const_cast<ColumnString::Chars&>(str_col->get_chars());
ColumnString::Offsets& offsets = const_cast<ColumnString::Offsets&>(str_col->get_offsets());
env->CallNonvirtualVoidMethod(
jni_ctx->executor, jni_env->executor_cl, jni_env->executor_result_basic_batch_id,
result_nullable, num_rows, result_obj, nullmap_address,
reinterpret_cast<int64_t>(&chars), reinterpret_cast<int64_t>(offsets.data()));
} else if (res_col->is_column_array()) {
ColumnArray* array_col = assert_cast<ColumnArray*>(res_col.get());
ColumnNullable& array_nested_nullable = assert_cast<ColumnNullable&>(array_col->get_data());
auto data_column_null_map = array_nested_nullable.get_null_map_column_ptr();
auto data_column = array_nested_nullable.get_nested_column_ptr();
auto& offset_column = array_col->get_offsets_column();
auto offset_address = reinterpret_cast<int64_t>(offset_column.get_raw_data().data);
auto& null_map_data =
assert_cast<ColumnVector<UInt8>*>(data_column_null_map.get())->get_data();
auto nested_nullmap_address = reinterpret_cast<int64_t>(null_map_data.data());
jmethodID list_size = env->GetMethodID(arraylist_class, "size", "()I");
int element_size = 0; // get all element size in num_rows of array column
for (int i = 0; i < num_rows; ++i) {
jobject obj = env->GetObjectArrayElement(result_obj, i);
if (obj == nullptr) {
continue;
}
element_size = element_size + env->CallIntMethod(obj, list_size);
env->DeleteLocalRef(obj);
}
array_nested_nullable.resize(element_size);
memset(null_map_data.data(), 0, element_size);
int64_t nested_data_address = 0, nested_offset_address = 0;
// array type need pass address: [nullmap_address], offset_address, nested_nullmap_address, nested_data_address/nested_char_address,nested_offset_address
if (data_column->is_column_string()) {
ColumnString* str_col = assert_cast<ColumnString*>(data_column.get());
ColumnString::Chars& chars = assert_cast<ColumnString::Chars&>(str_col->get_chars());
ColumnString::Offsets& offsets =
assert_cast<ColumnString::Offsets&>(str_col->get_offsets());
nested_data_address = reinterpret_cast<int64_t>(&chars);
nested_offset_address = reinterpret_cast<int64_t>(offsets.data());
} else {
nested_data_address = reinterpret_cast<int64_t>(data_column->get_raw_data().data);
}
env->CallNonvirtualVoidMethod(
jni_ctx->executor, jni_env->executor_cl, jni_env->executor_result_array_batch_id,
result_nullable, num_rows, result_obj, nullmap_address, offset_address,
nested_nullmap_address, nested_data_address, nested_offset_address);
} else if (res_col->is_column_map()) {
ColumnMap* map_col = assert_cast<ColumnMap*>(res_col.get());
auto& offset_column = map_col->get_offsets_column();
auto offset_address = reinterpret_cast<int64_t>(offset_column.get_raw_data().data);
ColumnNullable& map_key_column_nullable = assert_cast<ColumnNullable&>(map_col->get_keys());
auto key_data_column_null_map = map_key_column_nullable.get_null_map_column_ptr();
auto key_data_column = map_key_column_nullable.get_nested_column_ptr();
auto& key_null_map_data =
assert_cast<ColumnVector<UInt8>*>(key_data_column_null_map.get())->get_data();
auto key_nested_nullmap_address = reinterpret_cast<int64_t>(key_null_map_data.data());
ColumnNullable& map_value_column_nullable =
assert_cast<ColumnNullable&>(map_col->get_values());
auto value_data_column_null_map = map_value_column_nullable.get_null_map_column_ptr();
auto value_data_column = map_value_column_nullable.get_nested_column_ptr();
auto& value_null_map_data =
assert_cast<ColumnVector<UInt8>*>(value_data_column_null_map.get())->get_data();
auto value_nested_nullmap_address = reinterpret_cast<int64_t>(value_null_map_data.data());
jmethodID map_size = env->GetMethodID(hashmap_class, "size", "()I");
int element_size = 0; // get all element size in num_rows of map column
for (int i = 0; i < num_rows; ++i) {
jobject obj = env->GetObjectArrayElement(result_obj, i);
if (obj == nullptr) {
continue;
}
element_size = element_size + env->CallIntMethod(obj, map_size);
env->DeleteLocalRef(obj);
}
map_key_column_nullable.resize(element_size);
memset(key_null_map_data.data(), 0, element_size);
map_value_column_nullable.resize(element_size);
memset(value_null_map_data.data(), 0, element_size);
int64_t key_nested_data_address = 0, key_nested_offset_address = 0;
if (key_data_column->is_column_string()) {
ColumnString* str_col = assert_cast<ColumnString*>(key_data_column.get());
ColumnString::Chars& chars = assert_cast<ColumnString::Chars&>(str_col->get_chars());
ColumnString::Offsets& offsets =
assert_cast<ColumnString::Offsets&>(str_col->get_offsets());
key_nested_data_address = reinterpret_cast<int64_t>(&chars);
key_nested_offset_address = reinterpret_cast<int64_t>(offsets.data());
} else {
key_nested_data_address =
reinterpret_cast<int64_t>(key_data_column->get_raw_data().data);
}
int64_t value_nested_data_address = 0, value_nested_offset_address = 0;
if (value_data_column->is_column_string()) {
ColumnString* str_col = assert_cast<ColumnString*>(value_data_column.get());
ColumnString::Chars& chars = assert_cast<ColumnString::Chars&>(str_col->get_chars());
ColumnString::Offsets& offsets =
assert_cast<ColumnString::Offsets&>(str_col->get_offsets());
value_nested_data_address = reinterpret_cast<int64_t>(&chars);
value_nested_offset_address = reinterpret_cast<int64_t>(offsets.data());
} else {
value_nested_data_address =
reinterpret_cast<int64_t>(value_data_column->get_raw_data().data);
}
env->CallNonvirtualVoidMethod(jni_ctx->executor, jni_env->executor_cl,
jni_env->executor_result_map_batch_id, result_nullable,
num_rows, result_obj, nullmap_address, offset_address,
key_nested_nullmap_address, key_nested_data_address,
key_nested_offset_address, value_nested_nullmap_address,
value_nested_data_address, value_nested_offset_address);
} else {
return Status::InvalidArgument(strings::Substitute(
"Java UDF doesn't support return type $0 now !", return_type->get_name()));
}
env->DeleteLocalRef(result_obj);
env->DeleteLocalRef(obj_class);
env->DeleteLocalRef(arraylist_class);
env->DeleteLocalRef(hashmap_class);
if (result_nullable) {
block.replace_by_position(result,
ColumnNullable::create(std::move(res_col), std::move(null_col)));
} else {
block.replace_by_position(result, std::move(res_col));
}
return JniUtil::GetJniExceptionMsg(env);
return JniConnector::fill_block(&block, {result}, output_address);
}
Status JavaFunctionCall::close(FunctionContext* context,

View File

@ -114,25 +114,12 @@ private:
const DataTypes _argument_types;
const DataTypePtr _return_type;
struct IntermediateState {
size_t buffer_size;
size_t row_idx;
IntermediateState() : buffer_size(0), row_idx(0) {}
};
struct JniEnv {
/// Global class reference to the UdfExecutor Java class and related method IDs. Set in
/// Init(). These have the lifetime of the process (i.e. 'executor_cl_' is never freed).
jclass executor_cl;
jmethodID executor_ctor_id;
jmethodID executor_evaluate_id;
jmethodID executor_convert_basic_argument_id;
jmethodID executor_convert_array_argument_id;
jmethodID executor_convert_map_argument_id;
jmethodID executor_result_basic_batch_id;
jmethodID executor_result_array_batch_id;
jmethodID executor_result_map_batch_id;
jmethodID executor_close_id;
};

View File

@ -81,7 +81,7 @@ public abstract class JniScanner {
public long getNextBatchMeta() throws IOException {
if (vectorTable == null) {
vectorTable = new VectorTable(types, fields, predicates, batchSize);
vectorTable = VectorTable.createWritableTable(types, fields, batchSize);
}
int numRows;
try {
@ -107,7 +107,7 @@ public abstract class JniScanner {
}
private long getMetaAddress(int numRows) {
vectorTable.setNumRows(numRows);
assert (numRows == vectorTable.getNumRows());
return vectorTable.getMetaAddress();
}

View File

@ -34,7 +34,7 @@ public class OffHeap {
private static boolean IS_TESTING = false;
private static final Unsafe UNSAFE;
public static final Unsafe UNSAFE;
public static final int BOOLEAN_ARRAY_OFFSET;
@ -78,6 +78,12 @@ public class OffHeap {
return UNSAFE.getInt(object, offset);
}
public static int[] getInt(Object object, long offset, int length) {
int[] result = new int[length];
UNSAFE.copyMemory(object, offset, result, INT_ARRAY_OFFSET, (long) length * Integer.BYTES);
return result;
}
public static void putInt(Object object, long offset, int value) {
UNSAFE.putInt(object, offset, value);
}
@ -86,6 +92,12 @@ public class OffHeap {
return UNSAFE.getBoolean(object, offset);
}
public static boolean[] getBoolean(Object object, long offset, int length) {
boolean[] result = new boolean[length];
UNSAFE.copyMemory(object, offset, result, BOOLEAN_ARRAY_OFFSET, length);
return result;
}
public static void putBoolean(Object object, long offset, boolean value) {
UNSAFE.putBoolean(object, offset, value);
}
@ -94,6 +106,12 @@ public class OffHeap {
return UNSAFE.getByte(object, offset);
}
public static byte[] getByte(Object object, long offset, int length) {
byte[] result = new byte[length];
UNSAFE.copyMemory(object, offset, result, BYTE_ARRAY_OFFSET, length * Byte.BYTES);
return result;
}
public static void putByte(Object object, long offset, byte value) {
UNSAFE.putByte(object, offset, value);
}
@ -102,6 +120,12 @@ public class OffHeap {
return UNSAFE.getShort(object, offset);
}
public static short[] getShort(Object object, long offset, int length) {
short[] result = new short[length];
UNSAFE.copyMemory(object, offset, result, SHORT_ARRAY_OFFSET, (long) length * Short.BYTES);
return result;
}
public static void putShort(Object object, long offset, short value) {
UNSAFE.putShort(object, offset, value);
}
@ -110,6 +134,12 @@ public class OffHeap {
return UNSAFE.getLong(object, offset);
}
public static long[] getLong(Object object, long offset, int length) {
long[] result = new long[length];
UNSAFE.copyMemory(object, offset, result, LONG_ARRAY_OFFSET, (long) length * Long.BYTES);
return result;
}
public static void putLong(Object object, long offset, long value) {
UNSAFE.putLong(object, offset, value);
}
@ -118,6 +148,12 @@ public class OffHeap {
return UNSAFE.getFloat(object, offset);
}
public static float[] getFloat(Object object, long offset, int length) {
float[] result = new float[length];
UNSAFE.copyMemory(object, offset, result, FLOAT_ARRAY_OFFSET, (long) length * Float.BYTES);
return result;
}
public static void putFloat(Object object, long offset, float value) {
UNSAFE.putFloat(object, offset, value);
}
@ -126,6 +162,12 @@ public class OffHeap {
return UNSAFE.getDouble(object, offset);
}
public static double[] getDouble(Object object, long offset, int length) {
double[] result = new double[length];
UNSAFE.copyMemory(object, offset, result, DOUBLE_ARRAY_OFFSET, (long) length * Double.BYTES);
return result;
}
public static void putDouble(Object object, long offset, double value) {
UNSAFE.putDouble(object, offset, value);
}

View File

@ -40,25 +40,21 @@ public class TypeNativeBytes {
}
public static byte[] getBigIntegerBytes(BigInteger v) {
byte[] bytes = v.toByteArray();
// If the BigInteger is not negative and the first byte is 0, remove the first byte
if (v.signum() >= 0 && bytes[0] == 0) {
bytes = Arrays.copyOfRange(bytes, 1, bytes.length);
byte[] bytes = convertByteOrder(v.toByteArray());
// here value is 16 bytes, so if result data greater than the maximum of 16
// bytes, it will return a wrong num to backend;
byte[] value = new byte[16];
// check data is negative
if (v.signum() == -1) {
Arrays.fill(value, (byte) -1);
}
// Convert the byte order if necessary
return convertByteOrder(bytes);
System.arraycopy(bytes, 0, value, 0, Math.min(bytes.length, value.length));
return value;
}
public static BigInteger getBigInteger(byte[] bytes) {
// Convert the byte order back if necessary
byte[] originalBytes = convertByteOrder(bytes);
// If the first byte has the sign bit set, add a 0 byte at the start
if ((originalBytes[0] & 0x80) != 0) {
byte[] extendedBytes = new byte[originalBytes.length + 1];
extendedBytes[0] = 0;
System.arraycopy(originalBytes, 0, extendedBytes, 1, originalBytes.length);
originalBytes = extendedBytes;
}
return new BigInteger(originalBytes);
}
@ -80,6 +76,22 @@ public class TypeNativeBytes {
return new BigDecimal(value, scale);
}
public static long convertToDateTime(int year, int month, int day, int hour, int minute, int second,
boolean isDate) {
long time = 0;
time = time + year;
time = (time << 8) + month;
time = (time << 8) + day;
time = (time << 8) + hour;
time = (time << 8) + minute;
time = (time << 12) + second;
int type = isDate ? 2 : 3;
time = (time << 3) + type;
//this bit is int neg = 0;
time = (time << 1);
return time;
}
public static int convertToDateV2(int year, int month, int day) {
return (int) (day | (long) month << 5 | (long) year << 9);
}
@ -95,20 +107,128 @@ public class TypeNativeBytes {
| (long) day << 37 | (long) month << 42 | (long) year << 46;
}
public static LocalDate convertToJavaDate(int date) {
public static LocalDate convertToJavaDateV1(long date) {
int year = (int) (date >> 48);
int yearMonth = (int) (date >> 40);
int yearMonthDay = (int) (date >> 32);
int month = (yearMonth & 0XFF);
int day = (yearMonthDay & 0XFF);
try {
return LocalDate.of(year, month, day);
} catch (DateTimeException e) {
return null;
}
}
public static Object convertToJavaDateV1(long date, Class clz) {
int year = (int) (date >> 48);
int yearMonth = (int) (date >> 40);
int yearMonthDay = (int) (date >> 32);
int month = (yearMonth & 0XFF);
int day = (yearMonthDay & 0XFF);
try {
if (LocalDate.class.equals(clz)) {
return LocalDate.of(year, month, day);
} else if (java.util.Date.class.equals(clz)) {
return new java.util.Date(year - 1900, month - 1, day);
} else if (org.joda.time.LocalDate.class.equals(clz)) {
return new org.joda.time.LocalDate(year, month, day);
} else {
return null;
}
} catch (Exception e) {
return null;
}
}
public static LocalDate convertToJavaDateV2(int date) {
int year = date >> 9;
int month = (date >> 5) & 0XF;
int day = date & 0X1F;
LocalDate value;
try {
value = LocalDate.of(year, month, day);
return LocalDate.of(year, month, day);
} catch (DateTimeException e) {
value = LocalDate.MAX;
return null;
}
return value;
}
public static LocalDateTime convertToJavaDateTime(long time) {
public static Object convertToJavaDateV2(int date, Class clz) {
int year = date >> 9;
int month = (date >> 5) & 0XF;
int day = date & 0X1F;
try {
if (LocalDate.class.equals(clz)) {
return LocalDate.of(year, month, day);
} else if (java.util.Date.class.equals(clz)) {
return new java.util.Date(year - 1900, month - 1, day);
} else if (org.joda.time.LocalDate.class.equals(clz)) {
return new org.joda.time.LocalDate(year, month, day);
} else {
return null;
}
} catch (Exception e) {
return null;
}
}
public static LocalDateTime convertToJavaDateTimeV1(long time) {
int year = (int) (time >> 48);
int yearMonth = (int) (time >> 40);
int yearMonthDay = (int) (time >> 32);
int month = (yearMonth & 0XFF);
int day = (yearMonthDay & 0XFF);
int hourMinuteSecond = (int) (time % (1 << 31));
int minuteTypeNeg = (hourMinuteSecond % (1 << 16));
int hour = (hourMinuteSecond >> 24);
int minute = ((hourMinuteSecond >> 16) & 0XFF);
int second = (minuteTypeNeg >> 4);
//here don't need those bits are type = ((minus_type_neg >> 1) & 0x7);
try {
return LocalDateTime.of(year, month, day, hour, minute, second);
} catch (DateTimeException e) {
return null;
}
}
public static Object convertToJavaDateTimeV1(long time, Class clz) {
int year = (int) (time >> 48);
int yearMonth = (int) (time >> 40);
int yearMonthDay = (int) (time >> 32);
int month = (yearMonth & 0XFF);
int day = (yearMonthDay & 0XFF);
int hourMinuteSecond = (int) (time % (1 << 31));
int minuteTypeNeg = (hourMinuteSecond % (1 << 16));
int hour = (hourMinuteSecond >> 24);
int minute = ((hourMinuteSecond >> 16) & 0XFF);
int second = (minuteTypeNeg >> 4);
//here don't need those bits are type = ((minus_type_neg >> 1) & 0x7);
try {
if (LocalDateTime.class.equals(clz)) {
return LocalDateTime.of(year, month, day, hour, minute, second);
} else if (org.joda.time.DateTime.class.equals(clz)) {
return new org.joda.time.DateTime(year, month, day, hour, minute, second);
} else if (org.joda.time.LocalDateTime.class.equals(clz)) {
return new org.joda.time.LocalDateTime(year, month, day, hour, minute, second);
} else {
return null;
}
} catch (Exception e) {
return null;
}
}
public static LocalDateTime convertToJavaDateTimeV2(long time) {
int year = (int) (time >> 46);
int yearMonth = (int) (time >> 42);
int yearMonthDay = (int) (time >> 37);
@ -121,12 +241,38 @@ public class TypeNativeBytes {
int second = (int) ((time >> 20) & 0X3F);
int microsecond = (int) (time & 0XFFFFF);
LocalDateTime value;
try {
value = LocalDateTime.of(year, month, day, hour, minute, second, microsecond * 1000);
return LocalDateTime.of(year, month, day, hour, minute, second, microsecond * 1000);
} catch (DateTimeException e) {
value = LocalDateTime.MAX;
return null;
}
}
public static Object convertToJavaDateTimeV2(long time, Class clz) {
int year = (int) (time >> 46);
int yearMonth = (int) (time >> 42);
int yearMonthDay = (int) (time >> 37);
int month = (yearMonth & 0XF);
int day = (yearMonthDay & 0X1F);
int hour = (int) ((time >> 32) & 0X1F);
int minute = (int) ((time >> 26) & 0X3F);
int second = (int) ((time >> 20) & 0X3F);
int microsecond = (int) (time & 0XFFFFF);
try {
if (LocalDateTime.class.equals(clz)) {
return LocalDateTime.of(year, month, day, hour, minute, second, microsecond * 1000);
} else if (org.joda.time.DateTime.class.equals(clz)) {
return new org.joda.time.DateTime(year, month, day, hour, minute, second, microsecond / 1000);
} else if (org.joda.time.LocalDateTime.class.equals(clz)) {
return new org.joda.time.LocalDateTime(year, month, day, hour, minute, second, microsecond / 1000);
} else {
return null;
}
} catch (Exception e) {
return null;
}
return value;
}
}

View File

@ -367,11 +367,11 @@ public class UdfUtils {
0, 0, true);
} else if (java.util.Date.class.equals(clz)) {
java.util.Date date = (java.util.Date) obj;
return convertToDateTime(date.getYear() + 1900, date.getMonth(), date.getDay(), 0,
return convertToDateTime(date.getYear() + 1900, date.getMonth() + 1, date.getDay(), 0,
0, 0, true);
} else if (org.joda.time.LocalDate.class.equals(clz)) {
org.joda.time.LocalDate date = (org.joda.time.LocalDate) obj;
return convertToDateTime(date.getYear(), date.getDayOfMonth(), date.getDayOfMonth(), 0,
return convertToDateTime(date.getYear(), date.getMonthOfYear(), date.getDayOfMonth(), 0,
0, 0, true);
} else {
return 0;

View File

@ -45,7 +45,9 @@ public class ColumnType {
LARGEINT(16),
FLOAT(4),
DOUBLE(8),
DATE(8),
DATEV2(4),
DATETIME(8),
DATETIMEV2(8),
CHAR(-1),
VARCHAR(-1),
@ -161,6 +163,14 @@ public class ColumnType {
return type == Type.STRUCT;
}
public boolean isDateV2() {
return type == Type.DATEV2;
}
public boolean isDateTimeV2() {
return type == Type.DATETIMEV2;
}
public Type getType() {
return type;
}
@ -264,9 +274,16 @@ public class ColumnType {
case "double":
type = Type.DOUBLE;
break;
case "datev1":
type = Type.DATE;
break;
case "date":
case "datev2":
type = Type.DATEV2;
break;
case "datetimev1":
type = Type.DATETIME;
break;
case "binary":
case "bytes":
type = Type.BINARY;
@ -275,7 +292,9 @@ public class ColumnType {
type = Type.STRING;
break;
default:
if (lowerCaseType.startsWith("timestamp")) {
if (lowerCaseType.startsWith("timestamp")
|| lowerCaseType.startsWith("datetime")
|| lowerCaseType.startsWith("datetimev2")) {
type = Type.DATETIMEV2;
precision = 6; // default
Matcher match = digitPattern.matcher(lowerCaseType);

View File

@ -0,0 +1,25 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.
package org.apache.doris.common.jni.vec;
/**
* Convert the column values if the type is not defined in ColumnType
*/
public interface ColumnValueConverter {
Object[] convert(Object[] column);
}

View File

@ -21,6 +21,9 @@ package org.apache.doris.common.jni.vec;
import org.apache.doris.common.jni.utils.OffHeap;
import org.apache.doris.common.jni.vec.ColumnType.Type;
import java.util.Collections;
import java.util.Map;
/**
* Store a batch of data as vector table.
*/
@ -28,56 +31,133 @@ public class VectorTable {
private final VectorColumn[] columns;
private final ColumnType[] columnTypes;
private final String[] fields;
private final ScanPredicate[] predicates;
private final VectorColumn meta;
private int numRows;
private final boolean onlyReadable;
private final int numRowsOfReadable;
private final boolean isRestoreTable;
public VectorTable(ColumnType[] types, String[] fields, ScanPredicate[] predicates, int capacity) {
// Create writable vector table
private VectorTable(ColumnType[] types, String[] fields, int capacity) {
this.columnTypes = types;
this.fields = fields;
this.columns = new VectorColumn[types.length];
this.predicates = predicates;
int metaSize = 1; // number of rows
for (int i = 0; i < types.length; i++) {
columns[i] = new VectorColumn(types[i], capacity);
columns[i] = VectorColumn.createWritableColumn(types[i], capacity);
metaSize += types[i].metaSize();
}
this.meta = new VectorColumn(new ColumnType("#meta", Type.BIGINT), metaSize);
this.numRows = 0;
this.isRestoreTable = false;
this.meta = VectorColumn.createWritableColumn(new ColumnType("#meta", Type.BIGINT), metaSize);
this.onlyReadable = false;
numRowsOfReadable = -1;
}
public VectorTable(ColumnType[] types, String[] fields, long metaAddress) {
// Create readable vector table
// `metaAddress` is generated by `JniConnector::generate_meta_info`
private VectorTable(ColumnType[] types, String[] fields, long metaAddress) {
long address = metaAddress;
this.columnTypes = types;
this.fields = fields;
this.columns = new VectorColumn[types.length];
this.predicates = new ScanPredicate[0];
this.numRows = (int) OffHeap.getLong(null, address);
int numRows = (int) OffHeap.getLong(null, address);
address += 8;
int metaSize = 1; // stores the number of rows + other columns meta data
for (int i = 0; i < types.length; i++) {
columns[i] = new VectorColumn(types[i], numRows, address);
columns[i] = VectorColumn.createReadableColumn(types[i], numRows, address);
metaSize += types[i].metaSize();
address += types[i].metaSize() * 8L;
}
this.meta = new VectorColumn(metaAddress, metaSize, new ColumnType("#meta", Type.BIGINT));
this.isRestoreTable = true;
this.meta = VectorColumn.createReadableColumn(metaAddress, metaSize, new ColumnType("#meta", Type.BIGINT));
this.onlyReadable = true;
numRowsOfReadable = numRows;
}
public static VectorTable createWritableTable(ColumnType[] types, String[] fields, int capacity) {
return new VectorTable(types, fields, capacity);
}
public static VectorTable createWritableTable(Map<String, String> params, int capacity) {
String[] requiredFields = params.get("required_fields").split(",");
String[] types = params.get("columns_types").split("#");
ColumnType[] columnTypes = new ColumnType[types.length];
for (int i = 0; i < types.length; i++) {
columnTypes[i] = ColumnType.parseType(requiredFields[i], types[i]);
}
return createWritableTable(columnTypes, requiredFields, capacity);
}
public static VectorTable createWritableTable(Map<String, String> params) {
return createWritableTable(params, Integer.parseInt(params.get("num_rows")));
}
public static VectorTable createReadableTable(ColumnType[] types, String[] fields, long metaAddress) {
return new VectorTable(types, fields, metaAddress);
}
public static VectorTable createReadableTable(Map<String, String> params) {
if (params.get("required_fields").isEmpty()) {
assert params.get("columns_types").isEmpty();
return createReadableTable(new ColumnType[0], new String[0], Long.parseLong(params.get("meta_address")));
}
String[] requiredFields = params.get("required_fields").split(",");
String[] types = params.get("columns_types").split("#");
long metaAddress = Long.parseLong(params.get("meta_address"));
// Get sql string from configuration map
ColumnType[] columnTypes = new ColumnType[types.length];
for (int i = 0; i < types.length; i++) {
columnTypes[i] = ColumnType.parseType(requiredFields[i], types[i]);
}
return createReadableTable(columnTypes, requiredFields, metaAddress);
}
public void appendNativeData(int fieldId, NativeColumnValue o) {
assert (!isRestoreTable);
assert (!onlyReadable);
columns[fieldId].appendNativeValue(o);
}
public void appendData(int fieldId, ColumnValue o) {
assert (!isRestoreTable);
assert (!onlyReadable);
columns[fieldId].appendValue(o);
}
public void appendData(int fieldId, Object[] batch, ColumnValueConverter converter, boolean isNullable) {
assert (!onlyReadable);
if (converter != null) {
columns[fieldId].appendObjectColumn(converter.convert(batch), isNullable);
} else {
columns[fieldId].appendObjectColumn(batch, isNullable);
}
}
public void appendData(int fieldId, Object[] batch, boolean isNullable) {
appendData(fieldId, batch, null, isNullable);
}
/**
* Get materialized data, each type is wrapped by its Java type. For example: int -> Integer, decimal -> BigDecimal
*
* @param converters A map of converters. Convert the column values if the type is not defined in ColumnType.
* The map key is the field ID in VectorTable.
*/
public Object[][] getMaterializedData(Map<Integer, ColumnValueConverter> converters) {
if (columns.length == 0) {
return new Object[0][0];
}
Object[][] data = new Object[columns.length][];
for (int j = 0; j < columns.length; ++j) {
Object[] columnData = columns[j].getObjectColumn(0, columns[j].numRows());
if (converters.containsKey(j)) {
data[j] = converters.get(j).convert(columnData);
} else {
data[j] = columnData;
}
}
return data;
}
public Object[][] getMaterializedData() {
return getMaterializedData(Collections.emptyMap());
}
public VectorColumn[] getColumns() {
return columns;
}
@ -95,22 +175,26 @@ public class VectorTable {
}
public void releaseColumn(int fieldId) {
assert (!isRestoreTable);
assert (!onlyReadable);
columns[fieldId].close();
}
public void setNumRows(int numRows) {
this.numRows = numRows;
public int getNumRows() {
if (onlyReadable) {
return numRowsOfReadable;
} else {
return columns[0].numRows();
}
}
public int getNumRows() {
return this.numRows;
public int getNumColumns() {
return columns.length;
}
public long getMetaAddress() {
if (!isRestoreTable) {
if (!onlyReadable) {
meta.reset();
meta.appendLong(numRows);
meta.appendLong(getNumRows());
for (VectorColumn c : columns) {
c.updateMeta(meta);
}
@ -119,7 +203,7 @@ public class VectorTable {
}
public void reset() {
assert (!isRestoreTable);
assert (!onlyReadable);
for (VectorColumn column : columns) {
column.reset();
}
@ -127,7 +211,7 @@ public class VectorTable {
}
public void close() {
assert (!isRestoreTable);
assert (!onlyReadable);
for (int i = 0; i < columns.length; i++) {
releaseColumn(i);
}
@ -137,7 +221,7 @@ public class VectorTable {
// for test only.
public String dump(int rowLimit) {
StringBuilder sb = new StringBuilder();
for (int i = 0; i < rowLimit && i < numRows; i++) {
for (int i = 0; i < rowLimit && i < getNumRows(); i++) {
for (int j = 0; j < columns.length; j++) {
if (j != 0) {
sb.append(", ");

View File

@ -35,10 +35,12 @@ public class JniScannerTest {
{
put("mock_rows", "128");
put("required_fields", "boolean,tinyint,smallint,int,bigint,largeint,float,double,"
+ "date,timestamp,char,varchar,string,decimalv2,decimal64,array,map,struct");
+ "date,timestamp,char,varchar,string,decimalv2,decimal64,array,map,struct,"
+ "decimal18,timestamp4,datev1,datev2,datetimev1,datetimev2");
put("columns_types", "boolean#tinyint#smallint#int#bigint#largeint#float#double#"
+ "date#timestamp#char(10)#varchar(10)#string#decimalv2(12,4)#decimal64(10,3)#"
+ "array<array<string>>#map<string,array<int>>#struct<col1:timestamp(6),col2:array<char(10)>>");
+ "array<array<string>>#map<string,array<int>>#struct<col1:timestamp(6),col2:array<char(10)>>#"
+ "decimal(18,5)#timestamp(4)#datev1#datev2#datetimev1#datetimev2(4)");
}
});
scanner.open();
@ -49,7 +51,7 @@ public class JniScannerTest {
long rows = OffHeap.getLong(null, metaAddress);
Assert.assertEquals(32, rows);
VectorTable restoreTable = new VectorTable(scanner.getTable().getColumnTypes(),
VectorTable restoreTable = VectorTable.createReadableTable(scanner.getTable().getColumnTypes(),
scanner.getTable().getFields(), metaAddress);
System.out.println(restoreTable.dump((int) rows).substring(0, 128));
// Restored table is release by the origin table.

View File

@ -17,17 +17,18 @@
package org.apache.doris.udf;
import org.apache.doris.catalog.PrimitiveType;
import org.apache.doris.catalog.Type;
import org.apache.doris.common.Pair;
import org.apache.doris.common.exception.UdfRuntimeException;
import org.apache.doris.common.jni.utils.JavaUdfDataType;
import org.apache.doris.common.jni.utils.UdfUtils;
import org.apache.doris.common.jni.vec.ColumnValueConverter;
import org.apache.doris.common.jni.vec.VectorTable;
import org.apache.doris.thrift.TJavaUdfExecutorCtorParams;
import org.apache.doris.thrift.TPrimitiveType;
import com.esotericsoftware.reflectasm.MethodAccess;
import com.google.common.base.Joiner;
import com.google.common.base.Preconditions;
import com.google.common.collect.Lists;
import org.apache.log4j.Logger;
@ -35,7 +36,11 @@ import java.lang.reflect.Array;
import java.lang.reflect.Constructor;
import java.lang.reflect.Method;
import java.net.MalformedURLException;
import java.time.LocalDate;
import java.time.LocalDateTime;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.Map;
public class UdfExecutor extends BaseExecutor {
// private static final java.util.logging.Logger LOG =
@ -46,6 +51,8 @@ public class UdfExecutor extends BaseExecutor {
private int evaluateIndex;
private VectorTable outputTable = null;
/**
* Create a UdfExecutor, using parameters from a serialized thrift object. Used by
* the backend.
@ -59,98 +66,195 @@ public class UdfExecutor extends BaseExecutor {
*/
@Override
public void close() {
// inputTable is released by c++, only release outputTable
if (outputTable != null) {
outputTable.close();
}
// We are now un-usable (because the class loader has been
// closed), so null out method_ and classLoader_.
method = null;
super.close();
}
public Object[] convertBasicArguments(int argIdx, boolean isNullable, int numRows, long nullMapAddr,
long columnAddr, long strOffsetAddr) {
return convertBasicArg(true, argIdx, isNullable, 0, numRows, nullMapAddr, columnAddr, strOffsetAddr);
private ColumnValueConverter getInputConverter(TPrimitiveType primitiveType, Class clz) {
switch (primitiveType) {
case DATE:
case DATEV2: {
if (java.util.Date.class.equals(clz)) {
return (Object[] columnData) -> {
Object[] result = new java.util.Date[columnData.length];
for (int i = 0; i < columnData.length; ++i) {
if (columnData[i] != null) {
LocalDate v = (LocalDate) columnData[i];
result[i] = new java.util.Date(v.getYear() - 1900, v.getMonthValue() - 1,
v.getDayOfMonth());
}
}
return result;
};
} else if (org.joda.time.LocalDate.class.equals(clz)) {
return (Object[] columnData) -> {
Object[] result = new org.joda.time.LocalDate[columnData.length];
for (int i = 0; i < columnData.length; ++i) {
if (columnData[i] != null) {
LocalDate v = (LocalDate) columnData[i];
result[i] = new org.joda.time.LocalDate(v.getYear(), v.getMonthValue(),
v.getDayOfMonth());
}
}
return result;
};
} else if (!LocalDate.class.equals(clz)) {
throw new RuntimeException("Unsupported date type: " + clz.getCanonicalName());
}
break;
}
case DATETIME:
case DATETIMEV2: {
if (org.joda.time.DateTime.class.equals(clz)) {
return (Object[] columnData) -> {
Object[] result = new org.joda.time.DateTime[columnData.length];
for (int i = 0; i < columnData.length; ++i) {
if (columnData[i] != null) {
LocalDateTime v = (LocalDateTime) columnData[i];
result[i] = new org.joda.time.DateTime(v.getYear(), v.getMonthValue(),
v.getDayOfMonth(), v.getHour(),
v.getMinute(), v.getSecond(), v.getNano() / 1000000);
}
}
return result;
};
} else if (org.joda.time.LocalDateTime.class.equals(clz)) {
return (Object[] columnData) -> {
Object[] result = new org.joda.time.LocalDateTime[columnData.length];
for (int i = 0; i < columnData.length; ++i) {
if (columnData[i] != null) {
LocalDateTime v = (LocalDateTime) columnData[i];
result[i] = new org.joda.time.LocalDateTime(v.getYear(), v.getMonthValue(),
v.getDayOfMonth(), v.getHour(),
v.getMinute(), v.getSecond(), v.getNano() / 1000000);
}
}
return result;
};
} else if (!LocalDateTime.class.equals(clz)) {
throw new RuntimeException("Unsupported date type: " + clz.getCanonicalName());
}
break;
}
default:
break;
}
return null;
}
public Object[] convertArrayArguments(int argIdx, boolean isNullable, int numRows, long nullMapAddr,
long offsetsAddr, long nestedNullMapAddr, long dataAddr, long strOffsetAddr) {
return convertArrayArg(argIdx, isNullable, 0, numRows, nullMapAddr, offsetsAddr, nestedNullMapAddr, dataAddr,
strOffsetAddr);
private Map<Integer, ColumnValueConverter> getInputConverters(int numColumns) {
Map<Integer, ColumnValueConverter> converters = new HashMap<>();
for (int j = 0; j < numColumns; ++j) {
ColumnValueConverter converter = getInputConverter(argTypes[j].getPrimitiveType(), argClass[j]);
if (converter != null) {
converters.put(j, converter);
}
}
return converters;
}
public Object[] convertMapArguments(int argIdx, boolean isNullable, int numRows, long nullMapAddr,
long offsetsAddr, long keyNestedNullMapAddr, long keyDataAddr, long keyStrOffsetAddr,
long valueNestedNullMapAddr, long valueDataAddr, long valueStrOffsetAddr) {
PrimitiveType keyType = argTypes[argIdx].getKeyType().getPrimitiveType();
PrimitiveType valueType = argTypes[argIdx].getValueType().getPrimitiveType();
Object[] keyCol = convertMapArg(keyType, argIdx, isNullable, 0, numRows, nullMapAddr, offsetsAddr,
keyNestedNullMapAddr, keyDataAddr,
keyStrOffsetAddr, argTypes[argIdx].getKeyScale());
Object[] valueCol = convertMapArg(valueType, argIdx, isNullable, 0, numRows, nullMapAddr, offsetsAddr,
valueNestedNullMapAddr, valueDataAddr,
valueStrOffsetAddr, argTypes[argIdx].getValueScale());
return buildHashMap(keyType, valueType, keyCol, valueCol);
private ColumnValueConverter getOutputConverter() {
Class clz = method.getReturnType();
switch (retType.getPrimitiveType()) {
case DATE:
case DATEV2: {
if (java.util.Date.class.equals(clz)) {
return (Object[] columnData) -> {
Object[] result = new LocalDate[columnData.length];
for (int i = 0; i < columnData.length; ++i) {
if (columnData[i] != null) {
java.util.Date v = (java.util.Date) columnData[i];
result[i] = LocalDate.of(v.getYear() + 1900, v.getMonth() + 1, v.getDate());
}
}
return result;
};
} else if (org.joda.time.LocalDate.class.equals(clz)) {
return (Object[] columnData) -> {
Object[] result = new LocalDate[columnData.length];
for (int i = 0; i < columnData.length; ++i) {
if (columnData[i] != null) {
org.joda.time.LocalDate v = (org.joda.time.LocalDate) columnData[i];
result[i] = LocalDate.of(v.getYear(), v.getMonthOfYear(), v.getDayOfMonth());
}
}
return result;
};
} else if (!LocalDate.class.equals(clz)) {
throw new RuntimeException("Unsupported date type: " + clz.getCanonicalName());
}
break;
}
case DATETIME:
case DATETIMEV2: {
if (org.joda.time.DateTime.class.equals(clz)) {
return (Object[] columnData) -> {
Object[] result = new LocalDateTime[columnData.length];
for (int i = 0; i < columnData.length; ++i) {
if (columnData[i] != null) {
org.joda.time.DateTime v = (org.joda.time.DateTime) columnData[i];
result[i] = LocalDateTime.of(v.getYear(), v.getMonthOfYear(), v.getDayOfMonth(),
v.getHourOfDay(),
v.getMinuteOfHour(), v.getSecondOfMinute(), v.getMillisOfSecond() * 1000000);
}
}
return result;
};
} else if (org.joda.time.LocalDateTime.class.equals(clz)) {
return (Object[] columnData) -> {
Object[] result = new LocalDateTime[columnData.length];
for (int i = 0; i < columnData.length; ++i) {
if (columnData[i] != null) {
org.joda.time.LocalDateTime v = (org.joda.time.LocalDateTime) columnData[i];
result[i] = LocalDateTime.of(v.getYear(), v.getMonthOfYear(), v.getDayOfMonth(),
v.getHourOfDay(),
v.getMinuteOfHour(), v.getSecondOfMinute(), v.getMillisOfSecond() * 1000000);
}
}
return result;
};
} else if (!LocalDateTime.class.equals(clz)) {
throw new RuntimeException("Unsupported date type: " + clz.getCanonicalName());
}
break;
}
default:
break;
}
return null;
}
/**
* Evaluates the UDF with 'args' as the input to the UDF.
*/
public Object[] evaluate(int numRows, Object[] column) throws UdfRuntimeException {
public long evaluate(Map<String, String> inputParams, Map<String, String> outputParams) throws UdfRuntimeException {
try {
VectorTable inputTable = VectorTable.createReadableTable(inputParams);
int numRows = inputTable.getNumRows();
int numColumns = inputTable.getNumColumns();
Object[] result = (Object[]) Array.newInstance(method.getReturnType(), numRows);
Object[][] inputs = (Object[][]) column;
Object[] parameters = new Object[inputs.length];
Object[][] inputs = inputTable.getMaterializedData(getInputConverters(numColumns));
Object[] parameters = new Object[numColumns];
for (int i = 0; i < numRows; ++i) {
for (int j = 0; j < column.length; ++j) {
for (int j = 0; j < numColumns; ++j) {
parameters[j] = inputs[j][i];
}
result[i] = methodAccess.invoke(udf, evaluateIndex, parameters);
}
return result;
} catch (Exception e) {
LOG.info("evaluate exception: " + debugString());
LOG.info("evaluate(int numRows, Object[] column) Exception: " + e.toString());
throw new UdfRuntimeException("UDF failed to evaluate", e);
}
}
public void copyBatchBasicResult(boolean isNullable, int numRows, Object[] result, long nullMapAddr,
long resColumnAddr, long strOffsetAddr) {
copyBatchBasicResultImpl(isNullable, numRows, result, nullMapAddr, resColumnAddr, strOffsetAddr, getMethod());
}
public void copyBatchArrayResult(boolean isNullable, int numRows, Object[] result, long nullMapAddr,
long offsetsAddr, long nestedNullMapAddr, long dataAddr, long strOffsetAddr) {
Preconditions.checkState(result.length == numRows,
"copyBatchArrayResult result size should equal;");
copyBatchArrayResultImpl(isNullable, numRows, result, nullMapAddr, offsetsAddr, nestedNullMapAddr, dataAddr,
strOffsetAddr, retType.getItemType().getPrimitiveType(), retType.getScale());
}
public void copyBatchMapResult(boolean isNullable, int numRows, Object[] result, long nullMapAddr,
long offsetsAddr, long keyNsestedNullMapAddr, long keyDataAddr, long keyStrOffsetAddr,
long valueNsestedNullMapAddr, long valueDataAddr, long valueStrOffsetAddr) {
Preconditions.checkState(result.length == numRows,
"copyBatchMapResult result size should equal;");
PrimitiveType keyType = retType.getKeyType().getPrimitiveType();
PrimitiveType valueType = retType.getValueType().getPrimitiveType();
Object[] keyCol = new Object[result.length];
Object[] valueCol = new Object[result.length];
buildArrayListFromHashMap(result, keyType, valueType, keyCol, valueCol);
copyBatchArrayResultImpl(isNullable, numRows, valueCol, nullMapAddr, offsetsAddr, valueNsestedNullMapAddr,
valueDataAddr,
valueStrOffsetAddr, valueType, retType.getKeyScale());
copyBatchArrayResultImpl(isNullable, numRows, keyCol, nullMapAddr, offsetsAddr, keyNsestedNullMapAddr,
keyDataAddr,
keyStrOffsetAddr, keyType, retType.getValueScale());
}
/**
* Evaluates the UDF with 'args' as the input to the UDF.
*/
private Object evaluate(Object... args) throws UdfRuntimeException {
try {
return method.invoke(udf, args);
if (outputTable != null) {
outputTable.close();
}
boolean isNullable = Boolean.parseBoolean(outputParams.getOrDefault("is_nullable", "true"));
outputTable = VectorTable.createWritableTable(outputParams, numRows);
outputTable.appendData(0, result, getOutputConverter(), isNullable);
return outputTable.getMetaAddress();
} catch (Exception e) {
LOG.warn("evaluate exception: " + debugString(), e);
throw new UdfRuntimeException("UDF failed to evaluate", e);
}
}
@ -254,5 +358,4 @@ public class UdfExecutor extends BaseExecutor {
throw new UdfRuntimeException("Unable to call create UDF instance.", e);
}
}
}

View File

@ -173,16 +173,7 @@ public class JdbcExecutor {
}
public int write(Map<String, String> params) throws UdfRuntimeException {
String[] requiredFields = params.get("required_fields").split(",");
String[] types = params.get("columns_types").split("#");
long metaAddress = Long.parseLong(params.get("meta_address"));
// Get sql string from configuration map
ColumnType[] columnTypes = new ColumnType[types.length];
for (int i = 0; i < types.length; i++) {
columnTypes[i] = ColumnType.parseType(requiredFields[i], types[i]);
}
VectorTable batchTable = new VectorTable(columnTypes, requiredFields, metaAddress);
// todo: insert the batch table by PreparedStatement
VectorTable batchTable = VectorTable.createReadableTable(params);
// Can't release or close batchTable, it's released by c++
try {
insert(batchTable);

View File

@ -0,0 +1,13 @@
-- This file is automatically generated. You should know what you did if you want to edit this
-- !java_udf_all_types --
1 true 1 2 1 3 4 3.3300 7.77 3.1415 2023-10-18 2023-10-18 2023-10-18 2023-10-11T10:11:11.234 2023-10-11T10:11:11.234 2023-10-11T10:11:11.234 row1 [NULL, "nested1"] {"k1":NULL, "k2":1}
2 false 2 4 2 6 8 1.6650 3.885 1.57075 2023-10-19 2023-10-19 2023-10-19 2023-10-12T10:12:11.234 2023-10-12T10:12:11.234 2023-10-12T10:12:11.234 row2 [NULL, "nested2"] {"k3":2, "k2":NULL}
3 true 3 6 3 9 12 1.1100 2.59 1.0471666667 2023-10-20 2023-10-20 2023-10-20 2023-10-13T10:13:11.234 2023-10-13T10:13:11.234 2023-10-13T10:13:11.234 row3 [NULL, "nested3"] {"k3":NULL, "k4":3}
4 false 4 8 4 12 16 0.8325 1.9425 0.785375 2023-10-21 2023-10-21 2023-10-21 2023-10-14T10:14:11.234 2023-10-14T10:14:11.234 2023-10-14T10:14:11.234 row4 [NULL, "nested4"] {"k4":NULL, "k5":4}
5 true 5 10 5 15 20 0.6660 1.554 0.6283 2023-10-22 2023-10-22 2023-10-22 2023-10-15T10:15:11.234 2023-10-15T10:15:11.234 2023-10-15T10:15:11.234 row5 [NULL, "nested5"] {"k5":NULL, "k6":5}
6 false 6 12 6 18 24 0.5550 1.295 0.5235833333 2023-10-23 2023-10-23 2023-10-23 2023-10-16T10:16:11.234 2023-10-16T10:16:11.234 2023-10-16T10:16:11.234 row6 [NULL, "nested6"] {"k7":6, "k6":NULL}
7 true 7 14 7 21 28 0.4757 1.11 0.4487857143 2023-10-24 2023-10-24 2023-10-24 2023-10-17T10:17:11.234 2023-10-17T10:17:11.234 2023-10-17T10:17:11.234 row7 [NULL, "nested7"] {"k7":NULL, "k8":7}
8 false 8 16 8 24 32 0.4163 0.97125 0.3926875 2023-10-25 2023-10-25 2023-10-25 2023-10-18T10:18:11.234 2023-10-18T10:18:11.234 2023-10-18T10:18:11.234 row8 [NULL, "nested8"] {"k8":NULL, "k9":8}
9 true 9 18 9 27 36 0.3700 0.86333334 0.3490555556 2023-10-26 2023-10-26 2023-10-26 2023-10-19T10:19:11.234 2023-10-19T10:19:11.234 2023-10-19T10:19:11.234 row9 [NULL, "nested9"] {"k9":NULL, "k10":9}
10 false \N 20 10 30 40 \N 0.777 0.31415 \N \N \N 2023-10-20T10:10:11.234 2023-10-20T10:10:11.234 2023-10-20T10:10:11.234 \N [NULL, "nested10"] {"k11":10, "k10":NULL}

View File

@ -0,0 +1,137 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.
package org.apache.doris.udf;
import org.apache.hadoop.hive.ql.exec.UDF;
import java.math.BigDecimal;
import java.math.BigInteger;
import java.time.LocalDate;
import java.time.LocalDateTime;
import java.util.ArrayList;
import java.util.HashMap;
public class Echo {
public static class EchoBoolean extends UDF {
public Boolean evaluate(Boolean value) {
return value;
}
}
public static class EchoByte extends UDF {
public Byte evaluate(Byte value) {
return value;
}
}
public static class EchoShort extends UDF {
public Short evaluate(Short value) {
return value;
}
}
public static class EchoInt extends UDF {
public Integer evaluate(Integer value) {
return value;
}
}
public static class EchoLong extends UDF {
public Long evaluate(Long value) {
return value;
}
}
public static class EchoLargeInt extends UDF {
public BigInteger evaluate(BigInteger value) {
return value;
}
}
public static class EchoFloat extends UDF {
public Float evaluate(Float value) {
return value;
}
}
public static class EchoDouble extends UDF {
public Double evaluate(Double value) {
return value;
}
}
public static class EchoDecimal extends UDF {
public BigDecimal evaluate(BigDecimal value) {
return value;
}
}
public static class EchoDate extends UDF {
public LocalDate evaluate(LocalDate value) {
return value;
}
}
public static class EchoDate2 extends UDF {
public java.util.Date evaluate(java.util.Date value) {
return value;
}
}
public static class EchoDate3 extends UDF {
public org.joda.time.LocalDate evaluate(org.joda.time.LocalDate value) {
return value;
}
}
public static class EchoDateTime extends UDF {
public LocalDateTime evaluate(LocalDateTime value) {
return value;
}
}
public static class EchoDateTime2 extends UDF {
public org.joda.time.DateTime evaluate(org.joda.time.DateTime value) {
return value;
}
}
public static class EchoDateTime3 extends UDF {
public org.joda.time.LocalDateTime evaluate(org.joda.time.LocalDateTime value) {
return value;
}
}
public static class EchoString extends UDF {
public String evaluate(String value) {
return value;
}
}
public static class EchoList extends UDF {
public ArrayList<String> evaluate(ArrayList<String> value) {
return value;
}
}
public static class EchoMap extends UDF {
public HashMap<String, Integer> evaluate(HashMap<String, Integer> value) {
return value;
}
}
}

View File

@ -24,10 +24,14 @@ import java.util.*;
public class MapStrStrTest extends UDF {
public String evaluate(HashMap<String, String> hashMap) {
StringBuffer sb = new StringBuffer();
Set<String> sortSet = new TreeSet<String>();
for (Map.Entry<String, String> entry : hashMap.entrySet()) {
String key = entry.getKey();
String value = entry.getValue();
sb.append((key + value));
sortSet.add(key + value);
}
for (String item : sortSet) {
sb.append(item);
}
String ans = sb.toString();
return ans;

View File

@ -0,0 +1,236 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.
import org.codehaus.groovy.runtime.IOGroovyMethods
import java.nio.charset.StandardCharsets
import java.nio.file.Files
import java.nio.file.Paths
suite("test_javaudf_all_types") {
def tableName = "test_javaudf_all_types"
def jarPath = """${context.file.parent}/jars/java-udf-case-jar-with-dependencies.jar"""
log.info("Jar path: ${jarPath}".toString())
try {
sql """ DROP TABLE IF EXISTS ${tableName} """
sql """
CREATE TABLE IF NOT EXISTS ${tableName} (
int_col int,
boolean_col boolean,
tinyint_col tinyint,
smallint_col smallint,
bigint_col bigint,
largeint_col largeint,
decimal_col decimal(15, 4),
float_col float,
double_col double,
date_col date,
datetime_col datetime(6),
string_col string,
array_col array<string>,
map_col map<string, int>
)
DISTRIBUTED BY HASH(int_col) PROPERTIES("replication_num" = "1");
"""
StringBuilder sb = new StringBuilder()
int i = 1
for (; i < 10; i++) {
sb.append("""
(${i},${i%2},${i},${i}*2,${i}*3,${i}*4,${3.33/i},${7.77/i},${3.1415/i},"2023-10-${i+17}","2023-10-${i+10} 10:1${i}:11.234","row${i}",array(null, "nested${i}"),{"k${i}":null,"k${i+1}":${i}}),
""")
}
sb.append("""
(${i},${i%2},null,${i}*2,${i}*3,${i}*4,null,${7.77/i},${3.1415/i},null,"2023-10-${i+10} 10:${i}:11.234",null,array(null, "nested${i}"),{"k${i}":null,"k${i+1}":${i}})
""")
sql """ INSERT INTO ${tableName} VALUES
${sb.toString()}
"""
File path = new File(jarPath)
if (!path.exists()) {
throw new IllegalStateException("""${jarPath} doesn't exist! """)
}
sql """DROP FUNCTION IF EXISTS echo_boolean(boolean);"""
sql """CREATE FUNCTION echo_boolean(boolean) RETURNS boolean PROPERTIES (
"file"="file://${jarPath}",
"symbol"="org.apache.doris.udf.Echo\$EchoBoolean",
"type"="JAVA_UDF"
);"""
sql """DROP FUNCTION IF EXISTS echo_tinyint(tinyint);"""
sql """CREATE FUNCTION echo_tinyint(tinyint) RETURNS tinyint PROPERTIES (
"file"="file://${jarPath}",
"symbol"="org.apache.doris.udf.Echo\$EchoByte",
"type"="JAVA_UDF"
);"""
sql """DROP FUNCTION IF EXISTS echo_short(smallint);"""
sql """CREATE FUNCTION echo_short(smallint) RETURNS smallint PROPERTIES (
"file"="file://${jarPath}",
"symbol"="org.apache.doris.udf.Echo\$EchoShort",
"type"="JAVA_UDF"
);"""
sql """DROP FUNCTION IF EXISTS echo_int(int);"""
sql """CREATE FUNCTION echo_int(int) RETURNS int PROPERTIES (
"file"="file://${jarPath}",
"symbol"="org.apache.doris.udf.Echo\$EchoInt",
"type"="JAVA_UDF"
);"""
sql """DROP FUNCTION IF EXISTS echo_long(bigint);"""
sql """CREATE FUNCTION echo_long(bigint) RETURNS bigint PROPERTIES (
"file"="file://${jarPath}",
"symbol"="org.apache.doris.udf.Echo\$EchoLong",
"type"="JAVA_UDF"
);"""
sql """DROP FUNCTION IF EXISTS echo_largeint(largeint);"""
sql """CREATE FUNCTION echo_largeint(largeint) RETURNS largeint PROPERTIES (
"file"="file://${jarPath}",
"symbol"="org.apache.doris.udf.Echo\$EchoLargeInt",
"type"="JAVA_UDF"
);"""
sql """DROP FUNCTION IF EXISTS echo_decimal(decimal(15, 4));"""
sql """CREATE FUNCTION echo_decimal(decimal(15, 4)) RETURNS decimal(15, 4) PROPERTIES (
"file"="file://${jarPath}",
"symbol"="org.apache.doris.udf.Echo\$EchoDecimal",
"type"="JAVA_UDF"
);"""
sql """DROP FUNCTION IF EXISTS echo_float(float);"""
sql """CREATE FUNCTION echo_float(float) RETURNS float PROPERTIES (
"file"="file://${jarPath}",
"symbol"="org.apache.doris.udf.Echo\$EchoFloat",
"type"="JAVA_UDF"
);"""
sql """DROP FUNCTION IF EXISTS echo_double(double);"""
sql """CREATE FUNCTION echo_double(double) RETURNS double PROPERTIES (
"file"="file://${jarPath}",
"symbol"="org.apache.doris.udf.Echo\$EchoDouble",
"type"="JAVA_UDF"
);"""
sql """DROP FUNCTION IF EXISTS echo_date(date);"""
sql """CREATE FUNCTION echo_date(date) RETURNS date PROPERTIES (
"file"="file://${jarPath}",
"symbol"="org.apache.doris.udf.Echo\$EchoDate",
"type"="JAVA_UDF"
);"""
sql """DROP FUNCTION IF EXISTS echo_date2(date);"""
sql """CREATE FUNCTION echo_date2(date) RETURNS date PROPERTIES (
"file"="file://${jarPath}",
"symbol"="org.apache.doris.udf.Echo\$EchoDate2",
"type"="JAVA_UDF"
);"""
sql """DROP FUNCTION IF EXISTS echo_date3(date);"""
sql """CREATE FUNCTION echo_date3(date) RETURNS date PROPERTIES (
"file"="file://${jarPath}",
"symbol"="org.apache.doris.udf.Echo\$EchoDate3",
"type"="JAVA_UDF"
);"""
sql """DROP FUNCTION IF EXISTS echo_datetime(datetime(6));"""
sql """CREATE FUNCTION echo_datetime(datetime(6)) RETURNS datetime(6) PROPERTIES (
"file"="file://${jarPath}",
"symbol"="org.apache.doris.udf.Echo\$EchoDateTime",
"type"="JAVA_UDF"
);"""
sql """DROP FUNCTION IF EXISTS echo_datetime2(datetime(6));"""
sql """CREATE FUNCTION echo_datetime2(datetime(6)) RETURNS datetime(6) PROPERTIES (
"file"="file://${jarPath}",
"symbol"="org.apache.doris.udf.Echo\$EchoDateTime2",
"type"="JAVA_UDF"
);"""
sql """DROP FUNCTION IF EXISTS echo_datetime3(datetime(6));"""
sql """CREATE FUNCTION echo_datetime3(datetime(6)) RETURNS datetime(6) PROPERTIES (
"file"="file://${jarPath}",
"symbol"="org.apache.doris.udf.Echo\$EchoDateTime3",
"type"="JAVA_UDF"
);"""
sql """DROP FUNCTION IF EXISTS echo_string(string);"""
sql """CREATE FUNCTION echo_string(string) RETURNS string PROPERTIES (
"file"="file://${jarPath}",
"symbol"="org.apache.doris.udf.Echo\$EchoString",
"type"="JAVA_UDF"
);"""
sql """DROP FUNCTION IF EXISTS echo_list(array<string>);"""
sql """CREATE FUNCTION echo_list(array<string>) RETURNS array<string> PROPERTIES (
"file"="file://${jarPath}",
"symbol"="org.apache.doris.udf.Echo\$EchoList",
"type"="JAVA_UDF"
);"""
sql """DROP FUNCTION IF EXISTS echo_map(map<string,int>);"""
sql """CREATE FUNCTION echo_map(map<string,int>) RETURNS map<string,int> PROPERTIES (
"file"="file://${jarPath}",
"symbol"="org.apache.doris.udf.Echo\$EchoMap",
"type"="JAVA_UDF"
);"""
qt_java_udf_all_types """select
int_col,
echo_boolean(boolean_col),
echo_tinyint(tinyint_col),
echo_short(smallint_col),
echo_int(int_col),
echo_long(bigint_col),
echo_largeint(largeint_col),
echo_decimal(decimal_col),
echo_float(float_col),
echo_double(double_col),
echo_date(date_col),
echo_date2(date_col),
echo_date3(date_col),
echo_datetime(datetime_col),
echo_datetime2(datetime_col),
echo_datetime3(datetime_col),
echo_string(string_col),
echo_list(array_col),
echo_map(map_col)
from ${tableName} order by int_col;"""
} finally {
try_sql """DROP FUNCTION IF EXISTS echo_boolean(boolean);"""
try_sql """DROP FUNCTION IF EXISTS echo_tinyint(tinyint);"""
try_sql """DROP FUNCTION IF EXISTS echo_short(smallint);"""
try_sql """DROP FUNCTION IF EXISTS echo_int(int);"""
try_sql """DROP FUNCTION IF EXISTS echo_long(bigint);"""
try_sql """DROP FUNCTION IF EXISTS echo_largeint(largeint);"""
try_sql """DROP FUNCTION IF EXISTS echo_decimal(decimal(15, 4));"""
try_sql """DROP FUNCTION IF EXISTS echo_float(float);"""
try_sql """DROP FUNCTION IF EXISTS echo_double(double);"""
try_sql """DROP FUNCTION IF EXISTS echo_date(date);"""
try_sql """DROP FUNCTION IF EXISTS echo_date2(date);"""
try_sql """DROP FUNCTION IF EXISTS echo_date3(date);"""
try_sql """DROP FUNCTION IF EXISTS echo_datetime(datetime(6));"""
try_sql """DROP FUNCTION IF EXISTS echo_datetime2(datetime(6));"""
try_sql """DROP FUNCTION IF EXISTS echo_datetime3(datetime(6));"""
try_sql """DROP FUNCTION IF EXISTS echo_string(string);"""
try_sql """DROP FUNCTION IF EXISTS echo_list(array<string>);"""
try_sql """DROP FUNCTION IF EXISTS echo_map(map<string,int>);"""
try_sql("""DROP TABLE IF EXISTS ${tableName};""")
}
}