[vectorized](udaf) support udaf function work with window function (#19962)
This commit is contained in:
@ -47,6 +47,7 @@ const char* UDAF_EXECUTOR_ADD_SIGNATURE = "(ZJJ)V";
|
||||
const char* UDAF_EXECUTOR_SERIALIZE_SIGNATURE = "(J)[B";
|
||||
const char* UDAF_EXECUTOR_MERGE_SIGNATURE = "(J[B)V";
|
||||
const char* UDAF_EXECUTOR_RESULT_SIGNATURE = "(JJ)Z";
|
||||
const char* UDAF_EXECUTOR_RESET_SIGNATURE = "(J)V";
|
||||
// Calling Java method about those signature means: "(argument-types)return-type"
|
||||
// https://www.iitk.ac.in/esc101/05Aug/tutorial/native1.1/implementing/method.html
|
||||
|
||||
@ -219,6 +220,13 @@ public:
|
||||
return JniUtil::GetJniExceptionMsg(env);
|
||||
}
|
||||
|
||||
Status reset(int64_t place) {
|
||||
JNIEnv* env = nullptr;
|
||||
RETURN_NOT_OK_STATUS_WITH_WARN(JniUtil::GetJNIEnv(&env), "Java-Udaf reset function");
|
||||
env->CallNonvirtualVoidMethod(executor_obj, executor_cl, executor_reset_id, place);
|
||||
return JniUtil::GetJniExceptionMsg(env);
|
||||
}
|
||||
|
||||
void read(BufferReadable& buf) { read_binary(serialize_data, buf); }
|
||||
|
||||
Status destroy() {
|
||||
@ -375,6 +383,7 @@ private:
|
||||
|
||||
RETURN_IF_ERROR(register_id("<init>", UDAF_EXECUTOR_CTOR_SIGNATURE, executor_ctor_id));
|
||||
RETURN_IF_ERROR(register_id("add", UDAF_EXECUTOR_ADD_SIGNATURE, executor_add_id));
|
||||
RETURN_IF_ERROR(register_id("reset", UDAF_EXECUTOR_RESET_SIGNATURE, executor_reset_id));
|
||||
RETURN_IF_ERROR(register_id("close", UDAF_EXECUTOR_CLOSE_SIGNATURE, executor_close_id));
|
||||
RETURN_IF_ERROR(register_id("merge", UDAF_EXECUTOR_MERGE_SIGNATURE, executor_merge_id));
|
||||
RETURN_IF_ERROR(
|
||||
@ -397,6 +406,7 @@ private:
|
||||
jmethodID executor_merge_id;
|
||||
jmethodID executor_serialize_id;
|
||||
jmethodID executor_result_id;
|
||||
jmethodID executor_reset_id;
|
||||
jmethodID executor_close_id;
|
||||
jmethodID executor_destroy_id;
|
||||
|
||||
@ -502,11 +512,26 @@ public:
|
||||
}
|
||||
}
|
||||
|
||||
// TODO: reset function should be implement also in struct data
|
||||
void reset(AggregateDataPtr /*place*/) const override {
|
||||
LOG(WARNING) << " shouldn't going reset function, there maybe some error about function "
|
||||
<< _fn.name.function_name;
|
||||
throw doris::Exception(ErrorCode::INTERNAL_ERROR, "shouldn't going reset function");
|
||||
void add_range_single_place(int64_t partition_start, int64_t partition_end, int64_t frame_start,
|
||||
int64_t frame_end, AggregateDataPtr place, const IColumn** columns,
|
||||
Arena* arena) const override {
|
||||
frame_start = std::max<int64_t>(frame_start, partition_start);
|
||||
frame_end = std::min<int64_t>(frame_end, partition_end);
|
||||
int64_t places_address[1];
|
||||
places_address[0] = reinterpret_cast<int64_t>(place);
|
||||
Status st =
|
||||
this->data(_exec_place)
|
||||
.add(places_address, true, columns, frame_start, frame_end, argument_types);
|
||||
if (UNLIKELY(st != Status::OK())) {
|
||||
throw doris::Exception(ErrorCode::INTERNAL_ERROR, st.to_string());
|
||||
}
|
||||
}
|
||||
|
||||
void reset(AggregateDataPtr place) const override {
|
||||
Status st = this->data(_exec_place).reset(reinterpret_cast<int64_t>(place));
|
||||
if (UNLIKELY(st != Status::OK())) {
|
||||
throw doris::Exception(ErrorCode::INTERNAL_ERROR, st.to_string());
|
||||
}
|
||||
}
|
||||
|
||||
void merge(AggregateDataPtr __restrict place, ConstAggregateDataPtr rhs,
|
||||
|
||||
@ -130,6 +130,13 @@ public class SimpleDemo {
|
||||
/* here could do some destroy work if needed */
|
||||
}
|
||||
|
||||
/*Not Required*/
|
||||
public void reset(State state) {
|
||||
/*if you want this udaf function can work with window function.*/
|
||||
/*Must impl this, it will be reset to init state after calculate every window frame*/
|
||||
state.sum = 0;
|
||||
}
|
||||
|
||||
/*required*/
|
||||
//first argument is State, then other types your input
|
||||
public void add(State state, Integer val) throws Exception {
|
||||
|
||||
@ -130,6 +130,13 @@ public class SimpleDemo {
|
||||
/* here could do some destroy work if needed */
|
||||
}
|
||||
|
||||
/*Not Required*/
|
||||
public void reset(State state) {
|
||||
/*if you want this udaf function can work with window function.*/
|
||||
/*Must impl this, it will be reset to init state after calculate every window frame*/
|
||||
state.sum = 0;
|
||||
}
|
||||
|
||||
/*required*/
|
||||
//first argument is State, then other types your input
|
||||
public void add(State state, Integer val) throws Exception {
|
||||
|
||||
@ -46,6 +46,7 @@ public abstract class BaseExecutor {
|
||||
public static final String UDAF_CREATE_FUNCTION = "create";
|
||||
public static final String UDAF_DESTROY_FUNCTION = "destroy";
|
||||
public static final String UDAF_ADD_FUNCTION = "add";
|
||||
public static final String UDAF_RESET_FUNCTION = "reset";
|
||||
public static final String UDAF_SERIALIZE_FUNCTION = "serialize";
|
||||
public static final String UDAF_DESERIALIZE_FUNCTION = "deserialize";
|
||||
public static final String UDAF_MERGE_FUNCTION = "merge";
|
||||
|
||||
@ -71,10 +71,21 @@ public class UdafExecutor extends BaseExecutor {
|
||||
try {
|
||||
long idx = rowStart;
|
||||
do {
|
||||
Long curPlace = UdfUtils.UNSAFE.getLong(null, UdfUtils.UNSAFE.getLong(null, inputPlacesPtr) + 8L * idx);
|
||||
Long curPlace = null;
|
||||
if (isSinglePlace) {
|
||||
curPlace = UdfUtils.UNSAFE.getLong(null, UdfUtils.UNSAFE.getLong(null, inputPlacesPtr));
|
||||
} else {
|
||||
curPlace = UdfUtils.UNSAFE.getLong(null, UdfUtils.UNSAFE.getLong(null, inputPlacesPtr) + 8L * idx);
|
||||
}
|
||||
Object[] inputArgs = new Object[argTypes.length + 1];
|
||||
stateObjMap.putIfAbsent(curPlace, createAggState());
|
||||
inputArgs[0] = stateObjMap.get(curPlace);
|
||||
Object state = stateObjMap.get(curPlace);
|
||||
if (state != null) {
|
||||
inputArgs[0] = state;
|
||||
} else {
|
||||
Object newState = createAggState();
|
||||
stateObjMap.put(curPlace, newState);
|
||||
inputArgs[0] = newState;
|
||||
}
|
||||
do {
|
||||
Object[] inputObjects = allocateInputObjects(idx, 1);
|
||||
for (int i = 0; i < argTypes.length; ++i) {
|
||||
@ -134,6 +145,23 @@ public class UdafExecutor extends BaseExecutor {
|
||||
}
|
||||
}
|
||||
|
||||
/*
|
||||
* invoke reset function and reset the state to init.
|
||||
*/
|
||||
public void reset(long place) throws UdfRuntimeException {
|
||||
try {
|
||||
Object[] args = new Object[1];
|
||||
args[0] = stateObjMap.get((Long) place);
|
||||
if (args[0] == null) {
|
||||
return;
|
||||
}
|
||||
allMethods.get(UDAF_RESET_FUNCTION).invoke(udf, args);
|
||||
} catch (Exception e) {
|
||||
LOG.warn("invoke reset function meet some error: " + e.getCause().toString());
|
||||
throw new UdfRuntimeException("UDAF failed to reset: ", e);
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* invoke merge function and it's have done deserialze.
|
||||
* here call deserialize first, and call merge.
|
||||
@ -147,8 +175,14 @@ public class UdafExecutor extends BaseExecutor {
|
||||
allMethods.get(UDAF_DESERIALIZE_FUNCTION).invoke(udf, args);
|
||||
args[1] = args[0];
|
||||
Long curPlace = place;
|
||||
stateObjMap.putIfAbsent(curPlace, createAggState());
|
||||
args[0] = stateObjMap.get(curPlace);
|
||||
Object state = stateObjMap.get(curPlace);
|
||||
if (state != null) {
|
||||
args[0] = state;
|
||||
} else {
|
||||
Object newState = createAggState();
|
||||
stateObjMap.put(curPlace, newState);
|
||||
args[0] = newState;
|
||||
}
|
||||
allMethods.get(UDAF_MERGE_FUNCTION).invoke(udf, args);
|
||||
} catch (Exception e) {
|
||||
LOG.warn("invoke merge function meet some error: " + e.getCause().toString());
|
||||
@ -226,6 +260,7 @@ public class UdafExecutor extends BaseExecutor {
|
||||
case UDAF_CREATE_FUNCTION:
|
||||
case UDAF_MERGE_FUNCTION:
|
||||
case UDAF_SERIALIZE_FUNCTION:
|
||||
case UDAF_RESET_FUNCTION:
|
||||
case UDAF_DESERIALIZE_FUNCTION: {
|
||||
allMethods.put(methods[idx].getName(), methods[idx]);
|
||||
break;
|
||||
|
||||
@ -31,3 +31,36 @@
|
||||
2 6 6
|
||||
9 9 9
|
||||
|
||||
-- !select5 --
|
||||
1
|
||||
2
|
||||
0
|
||||
1
|
||||
2
|
||||
0
|
||||
1
|
||||
2
|
||||
9
|
||||
|
||||
-- !select6 --
|
||||
1
|
||||
2
|
||||
0
|
||||
1
|
||||
2
|
||||
0
|
||||
1
|
||||
2
|
||||
9
|
||||
|
||||
-- !select7 --
|
||||
1
|
||||
2
|
||||
0
|
||||
1
|
||||
2
|
||||
0
|
||||
1
|
||||
2
|
||||
9
|
||||
|
||||
|
||||
@ -31,6 +31,10 @@ public class MySumInt {
|
||||
public void destroy(State state) {
|
||||
}
|
||||
|
||||
public void reset(State state) {
|
||||
state.counter = 0;
|
||||
}
|
||||
|
||||
public void add(State state, Integer val) {
|
||||
if (val == null) return;
|
||||
state.counter += val;
|
||||
|
||||
@ -72,7 +72,9 @@ suite("test_javaudaf_mysum_int") {
|
||||
|
||||
qt_select4 """ select user_id, udaf_my_sum_int(user_id), sum(user_id) from ${tableName} group by user_id order by user_id; """
|
||||
|
||||
|
||||
qt_select5 """ select udaf_my_sum_int(user_id) over(partition by char_col) from test_javaudaf_mysum_int order by char_col; """
|
||||
qt_select6 """ select udaf_my_sum_int(user_id) over(partition by char_col order by string_col) from test_javaudaf_mysum_int order by char_col; """
|
||||
qt_select7 """ select udaf_my_sum_int(user_id) over(partition by char_col order by string_col rows between 1 preceding and 1 following ) from test_javaudaf_mysum_int order by char_col; """
|
||||
} finally {
|
||||
try_sql("DROP FUNCTION IF EXISTS udaf_my_sum_int(int);")
|
||||
try_sql("DROP TABLE IF EXISTS ${tableName}")
|
||||
|
||||
Reference in New Issue
Block a user