Files
doris/be/src/exprs/quantile_function.cpp

162 lines
6.5 KiB
C++

// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.
#include "exprs/quantile_function.h"
#include "runtime/collection_value.h"
#include "util/quantile_state.h"
#include "util/slice.h"
#include "util/string_parser.hpp"
namespace doris {
using doris_udf::DoubleVal;
using doris_udf::StringVal;
using doris_udf::FloatVal;
void QuantileStateFunctions::init() {}
void QuantileStateFunctions::quantile_state_init(FunctionContext* ctx, StringVal* dst) {
dst->is_null = false;
dst->len = sizeof(QuantileState<double>);
dst->ptr = (uint8_t*)new QuantileState<double>();
}
void QuantileStateFunctions::quantile_percent_prepare(FunctionContext* ctx,
FunctionContext::FunctionStateScope scope) {
if (scope != FunctionContext::FRAGMENT_LOCAL) {
return;
}
if (!ctx->is_arg_constant(1)) {
std::stringstream ss;
ss << "quantile_percent function's second arg must be constant.";
ctx->set_error(ss.str().c_str());
return;
}
float percentile_value = reinterpret_cast<const FloatVal*>(ctx->get_constant_arg(1))->val;
if (percentile_value > 1 || percentile_value < 0) {
std::stringstream error_msg;
error_msg << "The percentile must between 0 and 1, but input is:"
<< std::to_string(percentile_value);
ctx->set_error(error_msg.str().c_str());
return;
}
}
void QuantileStateFunctions::to_quantile_state_prepare(FunctionContext* ctx,
FunctionContext::FunctionStateScope scope) {
if (scope != FunctionContext::FRAGMENT_LOCAL) {
return;
}
if (!ctx->is_arg_constant(1)) {
// use default value, just return is ok.
return;
}
float compression = reinterpret_cast<const FloatVal*>(ctx->get_constant_arg(1))->val;
if (compression > QUANTILE_STATE_COMPRESSION_MAX ||
compression < QUANTILE_STATE_COMPRESSION_MIN) {
std::stringstream error_msg;
error_msg << "The compression of to_quantile_state must between "
<< QUANTILE_STATE_COMPRESSION_MIN << " and " << QUANTILE_STATE_COMPRESSION_MAX
<< std::endl
<< "but input is:" << std::to_string(compression);
ctx->set_error(error_msg.str().c_str());
return;
}
}
static StringVal serialize(FunctionContext* ctx, QuantileState<double>* value) {
StringVal result(ctx, value->get_serialized_size());
value->serialize(result.ptr);
return result;
}
StringVal QuantileStateFunctions::to_quantile_state(FunctionContext* ctx, const StringVal& src) {
QuantileState<double> quantile_state;
quantile_state.set_compression(QUANTILE_STATE_COMPRESSION_MIN);
const AnyVal* digest_compression = ctx->get_constant_arg(1);
if (digest_compression != nullptr) {
// compression will be between 2048 and 10000, promised by `to_quantile_state_prepare`
float compression = reinterpret_cast<const FloatVal*>(digest_compression)->val;
quantile_state.set_compression(compression);
}
if (!src.is_null) {
StringParser::ParseResult parse_result = StringParser::PARSE_SUCCESS;
double double_value = StringParser::string_to_float<double>(
reinterpret_cast<char*>(src.ptr), src.len, &parse_result);
if (UNLIKELY(parse_result != StringParser::PARSE_SUCCESS)) {
std::stringstream error_msg;
error_msg << "The input: " << std::string(reinterpret_cast<char*>(src.ptr), src.len)
<< " is not valid, to_quantile_state only support bigint value from 0 to "
"18446744073709551615 currently";
ctx->set_error(error_msg.str().c_str());
return StringVal::null();
}
quantile_state.add_value(double_value);
}
return serialize(ctx, &quantile_state);
}
void QuantileStateFunctions::quantile_union(FunctionContext* ctx, const StringVal& src,
StringVal* dst) {
if (src.is_null) {
return;
}
auto dst_quantile = reinterpret_cast<QuantileState<double>*>(dst->ptr);
if (src.len == 0) {
dst_quantile->merge(*reinterpret_cast<QuantileState<double>*>(src.ptr));
} else {
QuantileState<double> state(Slice(src.ptr, src.len));
dst_quantile->merge(state);
}
}
DoubleVal QuantileStateFunctions::quantile_percent(FunctionContext* ctx, StringVal& src) {
const AnyVal* percentile = ctx->get_constant_arg(1);
if (percentile != nullptr) {
// percentile_value will be between 0 and 1, promised by `quantile_percent_prepare`
float percentile_value = reinterpret_cast<const FloatVal*>(percentile)->val;
if (src.len == 0) {
auto quantile_state = reinterpret_cast<QuantileState<double>*>(src.ptr);
return {static_cast<double>(quantile_state->get_value_by_percentile(percentile_value))};
} else {
QuantileState<double> quantile_state(Slice(src.ptr, src.len));
return {static_cast<double>(quantile_state.get_value_by_percentile(percentile_value))};
}
} else {
std::stringstream error_msg;
error_msg << "quantile_percent function's second argument must be constant. eg: "
"quantile_percent(col, 0.95)";
ctx->set_error(error_msg.str().c_str());
}
return DoubleVal::null();
}
StringVal QuantileStateFunctions::quantile_state_serialize(FunctionContext* ctx,
const StringVal& src) {
if (src.is_null) {
return src;
}
auto tmp_ptr = reinterpret_cast<QuantileState<double>*>(src.ptr);
StringVal result = serialize(ctx, tmp_ptr);
delete tmp_ptr;
return result;
}
} // namespace doris