[Bug](fix) fix the percentile func result do not equal the percentile array rewrite result (#49379)

cherry pick https://github.com/apache/doris/pull/49351
This commit is contained in:
HappenLee
2025-03-29 08:56:24 +08:00
committed by GitHub
parent 8f15e62de5
commit 4a31fc4e09
13 changed files with 126 additions and 108 deletions

View File

@ -26,6 +26,7 @@
namespace doris {
template <typename T>
class Counts {
public:
Counts() = default;
@ -40,7 +41,7 @@ public:
}
}
void increment(int64_t key, uint32_t i) {
void increment(T key, uint32_t i) {
auto item = _counts.find(key);
if (item != _counts.end()) {
item->second += i;
@ -50,8 +51,7 @@ public:
}
uint32_t serialized_size() const {
return sizeof(uint32_t) + sizeof(int64_t) * _counts.size() +
sizeof(uint32_t) * _counts.size();
return sizeof(uint32_t) + sizeof(T) * _counts.size() + sizeof(uint32_t) * _counts.size();
}
void serialize(uint8_t* writer) const {
@ -59,8 +59,8 @@ public:
memcpy(writer, &size, sizeof(uint32_t));
writer += sizeof(uint32_t);
for (auto& cell : _counts) {
memcpy(writer, &cell.first, sizeof(int64_t));
writer += sizeof(int64_t);
memcpy(writer, &cell.first, sizeof(T));
writer += sizeof(T);
memcpy(writer, &cell.second, sizeof(uint32_t));
writer += sizeof(uint32_t);
}
@ -71,18 +71,17 @@ public:
memcpy(&size, type_reader, sizeof(uint32_t));
type_reader += sizeof(uint32_t);
for (uint32_t i = 0; i < size; ++i) {
int64_t key;
T key;
uint32_t count;
memcpy(&key, type_reader, sizeof(int64_t));
type_reader += sizeof(int64_t);
memcpy(&key, type_reader, sizeof(T));
type_reader += sizeof(T);
memcpy(&count, type_reader, sizeof(uint32_t));
type_reader += sizeof(uint32_t);
_counts.emplace(std::make_pair(key, count));
}
}
double get_percentile(std::vector<std::pair<int64_t, uint32_t>>& counts,
double position) const {
double get_percentile(std::vector<std::pair<T, uint32_t>>& counts, double position) const {
long lower = long(std::floor(position));
long higher = long(std::ceil(position));
@ -90,7 +89,7 @@ public:
for (; iter != counts.end() && iter->second < lower + 1; ++iter)
;
int64_t lower_key = iter->first;
T lower_key = iter->first;
if (higher == lower) {
return lower_key;
}
@ -99,7 +98,7 @@ public:
iter++;
}
int64_t higher_key = iter->first;
T higher_key = iter->first;
if (lower_key == higher_key) {
return lower_key;
}
@ -114,9 +113,9 @@ public:
return 0.0;
}
std::vector<std::pair<int64_t, uint32_t>> elems(_counts.begin(), _counts.end());
std::vector<std::pair<T, uint32_t>> elems(_counts.begin(), _counts.end());
sort(elems.begin(), elems.end(),
[](const std::pair<int64_t, uint32_t> l, const std::pair<int64_t, uint32_t> r) {
[](const std::pair<T, uint32_t> l, const std::pair<T, uint32_t> r) {
return l.first < r.first;
});
@ -132,7 +131,7 @@ public:
}
private:
std::unordered_map<int64_t, uint32_t> _counts;
std::unordered_map<T, uint32_t> _counts;
};
} // namespace doris

View File

@ -50,9 +50,10 @@ AggregateFunctionPtr create_aggregate_function_percentile_approx(
void register_aggregate_function_percentile(AggregateFunctionSimpleFactory& factory) {
factory.register_function_both("percentile",
creator_without_type::creator<AggregateFunctionPercentile>);
factory.register_function_both("percentile_array",
creator_without_type::creator<AggregateFunctionPercentileArray>);
creator_with_numeric_type::creator<AggregateFunctionPercentile>);
factory.register_function_both(
"percentile_array",
creator_with_numeric_type::creator<AggregateFunctionPercentileArray>);
}
void register_aggregate_function_percentile_approx(AggregateFunctionSimpleFactory& factory) {

View File

@ -288,8 +288,9 @@ public:
}
};
template <typename T>
struct PercentileState {
std::vector<Counts> vec_counts;
std::vector<Counts<T>> vec_counts;
std::vector<double> vec_quantile {-1};
bool inited_flag = false;
@ -327,7 +328,7 @@ struct PercentileState {
}
}
void add(int64_t source, const PaddedPODArray<Float64>& quantiles, int arg_size) {
void add(T source, const PaddedPODArray<Float64>& quantiles, int arg_size) {
if (!inited_flag) {
vec_counts.resize(arg_size);
vec_quantile.resize(arg_size, -1);
@ -376,11 +377,12 @@ struct PercentileState {
}
};
template <typename T>
class AggregateFunctionPercentile final
: public IAggregateFunctionDataHelper<PercentileState, AggregateFunctionPercentile> {
: public IAggregateFunctionDataHelper<PercentileState<T>, AggregateFunctionPercentile<T>> {
public:
AggregateFunctionPercentile(const DataTypes& argument_types_)
: IAggregateFunctionDataHelper<PercentileState, AggregateFunctionPercentile>(
: IAggregateFunctionDataHelper<PercentileState<T>, AggregateFunctionPercentile<T>>(
argument_types_) {}
String get_name() const override { return "percentile"; }
@ -389,10 +391,10 @@ public:
void add(AggregateDataPtr __restrict place, const IColumn** columns, ssize_t row_num,
Arena*) const override {
const auto& sources = assert_cast<const ColumnVector<Int64>&>(*columns[0]);
const auto& sources = assert_cast<const ColumnVector<T>&>(*columns[0]);
const auto& quantile = assert_cast<const ColumnVector<Float64>&>(*columns[1]);
AggregateFunctionPercentile::data(place).add(sources.get_int(row_num), quantile.get_data(),
1);
AggregateFunctionPercentile::data(place).add(sources.get_element(row_num),
quantile.get_data(), 1);
}
void reset(AggregateDataPtr __restrict place) const override {
@ -419,11 +421,13 @@ public:
}
};
template <typename T>
class AggregateFunctionPercentileArray final
: public IAggregateFunctionDataHelper<PercentileState, AggregateFunctionPercentileArray> {
: public IAggregateFunctionDataHelper<PercentileState<T>,
AggregateFunctionPercentileArray<T>> {
public:
AggregateFunctionPercentileArray(const DataTypes& argument_types_)
: IAggregateFunctionDataHelper<PercentileState, AggregateFunctionPercentileArray>(
: IAggregateFunctionDataHelper<PercentileState<T>, AggregateFunctionPercentileArray<T>>(
argument_types_) {}
String get_name() const override { return "percentile_array"; }
@ -434,7 +438,7 @@ public:
void add(AggregateDataPtr __restrict place, const IColumn** columns, ssize_t row_num,
Arena*) const override {
const auto& sources = assert_cast<const ColumnVector<Int64>&>(*columns[0]);
const auto& sources = assert_cast<const ColumnVector<T>&>(*columns[0]);
const auto& quantile_array = assert_cast<const ColumnArray&>(*columns[1]);
const auto& offset_column_data = quantile_array.get_offsets();
const auto& nested_column =
@ -442,7 +446,7 @@ public:
const auto& nested_column_data = assert_cast<const ColumnVector<Float64>&>(nested_column);
AggregateFunctionPercentileArray::data(place).add(
sources.get_int(row_num), nested_column_data.get_data(),
sources.get_element(row_num), nested_column_data.get_data(),
offset_column_data.data()[row_num] - offset_column_data[(ssize_t)row_num - 1]);
}
@ -480,4 +484,4 @@ public:
}
};
} // namespace doris::vectorized
} // namespace doris::vectorized

View File

@ -27,7 +27,7 @@ namespace doris {
class TCountsTest : public testing::Test {};
TEST_F(TCountsTest, TotalTest) {
Counts counts;
Counts<int64_t> counts;
// 1 1 1 2 5 7 7 9 9 19
// >>> import numpy as np
// >>> a = np.array([1,1,1,2,5,7,7,9,9,19])
@ -46,12 +46,12 @@ TEST_F(TCountsTest, TotalTest) {
uint8_t* type_reader = writer;
counts.serialize(writer);
Counts other;
Counts<int64_t> other;
other.unserialize(type_reader);
double result1 = other.terminate(0.2);
EXPECT_EQ(result, result1);
Counts other1;
Counts<int64_t> other1;
other1.increment(1, 1);
other1.increment(100, 3);
other1.increment(50, 3);

View File

@ -131,7 +131,7 @@ public class BuiltinAggregateFunctions implements FunctionHelper {
agg(OrthogonalBitmapIntersect.class, "orthogonal_bitmap_intersect"),
agg(OrthogonalBitmapIntersectCount.class, "orthogonal_bitmap_intersect_count"),
agg(OrthogonalBitmapUnionCount.class, "orthogonal_bitmap_union_count"),
agg(Percentile.class, "percentile"),
agg(Percentile.class, "percentile", "percentile_cont"),
agg(PercentileApprox.class, "percentile_approx"),
agg(PercentileArray.class, "percentile_array"),
agg(QuantileUnion.class, "quantile_union"),

View File

@ -1430,6 +1430,15 @@ public class FunctionSet<T> {
"",
false, true, false, true));
addBuiltin(AggregateFunction.createBuiltin("percentile_cont",
Lists.newArrayList(Type.BIGINT, Type.DOUBLE), Type.DOUBLE, Type.VARCHAR,
"",
"",
"",
"",
"",
false, true, false, true));
addBuiltin(AggregateFunction.createBuiltin("percentile_approx",
Lists.<Type>newArrayList(Type.DOUBLE, Type.DOUBLE), Type.DOUBLE, Type.VARCHAR,
"",

View File

@ -25,6 +25,11 @@ import org.apache.doris.nereids.trees.expressions.shape.BinaryExpression;
import org.apache.doris.nereids.trees.expressions.visitor.ExpressionVisitor;
import org.apache.doris.nereids.types.BigIntType;
import org.apache.doris.nereids.types.DoubleType;
import org.apache.doris.nereids.types.FloatType;
import org.apache.doris.nereids.types.IntegerType;
import org.apache.doris.nereids.types.LargeIntType;
import org.apache.doris.nereids.types.SmallIntType;
import org.apache.doris.nereids.types.TinyIntType;
import com.google.common.base.Preconditions;
import com.google.common.collect.ImmutableList;
@ -38,7 +43,13 @@ public class Percentile extends NullableAggregateFunction
implements BinaryExpression, ExplicitlyCastableSignature {
public static final List<FunctionSignature> SIGNATURES = ImmutableList.of(
FunctionSignature.ret(DoubleType.INSTANCE).args(BigIntType.INSTANCE, DoubleType.INSTANCE)
FunctionSignature.ret(DoubleType.INSTANCE).args(DoubleType.INSTANCE, DoubleType.INSTANCE),
FunctionSignature.ret(DoubleType.INSTANCE).args(FloatType.INSTANCE, DoubleType.INSTANCE),
FunctionSignature.ret(DoubleType.INSTANCE).args(LargeIntType.INSTANCE, DoubleType.INSTANCE),
FunctionSignature.ret(DoubleType.INSTANCE).args(BigIntType.INSTANCE, DoubleType.INSTANCE),
FunctionSignature.ret(DoubleType.INSTANCE).args(IntegerType.INSTANCE, DoubleType.INSTANCE),
FunctionSignature.ret(DoubleType.INSTANCE).args(SmallIntType.INSTANCE, DoubleType.INSTANCE),
FunctionSignature.ret(DoubleType.INSTANCE).args(TinyIntType.INSTANCE, DoubleType.INSTANCE)
);
/**

View File

@ -26,6 +26,11 @@ import org.apache.doris.nereids.trees.expressions.visitor.ExpressionVisitor;
import org.apache.doris.nereids.types.ArrayType;
import org.apache.doris.nereids.types.BigIntType;
import org.apache.doris.nereids.types.DoubleType;
import org.apache.doris.nereids.types.FloatType;
import org.apache.doris.nereids.types.IntegerType;
import org.apache.doris.nereids.types.LargeIntType;
import org.apache.doris.nereids.types.SmallIntType;
import org.apache.doris.nereids.types.TinyIntType;
import com.google.common.base.Preconditions;
import com.google.common.collect.ImmutableList;
@ -40,8 +45,19 @@ public class PercentileArray extends AggregateFunction
public static final List<FunctionSignature> SIGNATURES = ImmutableList.of(
FunctionSignature.ret(ArrayType.of(DoubleType.INSTANCE))
.args(BigIntType.INSTANCE, ArrayType.of(DoubleType.INSTANCE))
);
.args(DoubleType.INSTANCE, ArrayType.of(DoubleType.INSTANCE)),
FunctionSignature.ret(ArrayType.of(DoubleType.INSTANCE))
.args(FloatType.INSTANCE, ArrayType.of(DoubleType.INSTANCE)),
FunctionSignature.ret(ArrayType.of(DoubleType.INSTANCE))
.args(LargeIntType.INSTANCE, ArrayType.of(DoubleType.INSTANCE)),
FunctionSignature.ret(ArrayType.of(DoubleType.INSTANCE))
.args(BigIntType.INSTANCE, ArrayType.of(DoubleType.INSTANCE)),
FunctionSignature.ret(ArrayType.of(DoubleType.INSTANCE))
.args(IntegerType.INSTANCE, ArrayType.of(DoubleType.INSTANCE)),
FunctionSignature.ret(ArrayType.of(DoubleType.INSTANCE))
.args(SmallIntType.INSTANCE, ArrayType.of(DoubleType.INSTANCE)),
FunctionSignature.ret(ArrayType.of(DoubleType.INSTANCE))
.args(TinyIntType.INSTANCE, ArrayType.of(DoubleType.INSTANCE)));
/**
* constructor with 2 arguments.

View File

@ -182,6 +182,7 @@ public class GenerateFunction {
.put("any", "any_value")
.put("char_length", "character_length")
.put("stddev_pop", "stddev")
.put("percentile_cont", "percentile")
.put("var_pop", "variance")
.put("variance_pop", "variance")
.put("var_samp", "variance_samp")

View File

@ -384,11 +384,11 @@ sichuan [{"cbe":{},"notnull":0,"null":1,"all":1}]
2 [123456789, 223456789, 323456789]
2 [123456789, 223456789, 323456789]
2 [123456789, 223456789, 323456789]
3 [223456789, 223456789, 323456789]
3 [223456789, 223456789, 323456789]
3 [223456789, 223456789, 323456789]
3 [223456789, 223456789, 323456789]
3 [223456789, 223456789, 323456789]
3 [223456789.6, 223456789.6, 323456789.1]
3 [223456789.6, 223456789.6, 323456789.1]
3 [223456789.6, 223456789.6, 323456789.1]
3 [223456789.6, 223456789.6, 323456789.1]
3 [223456789.6, 223456789.6, 323456789.1]
-- !agg_window_percentile_approx --
1 5.234568E8

View File

@ -697,3 +697,6 @@ TESTING AGAIN
-- !aggregate_limit_contain_null --
16 \N
-- !aggregate35 --
12.25 32.625 43.75 54.999999999999986 77.49999999999994 \N

View File

@ -1,66 +0,0 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.
import org.codehaus.groovy.runtime.IOGroovyMethods
suite ("mv_percentile") {
sql "set enable_fallback_to_original_planner = false"
sql """DROP TABLE IF EXISTS d_table;"""
sql """
create table d_table(
k1 int null,
k2 int not null,
k3 decimal(28,6) null,
k4 varchar(100) null
)
duplicate key (k1,k2,k3)
distributed BY hash(k1) buckets 3
properties("replication_num" = "1");
"""
sql "insert into d_table select 1,1,1,'a';"
sql "insert into d_table select 2,2,2,'b';"
sql "insert into d_table select 3,-3,null,'c';"
createMV("create materialized view kp as select k1,k2,percentile(k3, 0.1),percentile(k3, 0.9) from d_table group by k1,k2;")
sql "insert into d_table select -4,-4,-4,'d';"
sql "insert into d_table(k4,k2) values('d',4);"
qt_select_star "select * from d_table order by k1;"
explain {
sql("select k1,k2,percentile(k3, 0.1),percentile(k3, 0.9) from d_table group by k1,k2 order by k1,k2;")
contains "(kp)"
}
qt_select_mv "select k1,k2,percentile(k3, 0.1),percentile(k3, 0.9) from d_table group by k1,k2 order by k1,k2;"
explain {
sql("select k1,k2,percentile(k3, 0.1),percentile(k3, 0.9) from d_table group by grouping sets((k1),(k1,k2),()) order by 1,2;")
contains "(kp)"
}
qt_select_mv "select k1,k2,percentile(k3, 0.1),percentile(k3, 0.9) from d_table group by grouping sets((k1),(k1,k2),()) order by 1,2,3;"
explain {
sql("select percentile(k3, 0.1) from d_table group by grouping sets((k1),()) order by 1;")
contains "(kp)"
}
qt_select_mv "select percentile(k3, 0.1) from d_table group by grouping sets((k1),()) order by 1;"
}

View File

@ -308,4 +308,44 @@ suite("aggregate") {
qt_aggregate_limit_contain_null """
select count(), cast(k12 as int) as t from baseall group by t limit 1;
"""
// Test case for percentile function with sales data
sql """ DROP TABLE IF EXISTS sales_data """
sql """
CREATE TABLE sales_data (
product_id INT,
sale_price DECIMAL(10, 2)
) DUPLICATE KEY(`product_id`)
DISTRIBUTED BY HASH(`product_id`) BUCKETS 1
PROPERTIES (
"replication_allocation" = "tag.location.default: 1"
)
"""
sql """
INSERT INTO sales_data VALUES
(1, 10.00),
(1, 15.00),
(1, 20.00),
(1, 25.00),
(1, 30.25),
(1, 35.00),
(1, 40.00),
(1, 45.00),
(1, 50.00),
(1, 100.00)
"""
qt_aggregate35 """
SELECT
percentile(sale_price, 0.05) as median_price_05,
percentile(sale_price, 0.5) as median_price,
percentile(sale_price, 0.75) as p75_price,
percentile(sale_price, 0.90) as p90_price,
percentile(sale_price, 0.95) as p95_price,
percentile(null, 0.99) as p99_null
FROM sales_data
"""
sql """ DROP TABLE IF EXISTS sales_data """
}