[Function] Support Decimal to calculate variance and standard deviation (#4959)
This commit is contained in:
@ -2061,6 +2061,13 @@ struct KnuthVarianceState {
|
||||
int64_t count;
|
||||
};
|
||||
|
||||
// Use Decimal to store the intermediate results of the variance algorithm
|
||||
struct DecimalV2KnuthVarianceState {
|
||||
DecimalV2Val mean;
|
||||
DecimalV2Val m2;
|
||||
int64_t count = 0;
|
||||
};
|
||||
|
||||
// Set pop=true for population variance, false for sample variance
|
||||
static double compute_knuth_variance(const KnuthVarianceState& state, bool pop) {
|
||||
// Return zero for 1 tuple specified by
|
||||
@ -2070,6 +2077,16 @@ static double compute_knuth_variance(const KnuthVarianceState& state, bool pop)
|
||||
return state.m2 / (state.count - 1);
|
||||
}
|
||||
|
||||
// The algorithm is the same as above, using decimal as the intermediate variable
|
||||
static DecimalV2Value decimalv2_compute_knuth_variance(const DecimalV2KnuthVarianceState& state, bool pop) {
|
||||
DecimalV2Value new_count = DecimalV2Value();
|
||||
new_count.assign_from_double(state.count);
|
||||
if (state.count == 1) return new_count;
|
||||
DecimalV2Value new_m2 = DecimalV2Value::from_decimal_val(state.m2);
|
||||
if (pop) return new_m2 / new_count;
|
||||
else return new_m2 / new_count.assign_from_double(state.count - 1);
|
||||
}
|
||||
|
||||
void AggregateFunctions::knuth_var_init(FunctionContext* ctx, StringVal* dst) {
|
||||
dst->is_null = false;
|
||||
// TODO(zc)
|
||||
@ -2079,6 +2096,15 @@ void AggregateFunctions::knuth_var_init(FunctionContext* ctx, StringVal* dst) {
|
||||
memset(dst->ptr, 0, dst->len);
|
||||
}
|
||||
|
||||
void AggregateFunctions::decimalv2_knuth_var_init(FunctionContext* ctx, StringVal* dst) {
|
||||
dst->is_null = false;
|
||||
dst->len = sizeof(DecimalV2KnuthVarianceState);
|
||||
// The memory for int128 need to be aligned by 16.
|
||||
// So the constructor has been used instead of allocating memory.
|
||||
// Also, it will be release in finalize.
|
||||
dst->ptr = (uint8_t*) new DecimalV2KnuthVarianceState;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void AggregateFunctions::knuth_var_update(FunctionContext* ctx, const T& src, StringVal* dst) {
|
||||
DCHECK(!dst->is_null);
|
||||
@ -2093,6 +2119,34 @@ void AggregateFunctions::knuth_var_update(FunctionContext* ctx, const T& src, St
|
||||
state->count = temp;
|
||||
}
|
||||
|
||||
void AggregateFunctions::knuth_var_update(FunctionContext* ctx, const DecimalV2Val& src, StringVal* dst) {
|
||||
DCHECK(!dst->is_null);
|
||||
DCHECK_EQ(dst->len, sizeof(DecimalV2KnuthVarianceState));
|
||||
if (src.is_null) return;
|
||||
DecimalV2KnuthVarianceState* state = reinterpret_cast<DecimalV2KnuthVarianceState*>(dst->ptr);
|
||||
|
||||
DecimalV2Value new_src = DecimalV2Value::from_decimal_val(src);
|
||||
DecimalV2Value new_mean = DecimalV2Value::from_decimal_val(state->mean);
|
||||
DecimalV2Value new_m2 = DecimalV2Value::from_decimal_val(state->m2);
|
||||
DecimalV2Value new_count = DecimalV2Value();
|
||||
new_count.assign_from_double(state->count);
|
||||
|
||||
DecimalV2Value temp = DecimalV2Value();
|
||||
temp.assign_from_double(1 + state->count);
|
||||
DecimalV2Value delta = new_src - new_mean;
|
||||
DecimalV2Value r = delta / temp;
|
||||
new_mean += r;
|
||||
// This may cause Decimal to overflow. When it overflows, m2 will be equal to 9223372036854775807999999999,
|
||||
// which is the maximum value that DecimalV2Value can represent. When using double to store the intermediate result m2,
|
||||
// it can be expressed by scientific and technical methods and will not overflow.
|
||||
// Spark's handling of decimal overflow is to return null or report an error, which can be controlled by parameters.
|
||||
// Spark's handling of decimal reference: https://cloud.tencent.com/developer/news/483615
|
||||
new_m2 += new_count * delta * r;
|
||||
++state->count;
|
||||
new_mean.to_decimal_val(&state->mean);
|
||||
new_m2.to_decimal_val(&state->m2);
|
||||
}
|
||||
|
||||
void AggregateFunctions::knuth_var_merge(FunctionContext* ctx, const StringVal& src,
|
||||
StringVal* dst) {
|
||||
DCHECK(!dst->is_null);
|
||||
@ -2112,6 +2166,33 @@ void AggregateFunctions::knuth_var_merge(FunctionContext* ctx, const StringVal&
|
||||
dst_state->count = sum_count;
|
||||
}
|
||||
|
||||
void AggregateFunctions::decimalv2_knuth_var_merge(FunctionContext* ctx, const StringVal& src,
|
||||
StringVal* dst) {
|
||||
DecimalV2KnuthVarianceState src_state;
|
||||
memcpy(&src_state, src.ptr, sizeof(DecimalV2KnuthVarianceState));
|
||||
DCHECK(!dst->is_null);
|
||||
DCHECK_EQ(dst->len, sizeof(DecimalV2KnuthVarianceState));
|
||||
DecimalV2KnuthVarianceState* dst_state = reinterpret_cast<DecimalV2KnuthVarianceState*>(dst->ptr);
|
||||
if (src_state.count == 0) return;
|
||||
|
||||
DecimalV2Value new_src_mean = DecimalV2Value::from_decimal_val(src_state.mean);
|
||||
DecimalV2Value new_dst_mean = DecimalV2Value::from_decimal_val(dst_state->mean);
|
||||
DecimalV2Value new_src_count = DecimalV2Value();
|
||||
new_src_count.assign_from_double(src_state.count);
|
||||
DecimalV2Value new_dst_count = DecimalV2Value();
|
||||
new_dst_count.assign_from_double(dst_state->count);
|
||||
DecimalV2Value new_src_m2 = DecimalV2Value::from_decimal_val(src_state.m2);
|
||||
DecimalV2Value new_dst_m2 = DecimalV2Value::from_decimal_val(dst_state->m2);
|
||||
|
||||
DecimalV2Value delta = new_dst_mean - new_src_mean;
|
||||
DecimalV2Value sum_count = new_dst_count + new_src_count;
|
||||
new_dst_mean = new_src_mean + delta * (new_dst_count / sum_count);
|
||||
new_dst_m2 = (new_src_m2) + new_dst_m2 + (delta * delta) * (new_src_count * new_dst_count / sum_count);
|
||||
dst_state->count += src_state.count;
|
||||
new_dst_mean.to_decimal_val(&dst_state->mean);
|
||||
new_dst_m2.to_decimal_val(&dst_state->m2);
|
||||
}
|
||||
|
||||
DoubleVal AggregateFunctions::knuth_var_finalize(FunctionContext* ctx, const StringVal& state_sv) {
|
||||
DCHECK(!state_sv.is_null);
|
||||
KnuthVarianceState* state = reinterpret_cast<KnuthVarianceState*>(state_sv.ptr);
|
||||
@ -2121,6 +2202,19 @@ DoubleVal AggregateFunctions::knuth_var_finalize(FunctionContext* ctx, const Str
|
||||
return DoubleVal(variance);
|
||||
}
|
||||
|
||||
DecimalV2Val AggregateFunctions::decimalv2_knuth_var_finalize(FunctionContext* ctx,
|
||||
const StringVal& state_sv) {
|
||||
DCHECK(!state_sv.is_null);
|
||||
DCHECK_EQ(state_sv.len, sizeof(DecimalV2KnuthVarianceState));
|
||||
DecimalV2KnuthVarianceState* state = reinterpret_cast<DecimalV2KnuthVarianceState*>(state_sv.ptr);
|
||||
if (state->count == 0 || state->count == 1) return DecimalV2Val::null();
|
||||
DecimalV2Value variance = decimalv2_compute_knuth_variance(*state, false);
|
||||
DecimalV2Val res;
|
||||
variance.to_decimal_val(&res);
|
||||
delete (DecimalV2KnuthVarianceState*)state_sv.ptr;
|
||||
return res;
|
||||
}
|
||||
|
||||
DoubleVal AggregateFunctions::knuth_var_pop_finalize(FunctionContext* ctx,
|
||||
const StringVal& state_sv) {
|
||||
DCHECK(!state_sv.is_null);
|
||||
@ -2132,6 +2226,19 @@ DoubleVal AggregateFunctions::knuth_var_pop_finalize(FunctionContext* ctx,
|
||||
return DoubleVal(variance);
|
||||
}
|
||||
|
||||
DecimalV2Val AggregateFunctions::decimalv2_knuth_var_pop_finalize(FunctionContext* ctx,
|
||||
const StringVal& state_sv) {
|
||||
DCHECK(!state_sv.is_null);
|
||||
DCHECK_EQ(state_sv.len, sizeof(DecimalV2KnuthVarianceState));
|
||||
DecimalV2KnuthVarianceState* state = reinterpret_cast<DecimalV2KnuthVarianceState*>(state_sv.ptr);
|
||||
if (state->count == 0) return DecimalV2Val::null();
|
||||
DecimalV2Value variance = decimalv2_compute_knuth_variance(*state, true);
|
||||
DecimalV2Val res;
|
||||
variance.to_decimal_val(&res);
|
||||
delete (DecimalV2KnuthVarianceState*)state_sv.ptr;
|
||||
return res;
|
||||
}
|
||||
|
||||
DoubleVal AggregateFunctions::knuth_stddev_finalize(FunctionContext* ctx,
|
||||
const StringVal& state_sv) {
|
||||
DCHECK(!state_sv.is_null);
|
||||
@ -2143,6 +2250,20 @@ DoubleVal AggregateFunctions::knuth_stddev_finalize(FunctionContext* ctx,
|
||||
return DoubleVal(variance);
|
||||
}
|
||||
|
||||
DecimalV2Val AggregateFunctions::decimalv2_knuth_stddev_finalize(FunctionContext* ctx,
|
||||
const StringVal& state_sv) {
|
||||
DCHECK(!state_sv.is_null);
|
||||
DCHECK_EQ(state_sv.len, sizeof(DecimalV2KnuthVarianceState));
|
||||
DecimalV2KnuthVarianceState* state = reinterpret_cast<DecimalV2KnuthVarianceState*>(state_sv.ptr);
|
||||
if (state->count == 0 || state->count == 1) return DecimalV2Val::null();
|
||||
DecimalV2Value variance = decimalv2_compute_knuth_variance(*state, false);
|
||||
variance = DecimalV2Value::sqrt(variance);
|
||||
DecimalV2Val res;
|
||||
variance.to_decimal_val(&res);
|
||||
delete (DecimalV2KnuthVarianceState*)state_sv.ptr;
|
||||
return res;
|
||||
}
|
||||
|
||||
DoubleVal AggregateFunctions::knuth_stddev_pop_finalize(FunctionContext* ctx,
|
||||
const StringVal& state_sv) {
|
||||
DCHECK(!state_sv.is_null);
|
||||
@ -2154,6 +2275,20 @@ DoubleVal AggregateFunctions::knuth_stddev_pop_finalize(FunctionContext* ctx,
|
||||
return DoubleVal(variance);
|
||||
}
|
||||
|
||||
DecimalV2Val AggregateFunctions::decimalv2_knuth_stddev_pop_finalize(FunctionContext* ctx,
|
||||
const StringVal& state_sv) {
|
||||
DCHECK(!state_sv.is_null);
|
||||
DCHECK_EQ(state_sv.len, sizeof(DecimalV2KnuthVarianceState));
|
||||
DecimalV2KnuthVarianceState* state = reinterpret_cast<DecimalV2KnuthVarianceState*>(state_sv.ptr);
|
||||
if (state->count == 0) return DecimalV2Val::null();
|
||||
DecimalV2Value variance = decimalv2_compute_knuth_variance(*state, true);
|
||||
variance = DecimalV2Value::sqrt(variance);
|
||||
DecimalV2Val res;
|
||||
variance.to_decimal_val(&res);
|
||||
delete (DecimalV2KnuthVarianceState*)state_sv.ptr;
|
||||
return res;
|
||||
}
|
||||
|
||||
struct RankState {
|
||||
int64_t rank;
|
||||
int64_t count;
|
||||
|
||||
Reference in New Issue
Block a user