[Function] Support Decimal to calculate variance and standard deviation (#4959)

This commit is contained in:
Xinyi Zou
2020-12-06 08:49:01 +08:00
committed by GitHub
parent 42dd821021
commit b1b99ae884
5 changed files with 412 additions and 32 deletions

View File

@ -2061,6 +2061,13 @@ struct KnuthVarianceState {
int64_t count;
};
// Use Decimal to store the intermediate results of the variance algorithm
struct DecimalV2KnuthVarianceState {
DecimalV2Val mean;
DecimalV2Val m2;
int64_t count = 0;
};
// Set pop=true for population variance, false for sample variance
static double compute_knuth_variance(const KnuthVarianceState& state, bool pop) {
// Return zero for 1 tuple specified by
@ -2070,6 +2077,16 @@ static double compute_knuth_variance(const KnuthVarianceState& state, bool pop)
return state.m2 / (state.count - 1);
}
// The algorithm is the same as above, using decimal as the intermediate variable
static DecimalV2Value decimalv2_compute_knuth_variance(const DecimalV2KnuthVarianceState& state, bool pop) {
DecimalV2Value new_count = DecimalV2Value();
new_count.assign_from_double(state.count);
if (state.count == 1) return new_count;
DecimalV2Value new_m2 = DecimalV2Value::from_decimal_val(state.m2);
if (pop) return new_m2 / new_count;
else return new_m2 / new_count.assign_from_double(state.count - 1);
}
void AggregateFunctions::knuth_var_init(FunctionContext* ctx, StringVal* dst) {
dst->is_null = false;
// TODO(zc)
@ -2079,6 +2096,15 @@ void AggregateFunctions::knuth_var_init(FunctionContext* ctx, StringVal* dst) {
memset(dst->ptr, 0, dst->len);
}
void AggregateFunctions::decimalv2_knuth_var_init(FunctionContext* ctx, StringVal* dst) {
dst->is_null = false;
dst->len = sizeof(DecimalV2KnuthVarianceState);
// The memory for int128 need to be aligned by 16.
// So the constructor has been used instead of allocating memory.
// Also, it will be release in finalize.
dst->ptr = (uint8_t*) new DecimalV2KnuthVarianceState;
}
template <typename T>
void AggregateFunctions::knuth_var_update(FunctionContext* ctx, const T& src, StringVal* dst) {
DCHECK(!dst->is_null);
@ -2093,6 +2119,34 @@ void AggregateFunctions::knuth_var_update(FunctionContext* ctx, const T& src, St
state->count = temp;
}
void AggregateFunctions::knuth_var_update(FunctionContext* ctx, const DecimalV2Val& src, StringVal* dst) {
DCHECK(!dst->is_null);
DCHECK_EQ(dst->len, sizeof(DecimalV2KnuthVarianceState));
if (src.is_null) return;
DecimalV2KnuthVarianceState* state = reinterpret_cast<DecimalV2KnuthVarianceState*>(dst->ptr);
DecimalV2Value new_src = DecimalV2Value::from_decimal_val(src);
DecimalV2Value new_mean = DecimalV2Value::from_decimal_val(state->mean);
DecimalV2Value new_m2 = DecimalV2Value::from_decimal_val(state->m2);
DecimalV2Value new_count = DecimalV2Value();
new_count.assign_from_double(state->count);
DecimalV2Value temp = DecimalV2Value();
temp.assign_from_double(1 + state->count);
DecimalV2Value delta = new_src - new_mean;
DecimalV2Value r = delta / temp;
new_mean += r;
// This may cause Decimal to overflow. When it overflows, m2 will be equal to 9223372036854775807999999999,
// which is the maximum value that DecimalV2Value can represent. When using double to store the intermediate result m2,
// it can be expressed by scientific and technical methods and will not overflow.
// Spark's handling of decimal overflow is to return null or report an error, which can be controlled by parameters.
// Spark's handling of decimal reference: https://cloud.tencent.com/developer/news/483615
new_m2 += new_count * delta * r;
++state->count;
new_mean.to_decimal_val(&state->mean);
new_m2.to_decimal_val(&state->m2);
}
void AggregateFunctions::knuth_var_merge(FunctionContext* ctx, const StringVal& src,
StringVal* dst) {
DCHECK(!dst->is_null);
@ -2112,6 +2166,33 @@ void AggregateFunctions::knuth_var_merge(FunctionContext* ctx, const StringVal&
dst_state->count = sum_count;
}
void AggregateFunctions::decimalv2_knuth_var_merge(FunctionContext* ctx, const StringVal& src,
StringVal* dst) {
DecimalV2KnuthVarianceState src_state;
memcpy(&src_state, src.ptr, sizeof(DecimalV2KnuthVarianceState));
DCHECK(!dst->is_null);
DCHECK_EQ(dst->len, sizeof(DecimalV2KnuthVarianceState));
DecimalV2KnuthVarianceState* dst_state = reinterpret_cast<DecimalV2KnuthVarianceState*>(dst->ptr);
if (src_state.count == 0) return;
DecimalV2Value new_src_mean = DecimalV2Value::from_decimal_val(src_state.mean);
DecimalV2Value new_dst_mean = DecimalV2Value::from_decimal_val(dst_state->mean);
DecimalV2Value new_src_count = DecimalV2Value();
new_src_count.assign_from_double(src_state.count);
DecimalV2Value new_dst_count = DecimalV2Value();
new_dst_count.assign_from_double(dst_state->count);
DecimalV2Value new_src_m2 = DecimalV2Value::from_decimal_val(src_state.m2);
DecimalV2Value new_dst_m2 = DecimalV2Value::from_decimal_val(dst_state->m2);
DecimalV2Value delta = new_dst_mean - new_src_mean;
DecimalV2Value sum_count = new_dst_count + new_src_count;
new_dst_mean = new_src_mean + delta * (new_dst_count / sum_count);
new_dst_m2 = (new_src_m2) + new_dst_m2 + (delta * delta) * (new_src_count * new_dst_count / sum_count);
dst_state->count += src_state.count;
new_dst_mean.to_decimal_val(&dst_state->mean);
new_dst_m2.to_decimal_val(&dst_state->m2);
}
DoubleVal AggregateFunctions::knuth_var_finalize(FunctionContext* ctx, const StringVal& state_sv) {
DCHECK(!state_sv.is_null);
KnuthVarianceState* state = reinterpret_cast<KnuthVarianceState*>(state_sv.ptr);
@ -2121,6 +2202,19 @@ DoubleVal AggregateFunctions::knuth_var_finalize(FunctionContext* ctx, const Str
return DoubleVal(variance);
}
DecimalV2Val AggregateFunctions::decimalv2_knuth_var_finalize(FunctionContext* ctx,
const StringVal& state_sv) {
DCHECK(!state_sv.is_null);
DCHECK_EQ(state_sv.len, sizeof(DecimalV2KnuthVarianceState));
DecimalV2KnuthVarianceState* state = reinterpret_cast<DecimalV2KnuthVarianceState*>(state_sv.ptr);
if (state->count == 0 || state->count == 1) return DecimalV2Val::null();
DecimalV2Value variance = decimalv2_compute_knuth_variance(*state, false);
DecimalV2Val res;
variance.to_decimal_val(&res);
delete (DecimalV2KnuthVarianceState*)state_sv.ptr;
return res;
}
DoubleVal AggregateFunctions::knuth_var_pop_finalize(FunctionContext* ctx,
const StringVal& state_sv) {
DCHECK(!state_sv.is_null);
@ -2132,6 +2226,19 @@ DoubleVal AggregateFunctions::knuth_var_pop_finalize(FunctionContext* ctx,
return DoubleVal(variance);
}
DecimalV2Val AggregateFunctions::decimalv2_knuth_var_pop_finalize(FunctionContext* ctx,
const StringVal& state_sv) {
DCHECK(!state_sv.is_null);
DCHECK_EQ(state_sv.len, sizeof(DecimalV2KnuthVarianceState));
DecimalV2KnuthVarianceState* state = reinterpret_cast<DecimalV2KnuthVarianceState*>(state_sv.ptr);
if (state->count == 0) return DecimalV2Val::null();
DecimalV2Value variance = decimalv2_compute_knuth_variance(*state, true);
DecimalV2Val res;
variance.to_decimal_val(&res);
delete (DecimalV2KnuthVarianceState*)state_sv.ptr;
return res;
}
DoubleVal AggregateFunctions::knuth_stddev_finalize(FunctionContext* ctx,
const StringVal& state_sv) {
DCHECK(!state_sv.is_null);
@ -2143,6 +2250,20 @@ DoubleVal AggregateFunctions::knuth_stddev_finalize(FunctionContext* ctx,
return DoubleVal(variance);
}
DecimalV2Val AggregateFunctions::decimalv2_knuth_stddev_finalize(FunctionContext* ctx,
const StringVal& state_sv) {
DCHECK(!state_sv.is_null);
DCHECK_EQ(state_sv.len, sizeof(DecimalV2KnuthVarianceState));
DecimalV2KnuthVarianceState* state = reinterpret_cast<DecimalV2KnuthVarianceState*>(state_sv.ptr);
if (state->count == 0 || state->count == 1) return DecimalV2Val::null();
DecimalV2Value variance = decimalv2_compute_knuth_variance(*state, false);
variance = DecimalV2Value::sqrt(variance);
DecimalV2Val res;
variance.to_decimal_val(&res);
delete (DecimalV2KnuthVarianceState*)state_sv.ptr;
return res;
}
DoubleVal AggregateFunctions::knuth_stddev_pop_finalize(FunctionContext* ctx,
const StringVal& state_sv) {
DCHECK(!state_sv.is_null);
@ -2154,6 +2275,20 @@ DoubleVal AggregateFunctions::knuth_stddev_pop_finalize(FunctionContext* ctx,
return DoubleVal(variance);
}
DecimalV2Val AggregateFunctions::decimalv2_knuth_stddev_pop_finalize(FunctionContext* ctx,
const StringVal& state_sv) {
DCHECK(!state_sv.is_null);
DCHECK_EQ(state_sv.len, sizeof(DecimalV2KnuthVarianceState));
DecimalV2KnuthVarianceState* state = reinterpret_cast<DecimalV2KnuthVarianceState*>(state_sv.ptr);
if (state->count == 0) return DecimalV2Val::null();
DecimalV2Value variance = decimalv2_compute_knuth_variance(*state, true);
variance = DecimalV2Value::sqrt(variance);
DecimalV2Val res;
variance.to_decimal_val(&res);
delete (DecimalV2KnuthVarianceState*)state_sv.ptr;
return res;
}
struct RankState {
int64_t rank;
int64_t count;