[Bug](fix) fix the percentile func result do not equal the percentile array rewrite result (#49379)
cherry pick https://github.com/apache/doris/pull/49351
This commit is contained in:
@ -26,6 +26,7 @@
|
||||
|
||||
namespace doris {
|
||||
|
||||
template <typename T>
|
||||
class Counts {
|
||||
public:
|
||||
Counts() = default;
|
||||
@ -40,7 +41,7 @@ public:
|
||||
}
|
||||
}
|
||||
|
||||
void increment(int64_t key, uint32_t i) {
|
||||
void increment(T key, uint32_t i) {
|
||||
auto item = _counts.find(key);
|
||||
if (item != _counts.end()) {
|
||||
item->second += i;
|
||||
@ -50,8 +51,7 @@ public:
|
||||
}
|
||||
|
||||
uint32_t serialized_size() const {
|
||||
return sizeof(uint32_t) + sizeof(int64_t) * _counts.size() +
|
||||
sizeof(uint32_t) * _counts.size();
|
||||
return sizeof(uint32_t) + sizeof(T) * _counts.size() + sizeof(uint32_t) * _counts.size();
|
||||
}
|
||||
|
||||
void serialize(uint8_t* writer) const {
|
||||
@ -59,8 +59,8 @@ public:
|
||||
memcpy(writer, &size, sizeof(uint32_t));
|
||||
writer += sizeof(uint32_t);
|
||||
for (auto& cell : _counts) {
|
||||
memcpy(writer, &cell.first, sizeof(int64_t));
|
||||
writer += sizeof(int64_t);
|
||||
memcpy(writer, &cell.first, sizeof(T));
|
||||
writer += sizeof(T);
|
||||
memcpy(writer, &cell.second, sizeof(uint32_t));
|
||||
writer += sizeof(uint32_t);
|
||||
}
|
||||
@ -71,18 +71,17 @@ public:
|
||||
memcpy(&size, type_reader, sizeof(uint32_t));
|
||||
type_reader += sizeof(uint32_t);
|
||||
for (uint32_t i = 0; i < size; ++i) {
|
||||
int64_t key;
|
||||
T key;
|
||||
uint32_t count;
|
||||
memcpy(&key, type_reader, sizeof(int64_t));
|
||||
type_reader += sizeof(int64_t);
|
||||
memcpy(&key, type_reader, sizeof(T));
|
||||
type_reader += sizeof(T);
|
||||
memcpy(&count, type_reader, sizeof(uint32_t));
|
||||
type_reader += sizeof(uint32_t);
|
||||
_counts.emplace(std::make_pair(key, count));
|
||||
}
|
||||
}
|
||||
|
||||
double get_percentile(std::vector<std::pair<int64_t, uint32_t>>& counts,
|
||||
double position) const {
|
||||
double get_percentile(std::vector<std::pair<T, uint32_t>>& counts, double position) const {
|
||||
long lower = long(std::floor(position));
|
||||
long higher = long(std::ceil(position));
|
||||
|
||||
@ -90,7 +89,7 @@ public:
|
||||
for (; iter != counts.end() && iter->second < lower + 1; ++iter)
|
||||
;
|
||||
|
||||
int64_t lower_key = iter->first;
|
||||
T lower_key = iter->first;
|
||||
if (higher == lower) {
|
||||
return lower_key;
|
||||
}
|
||||
@ -99,7 +98,7 @@ public:
|
||||
iter++;
|
||||
}
|
||||
|
||||
int64_t higher_key = iter->first;
|
||||
T higher_key = iter->first;
|
||||
if (lower_key == higher_key) {
|
||||
return lower_key;
|
||||
}
|
||||
@ -114,9 +113,9 @@ public:
|
||||
return 0.0;
|
||||
}
|
||||
|
||||
std::vector<std::pair<int64_t, uint32_t>> elems(_counts.begin(), _counts.end());
|
||||
std::vector<std::pair<T, uint32_t>> elems(_counts.begin(), _counts.end());
|
||||
sort(elems.begin(), elems.end(),
|
||||
[](const std::pair<int64_t, uint32_t> l, const std::pair<int64_t, uint32_t> r) {
|
||||
[](const std::pair<T, uint32_t> l, const std::pair<T, uint32_t> r) {
|
||||
return l.first < r.first;
|
||||
});
|
||||
|
||||
@ -132,7 +131,7 @@ public:
|
||||
}
|
||||
|
||||
private:
|
||||
std::unordered_map<int64_t, uint32_t> _counts;
|
||||
std::unordered_map<T, uint32_t> _counts;
|
||||
};
|
||||
|
||||
} // namespace doris
|
||||
|
||||
@ -50,9 +50,10 @@ AggregateFunctionPtr create_aggregate_function_percentile_approx(
|
||||
|
||||
void register_aggregate_function_percentile(AggregateFunctionSimpleFactory& factory) {
|
||||
factory.register_function_both("percentile",
|
||||
creator_without_type::creator<AggregateFunctionPercentile>);
|
||||
factory.register_function_both("percentile_array",
|
||||
creator_without_type::creator<AggregateFunctionPercentileArray>);
|
||||
creator_with_numeric_type::creator<AggregateFunctionPercentile>);
|
||||
factory.register_function_both(
|
||||
"percentile_array",
|
||||
creator_with_numeric_type::creator<AggregateFunctionPercentileArray>);
|
||||
}
|
||||
|
||||
void register_aggregate_function_percentile_approx(AggregateFunctionSimpleFactory& factory) {
|
||||
|
||||
@ -288,8 +288,9 @@ public:
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
struct PercentileState {
|
||||
std::vector<Counts> vec_counts;
|
||||
std::vector<Counts<T>> vec_counts;
|
||||
std::vector<double> vec_quantile {-1};
|
||||
bool inited_flag = false;
|
||||
|
||||
@ -327,7 +328,7 @@ struct PercentileState {
|
||||
}
|
||||
}
|
||||
|
||||
void add(int64_t source, const PaddedPODArray<Float64>& quantiles, int arg_size) {
|
||||
void add(T source, const PaddedPODArray<Float64>& quantiles, int arg_size) {
|
||||
if (!inited_flag) {
|
||||
vec_counts.resize(arg_size);
|
||||
vec_quantile.resize(arg_size, -1);
|
||||
@ -376,11 +377,12 @@ struct PercentileState {
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
class AggregateFunctionPercentile final
|
||||
: public IAggregateFunctionDataHelper<PercentileState, AggregateFunctionPercentile> {
|
||||
: public IAggregateFunctionDataHelper<PercentileState<T>, AggregateFunctionPercentile<T>> {
|
||||
public:
|
||||
AggregateFunctionPercentile(const DataTypes& argument_types_)
|
||||
: IAggregateFunctionDataHelper<PercentileState, AggregateFunctionPercentile>(
|
||||
: IAggregateFunctionDataHelper<PercentileState<T>, AggregateFunctionPercentile<T>>(
|
||||
argument_types_) {}
|
||||
|
||||
String get_name() const override { return "percentile"; }
|
||||
@ -389,10 +391,10 @@ public:
|
||||
|
||||
void add(AggregateDataPtr __restrict place, const IColumn** columns, ssize_t row_num,
|
||||
Arena*) const override {
|
||||
const auto& sources = assert_cast<const ColumnVector<Int64>&>(*columns[0]);
|
||||
const auto& sources = assert_cast<const ColumnVector<T>&>(*columns[0]);
|
||||
const auto& quantile = assert_cast<const ColumnVector<Float64>&>(*columns[1]);
|
||||
AggregateFunctionPercentile::data(place).add(sources.get_int(row_num), quantile.get_data(),
|
||||
1);
|
||||
AggregateFunctionPercentile::data(place).add(sources.get_element(row_num),
|
||||
quantile.get_data(), 1);
|
||||
}
|
||||
|
||||
void reset(AggregateDataPtr __restrict place) const override {
|
||||
@ -419,11 +421,13 @@ public:
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
class AggregateFunctionPercentileArray final
|
||||
: public IAggregateFunctionDataHelper<PercentileState, AggregateFunctionPercentileArray> {
|
||||
: public IAggregateFunctionDataHelper<PercentileState<T>,
|
||||
AggregateFunctionPercentileArray<T>> {
|
||||
public:
|
||||
AggregateFunctionPercentileArray(const DataTypes& argument_types_)
|
||||
: IAggregateFunctionDataHelper<PercentileState, AggregateFunctionPercentileArray>(
|
||||
: IAggregateFunctionDataHelper<PercentileState<T>, AggregateFunctionPercentileArray<T>>(
|
||||
argument_types_) {}
|
||||
|
||||
String get_name() const override { return "percentile_array"; }
|
||||
@ -434,7 +438,7 @@ public:
|
||||
|
||||
void add(AggregateDataPtr __restrict place, const IColumn** columns, ssize_t row_num,
|
||||
Arena*) const override {
|
||||
const auto& sources = assert_cast<const ColumnVector<Int64>&>(*columns[0]);
|
||||
const auto& sources = assert_cast<const ColumnVector<T>&>(*columns[0]);
|
||||
const auto& quantile_array = assert_cast<const ColumnArray&>(*columns[1]);
|
||||
const auto& offset_column_data = quantile_array.get_offsets();
|
||||
const auto& nested_column =
|
||||
@ -442,7 +446,7 @@ public:
|
||||
const auto& nested_column_data = assert_cast<const ColumnVector<Float64>&>(nested_column);
|
||||
|
||||
AggregateFunctionPercentileArray::data(place).add(
|
||||
sources.get_int(row_num), nested_column_data.get_data(),
|
||||
sources.get_element(row_num), nested_column_data.get_data(),
|
||||
offset_column_data.data()[row_num] - offset_column_data[(ssize_t)row_num - 1]);
|
||||
}
|
||||
|
||||
@ -480,4 +484,4 @@ public:
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace doris::vectorized
|
||||
} // namespace doris::vectorized
|
||||
|
||||
@ -27,7 +27,7 @@ namespace doris {
|
||||
class TCountsTest : public testing::Test {};
|
||||
|
||||
TEST_F(TCountsTest, TotalTest) {
|
||||
Counts counts;
|
||||
Counts<int64_t> counts;
|
||||
// 1 1 1 2 5 7 7 9 9 19
|
||||
// >>> import numpy as np
|
||||
// >>> a = np.array([1,1,1,2,5,7,7,9,9,19])
|
||||
@ -46,12 +46,12 @@ TEST_F(TCountsTest, TotalTest) {
|
||||
uint8_t* type_reader = writer;
|
||||
counts.serialize(writer);
|
||||
|
||||
Counts other;
|
||||
Counts<int64_t> other;
|
||||
other.unserialize(type_reader);
|
||||
double result1 = other.terminate(0.2);
|
||||
EXPECT_EQ(result, result1);
|
||||
|
||||
Counts other1;
|
||||
Counts<int64_t> other1;
|
||||
other1.increment(1, 1);
|
||||
other1.increment(100, 3);
|
||||
other1.increment(50, 3);
|
||||
|
||||
@ -131,7 +131,7 @@ public class BuiltinAggregateFunctions implements FunctionHelper {
|
||||
agg(OrthogonalBitmapIntersect.class, "orthogonal_bitmap_intersect"),
|
||||
agg(OrthogonalBitmapIntersectCount.class, "orthogonal_bitmap_intersect_count"),
|
||||
agg(OrthogonalBitmapUnionCount.class, "orthogonal_bitmap_union_count"),
|
||||
agg(Percentile.class, "percentile"),
|
||||
agg(Percentile.class, "percentile", "percentile_cont"),
|
||||
agg(PercentileApprox.class, "percentile_approx"),
|
||||
agg(PercentileArray.class, "percentile_array"),
|
||||
agg(QuantileUnion.class, "quantile_union"),
|
||||
|
||||
@ -1430,6 +1430,15 @@ public class FunctionSet<T> {
|
||||
"",
|
||||
false, true, false, true));
|
||||
|
||||
addBuiltin(AggregateFunction.createBuiltin("percentile_cont",
|
||||
Lists.newArrayList(Type.BIGINT, Type.DOUBLE), Type.DOUBLE, Type.VARCHAR,
|
||||
"",
|
||||
"",
|
||||
"",
|
||||
"",
|
||||
"",
|
||||
false, true, false, true));
|
||||
|
||||
addBuiltin(AggregateFunction.createBuiltin("percentile_approx",
|
||||
Lists.<Type>newArrayList(Type.DOUBLE, Type.DOUBLE), Type.DOUBLE, Type.VARCHAR,
|
||||
"",
|
||||
|
||||
@ -25,6 +25,11 @@ import org.apache.doris.nereids.trees.expressions.shape.BinaryExpression;
|
||||
import org.apache.doris.nereids.trees.expressions.visitor.ExpressionVisitor;
|
||||
import org.apache.doris.nereids.types.BigIntType;
|
||||
import org.apache.doris.nereids.types.DoubleType;
|
||||
import org.apache.doris.nereids.types.FloatType;
|
||||
import org.apache.doris.nereids.types.IntegerType;
|
||||
import org.apache.doris.nereids.types.LargeIntType;
|
||||
import org.apache.doris.nereids.types.SmallIntType;
|
||||
import org.apache.doris.nereids.types.TinyIntType;
|
||||
|
||||
import com.google.common.base.Preconditions;
|
||||
import com.google.common.collect.ImmutableList;
|
||||
@ -38,7 +43,13 @@ public class Percentile extends NullableAggregateFunction
|
||||
implements BinaryExpression, ExplicitlyCastableSignature {
|
||||
|
||||
public static final List<FunctionSignature> SIGNATURES = ImmutableList.of(
|
||||
FunctionSignature.ret(DoubleType.INSTANCE).args(BigIntType.INSTANCE, DoubleType.INSTANCE)
|
||||
FunctionSignature.ret(DoubleType.INSTANCE).args(DoubleType.INSTANCE, DoubleType.INSTANCE),
|
||||
FunctionSignature.ret(DoubleType.INSTANCE).args(FloatType.INSTANCE, DoubleType.INSTANCE),
|
||||
FunctionSignature.ret(DoubleType.INSTANCE).args(LargeIntType.INSTANCE, DoubleType.INSTANCE),
|
||||
FunctionSignature.ret(DoubleType.INSTANCE).args(BigIntType.INSTANCE, DoubleType.INSTANCE),
|
||||
FunctionSignature.ret(DoubleType.INSTANCE).args(IntegerType.INSTANCE, DoubleType.INSTANCE),
|
||||
FunctionSignature.ret(DoubleType.INSTANCE).args(SmallIntType.INSTANCE, DoubleType.INSTANCE),
|
||||
FunctionSignature.ret(DoubleType.INSTANCE).args(TinyIntType.INSTANCE, DoubleType.INSTANCE)
|
||||
);
|
||||
|
||||
/**
|
||||
|
||||
@ -26,6 +26,11 @@ import org.apache.doris.nereids.trees.expressions.visitor.ExpressionVisitor;
|
||||
import org.apache.doris.nereids.types.ArrayType;
|
||||
import org.apache.doris.nereids.types.BigIntType;
|
||||
import org.apache.doris.nereids.types.DoubleType;
|
||||
import org.apache.doris.nereids.types.FloatType;
|
||||
import org.apache.doris.nereids.types.IntegerType;
|
||||
import org.apache.doris.nereids.types.LargeIntType;
|
||||
import org.apache.doris.nereids.types.SmallIntType;
|
||||
import org.apache.doris.nereids.types.TinyIntType;
|
||||
|
||||
import com.google.common.base.Preconditions;
|
||||
import com.google.common.collect.ImmutableList;
|
||||
@ -40,8 +45,19 @@ public class PercentileArray extends AggregateFunction
|
||||
|
||||
public static final List<FunctionSignature> SIGNATURES = ImmutableList.of(
|
||||
FunctionSignature.ret(ArrayType.of(DoubleType.INSTANCE))
|
||||
.args(BigIntType.INSTANCE, ArrayType.of(DoubleType.INSTANCE))
|
||||
);
|
||||
.args(DoubleType.INSTANCE, ArrayType.of(DoubleType.INSTANCE)),
|
||||
FunctionSignature.ret(ArrayType.of(DoubleType.INSTANCE))
|
||||
.args(FloatType.INSTANCE, ArrayType.of(DoubleType.INSTANCE)),
|
||||
FunctionSignature.ret(ArrayType.of(DoubleType.INSTANCE))
|
||||
.args(LargeIntType.INSTANCE, ArrayType.of(DoubleType.INSTANCE)),
|
||||
FunctionSignature.ret(ArrayType.of(DoubleType.INSTANCE))
|
||||
.args(BigIntType.INSTANCE, ArrayType.of(DoubleType.INSTANCE)),
|
||||
FunctionSignature.ret(ArrayType.of(DoubleType.INSTANCE))
|
||||
.args(IntegerType.INSTANCE, ArrayType.of(DoubleType.INSTANCE)),
|
||||
FunctionSignature.ret(ArrayType.of(DoubleType.INSTANCE))
|
||||
.args(SmallIntType.INSTANCE, ArrayType.of(DoubleType.INSTANCE)),
|
||||
FunctionSignature.ret(ArrayType.of(DoubleType.INSTANCE))
|
||||
.args(TinyIntType.INSTANCE, ArrayType.of(DoubleType.INSTANCE)));
|
||||
|
||||
/**
|
||||
* constructor with 2 arguments.
|
||||
|
||||
@ -182,6 +182,7 @@ public class GenerateFunction {
|
||||
.put("any", "any_value")
|
||||
.put("char_length", "character_length")
|
||||
.put("stddev_pop", "stddev")
|
||||
.put("percentile_cont", "percentile")
|
||||
.put("var_pop", "variance")
|
||||
.put("variance_pop", "variance")
|
||||
.put("var_samp", "variance_samp")
|
||||
|
||||
@ -384,11 +384,11 @@ sichuan [{"cbe":{},"notnull":0,"null":1,"all":1}]
|
||||
2 [123456789, 223456789, 323456789]
|
||||
2 [123456789, 223456789, 323456789]
|
||||
2 [123456789, 223456789, 323456789]
|
||||
3 [223456789, 223456789, 323456789]
|
||||
3 [223456789, 223456789, 323456789]
|
||||
3 [223456789, 223456789, 323456789]
|
||||
3 [223456789, 223456789, 323456789]
|
||||
3 [223456789, 223456789, 323456789]
|
||||
3 [223456789.6, 223456789.6, 323456789.1]
|
||||
3 [223456789.6, 223456789.6, 323456789.1]
|
||||
3 [223456789.6, 223456789.6, 323456789.1]
|
||||
3 [223456789.6, 223456789.6, 323456789.1]
|
||||
3 [223456789.6, 223456789.6, 323456789.1]
|
||||
|
||||
-- !agg_window_percentile_approx --
|
||||
1 5.234568E8
|
||||
|
||||
@ -697,3 +697,6 @@ TESTING AGAIN
|
||||
|
||||
-- !aggregate_limit_contain_null --
|
||||
16 \N
|
||||
|
||||
-- !aggregate35 --
|
||||
12.25 32.625 43.75 54.999999999999986 77.49999999999994 \N
|
||||
|
||||
@ -1,66 +0,0 @@
|
||||
// Licensed to the Apache Software Foundation (ASF) under one
|
||||
// or more contributor license agreements. See the NOTICE file
|
||||
// distributed with this work for additional information
|
||||
// regarding copyright ownership. The ASF licenses this file
|
||||
// to you under the Apache License, Version 2.0 (the
|
||||
// "License"); you may not use this file except in compliance
|
||||
// with the License. You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing,
|
||||
// software distributed under the License is distributed on an
|
||||
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
// KIND, either express or implied. See the License for the
|
||||
// specific language governing permissions and limitations
|
||||
// under the License.
|
||||
|
||||
import org.codehaus.groovy.runtime.IOGroovyMethods
|
||||
|
||||
suite ("mv_percentile") {
|
||||
sql "set enable_fallback_to_original_planner = false"
|
||||
|
||||
sql """DROP TABLE IF EXISTS d_table;"""
|
||||
|
||||
sql """
|
||||
create table d_table(
|
||||
k1 int null,
|
||||
k2 int not null,
|
||||
k3 decimal(28,6) null,
|
||||
k4 varchar(100) null
|
||||
)
|
||||
duplicate key (k1,k2,k3)
|
||||
distributed BY hash(k1) buckets 3
|
||||
properties("replication_num" = "1");
|
||||
"""
|
||||
|
||||
sql "insert into d_table select 1,1,1,'a';"
|
||||
sql "insert into d_table select 2,2,2,'b';"
|
||||
sql "insert into d_table select 3,-3,null,'c';"
|
||||
|
||||
createMV("create materialized view kp as select k1,k2,percentile(k3, 0.1),percentile(k3, 0.9) from d_table group by k1,k2;")
|
||||
|
||||
sql "insert into d_table select -4,-4,-4,'d';"
|
||||
sql "insert into d_table(k4,k2) values('d',4);"
|
||||
|
||||
qt_select_star "select * from d_table order by k1;"
|
||||
|
||||
explain {
|
||||
sql("select k1,k2,percentile(k3, 0.1),percentile(k3, 0.9) from d_table group by k1,k2 order by k1,k2;")
|
||||
contains "(kp)"
|
||||
}
|
||||
qt_select_mv "select k1,k2,percentile(k3, 0.1),percentile(k3, 0.9) from d_table group by k1,k2 order by k1,k2;"
|
||||
|
||||
explain {
|
||||
sql("select k1,k2,percentile(k3, 0.1),percentile(k3, 0.9) from d_table group by grouping sets((k1),(k1,k2),()) order by 1,2;")
|
||||
contains "(kp)"
|
||||
}
|
||||
qt_select_mv "select k1,k2,percentile(k3, 0.1),percentile(k3, 0.9) from d_table group by grouping sets((k1),(k1,k2),()) order by 1,2,3;"
|
||||
|
||||
|
||||
explain {
|
||||
sql("select percentile(k3, 0.1) from d_table group by grouping sets((k1),()) order by 1;")
|
||||
contains "(kp)"
|
||||
}
|
||||
qt_select_mv "select percentile(k3, 0.1) from d_table group by grouping sets((k1),()) order by 1;"
|
||||
}
|
||||
@ -308,4 +308,44 @@ suite("aggregate") {
|
||||
qt_aggregate_limit_contain_null """
|
||||
select count(), cast(k12 as int) as t from baseall group by t limit 1;
|
||||
"""
|
||||
|
||||
// Test case for percentile function with sales data
|
||||
sql """ DROP TABLE IF EXISTS sales_data """
|
||||
sql """
|
||||
CREATE TABLE sales_data (
|
||||
product_id INT,
|
||||
sale_price DECIMAL(10, 2)
|
||||
) DUPLICATE KEY(`product_id`)
|
||||
DISTRIBUTED BY HASH(`product_id`) BUCKETS 1
|
||||
PROPERTIES (
|
||||
"replication_allocation" = "tag.location.default: 1"
|
||||
)
|
||||
"""
|
||||
|
||||
sql """
|
||||
INSERT INTO sales_data VALUES
|
||||
(1, 10.00),
|
||||
(1, 15.00),
|
||||
(1, 20.00),
|
||||
(1, 25.00),
|
||||
(1, 30.25),
|
||||
(1, 35.00),
|
||||
(1, 40.00),
|
||||
(1, 45.00),
|
||||
(1, 50.00),
|
||||
(1, 100.00)
|
||||
"""
|
||||
|
||||
qt_aggregate35 """
|
||||
SELECT
|
||||
percentile(sale_price, 0.05) as median_price_05,
|
||||
percentile(sale_price, 0.5) as median_price,
|
||||
percentile(sale_price, 0.75) as p75_price,
|
||||
percentile(sale_price, 0.90) as p90_price,
|
||||
percentile(sale_price, 0.95) as p95_price,
|
||||
percentile(null, 0.99) as p99_null
|
||||
FROM sales_data
|
||||
"""
|
||||
|
||||
sql """ DROP TABLE IF EXISTS sales_data """
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user