diff --git a/be/src/util/counts.h b/be/src/util/counts.h index fec18cedcd..5dcd14f310 100644 --- a/be/src/util/counts.h +++ b/be/src/util/counts.h @@ -26,6 +26,7 @@ namespace doris { +template class Counts { public: Counts() = default; @@ -40,7 +41,7 @@ public: } } - void increment(int64_t key, uint32_t i) { + void increment(T key, uint32_t i) { auto item = _counts.find(key); if (item != _counts.end()) { item->second += i; @@ -50,8 +51,7 @@ public: } uint32_t serialized_size() const { - return sizeof(uint32_t) + sizeof(int64_t) * _counts.size() + - sizeof(uint32_t) * _counts.size(); + return sizeof(uint32_t) + sizeof(T) * _counts.size() + sizeof(uint32_t) * _counts.size(); } void serialize(uint8_t* writer) const { @@ -59,8 +59,8 @@ public: memcpy(writer, &size, sizeof(uint32_t)); writer += sizeof(uint32_t); for (auto& cell : _counts) { - memcpy(writer, &cell.first, sizeof(int64_t)); - writer += sizeof(int64_t); + memcpy(writer, &cell.first, sizeof(T)); + writer += sizeof(T); memcpy(writer, &cell.second, sizeof(uint32_t)); writer += sizeof(uint32_t); } @@ -71,18 +71,17 @@ public: memcpy(&size, type_reader, sizeof(uint32_t)); type_reader += sizeof(uint32_t); for (uint32_t i = 0; i < size; ++i) { - int64_t key; + T key; uint32_t count; - memcpy(&key, type_reader, sizeof(int64_t)); - type_reader += sizeof(int64_t); + memcpy(&key, type_reader, sizeof(T)); + type_reader += sizeof(T); memcpy(&count, type_reader, sizeof(uint32_t)); type_reader += sizeof(uint32_t); _counts.emplace(std::make_pair(key, count)); } } - double get_percentile(std::vector>& counts, - double position) const { + double get_percentile(std::vector>& counts, double position) const { long lower = long(std::floor(position)); long higher = long(std::ceil(position)); @@ -90,7 +89,7 @@ public: for (; iter != counts.end() && iter->second < lower + 1; ++iter) ; - int64_t lower_key = iter->first; + T lower_key = iter->first; if (higher == lower) { return lower_key; } @@ -99,7 +98,7 @@ public: iter++; } - int64_t higher_key = iter->first; + T higher_key = iter->first; if (lower_key == higher_key) { return lower_key; } @@ -114,9 +113,9 @@ public: return 0.0; } - std::vector> elems(_counts.begin(), _counts.end()); + std::vector> elems(_counts.begin(), _counts.end()); sort(elems.begin(), elems.end(), - [](const std::pair l, const std::pair r) { + [](const std::pair l, const std::pair r) { return l.first < r.first; }); @@ -132,7 +131,7 @@ public: } private: - std::unordered_map _counts; + std::unordered_map _counts; }; } // namespace doris diff --git a/be/src/vec/aggregate_functions/aggregate_function_percentile_approx.cpp b/be/src/vec/aggregate_functions/aggregate_function_percentile_approx.cpp index 05e36a8f72..4cbe8a0690 100644 --- a/be/src/vec/aggregate_functions/aggregate_function_percentile_approx.cpp +++ b/be/src/vec/aggregate_functions/aggregate_function_percentile_approx.cpp @@ -50,9 +50,10 @@ AggregateFunctionPtr create_aggregate_function_percentile_approx( void register_aggregate_function_percentile(AggregateFunctionSimpleFactory& factory) { factory.register_function_both("percentile", - creator_without_type::creator); - factory.register_function_both("percentile_array", - creator_without_type::creator); + creator_with_numeric_type::creator); + factory.register_function_both( + "percentile_array", + creator_with_numeric_type::creator); } void register_aggregate_function_percentile_approx(AggregateFunctionSimpleFactory& factory) { diff --git a/be/src/vec/aggregate_functions/aggregate_function_percentile_approx.h b/be/src/vec/aggregate_functions/aggregate_function_percentile_approx.h index 0f24aef6db..87e5560466 100644 --- a/be/src/vec/aggregate_functions/aggregate_function_percentile_approx.h +++ b/be/src/vec/aggregate_functions/aggregate_function_percentile_approx.h @@ -288,8 +288,9 @@ public: } }; +template struct PercentileState { - std::vector vec_counts; + std::vector> vec_counts; std::vector vec_quantile {-1}; bool inited_flag = false; @@ -327,7 +328,7 @@ struct PercentileState { } } - void add(int64_t source, const PaddedPODArray& quantiles, int arg_size) { + void add(T source, const PaddedPODArray& quantiles, int arg_size) { if (!inited_flag) { vec_counts.resize(arg_size); vec_quantile.resize(arg_size, -1); @@ -376,11 +377,12 @@ struct PercentileState { } }; +template class AggregateFunctionPercentile final - : public IAggregateFunctionDataHelper { + : public IAggregateFunctionDataHelper, AggregateFunctionPercentile> { public: AggregateFunctionPercentile(const DataTypes& argument_types_) - : IAggregateFunctionDataHelper( + : IAggregateFunctionDataHelper, AggregateFunctionPercentile>( argument_types_) {} String get_name() const override { return "percentile"; } @@ -389,10 +391,10 @@ public: void add(AggregateDataPtr __restrict place, const IColumn** columns, ssize_t row_num, Arena*) const override { - const auto& sources = assert_cast&>(*columns[0]); + const auto& sources = assert_cast&>(*columns[0]); const auto& quantile = assert_cast&>(*columns[1]); - AggregateFunctionPercentile::data(place).add(sources.get_int(row_num), quantile.get_data(), - 1); + AggregateFunctionPercentile::data(place).add(sources.get_element(row_num), + quantile.get_data(), 1); } void reset(AggregateDataPtr __restrict place) const override { @@ -419,11 +421,13 @@ public: } }; +template class AggregateFunctionPercentileArray final - : public IAggregateFunctionDataHelper { + : public IAggregateFunctionDataHelper, + AggregateFunctionPercentileArray> { public: AggregateFunctionPercentileArray(const DataTypes& argument_types_) - : IAggregateFunctionDataHelper( + : IAggregateFunctionDataHelper, AggregateFunctionPercentileArray>( argument_types_) {} String get_name() const override { return "percentile_array"; } @@ -434,7 +438,7 @@ public: void add(AggregateDataPtr __restrict place, const IColumn** columns, ssize_t row_num, Arena*) const override { - const auto& sources = assert_cast&>(*columns[0]); + const auto& sources = assert_cast&>(*columns[0]); const auto& quantile_array = assert_cast(*columns[1]); const auto& offset_column_data = quantile_array.get_offsets(); const auto& nested_column = @@ -442,7 +446,7 @@ public: const auto& nested_column_data = assert_cast&>(nested_column); AggregateFunctionPercentileArray::data(place).add( - sources.get_int(row_num), nested_column_data.get_data(), + sources.get_element(row_num), nested_column_data.get_data(), offset_column_data.data()[row_num] - offset_column_data[(ssize_t)row_num - 1]); } @@ -480,4 +484,4 @@ public: } }; -} // namespace doris::vectorized \ No newline at end of file +} // namespace doris::vectorized diff --git a/be/test/util/counts_test.cpp b/be/test/util/counts_test.cpp index 908bbcefd5..42370f8057 100644 --- a/be/test/util/counts_test.cpp +++ b/be/test/util/counts_test.cpp @@ -27,7 +27,7 @@ namespace doris { class TCountsTest : public testing::Test {}; TEST_F(TCountsTest, TotalTest) { - Counts counts; + Counts counts; // 1 1 1 2 5 7 7 9 9 19 // >>> import numpy as np // >>> a = np.array([1,1,1,2,5,7,7,9,9,19]) @@ -46,12 +46,12 @@ TEST_F(TCountsTest, TotalTest) { uint8_t* type_reader = writer; counts.serialize(writer); - Counts other; + Counts other; other.unserialize(type_reader); double result1 = other.terminate(0.2); EXPECT_EQ(result, result1); - Counts other1; + Counts other1; other1.increment(1, 1); other1.increment(100, 3); other1.increment(50, 3); diff --git a/fe/fe-core/src/main/java/org/apache/doris/catalog/BuiltinAggregateFunctions.java b/fe/fe-core/src/main/java/org/apache/doris/catalog/BuiltinAggregateFunctions.java index d8f95b2f6b..09acb21f47 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/catalog/BuiltinAggregateFunctions.java +++ b/fe/fe-core/src/main/java/org/apache/doris/catalog/BuiltinAggregateFunctions.java @@ -131,7 +131,7 @@ public class BuiltinAggregateFunctions implements FunctionHelper { agg(OrthogonalBitmapIntersect.class, "orthogonal_bitmap_intersect"), agg(OrthogonalBitmapIntersectCount.class, "orthogonal_bitmap_intersect_count"), agg(OrthogonalBitmapUnionCount.class, "orthogonal_bitmap_union_count"), - agg(Percentile.class, "percentile"), + agg(Percentile.class, "percentile", "percentile_cont"), agg(PercentileApprox.class, "percentile_approx"), agg(PercentileArray.class, "percentile_array"), agg(QuantileUnion.class, "quantile_union"), diff --git a/fe/fe-core/src/main/java/org/apache/doris/catalog/FunctionSet.java b/fe/fe-core/src/main/java/org/apache/doris/catalog/FunctionSet.java index 2db943993d..d943ad4f6e 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/catalog/FunctionSet.java +++ b/fe/fe-core/src/main/java/org/apache/doris/catalog/FunctionSet.java @@ -1430,6 +1430,15 @@ public class FunctionSet { "", false, true, false, true)); + addBuiltin(AggregateFunction.createBuiltin("percentile_cont", + Lists.newArrayList(Type.BIGINT, Type.DOUBLE), Type.DOUBLE, Type.VARCHAR, + "", + "", + "", + "", + "", + false, true, false, true)); + addBuiltin(AggregateFunction.createBuiltin("percentile_approx", Lists.newArrayList(Type.DOUBLE, Type.DOUBLE), Type.DOUBLE, Type.VARCHAR, "", diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/agg/Percentile.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/agg/Percentile.java index 31ab925ca6..fd3ba4890d 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/agg/Percentile.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/agg/Percentile.java @@ -25,6 +25,11 @@ import org.apache.doris.nereids.trees.expressions.shape.BinaryExpression; import org.apache.doris.nereids.trees.expressions.visitor.ExpressionVisitor; import org.apache.doris.nereids.types.BigIntType; import org.apache.doris.nereids.types.DoubleType; +import org.apache.doris.nereids.types.FloatType; +import org.apache.doris.nereids.types.IntegerType; +import org.apache.doris.nereids.types.LargeIntType; +import org.apache.doris.nereids.types.SmallIntType; +import org.apache.doris.nereids.types.TinyIntType; import com.google.common.base.Preconditions; import com.google.common.collect.ImmutableList; @@ -38,7 +43,13 @@ public class Percentile extends NullableAggregateFunction implements BinaryExpression, ExplicitlyCastableSignature { public static final List SIGNATURES = ImmutableList.of( - FunctionSignature.ret(DoubleType.INSTANCE).args(BigIntType.INSTANCE, DoubleType.INSTANCE) + FunctionSignature.ret(DoubleType.INSTANCE).args(DoubleType.INSTANCE, DoubleType.INSTANCE), + FunctionSignature.ret(DoubleType.INSTANCE).args(FloatType.INSTANCE, DoubleType.INSTANCE), + FunctionSignature.ret(DoubleType.INSTANCE).args(LargeIntType.INSTANCE, DoubleType.INSTANCE), + FunctionSignature.ret(DoubleType.INSTANCE).args(BigIntType.INSTANCE, DoubleType.INSTANCE), + FunctionSignature.ret(DoubleType.INSTANCE).args(IntegerType.INSTANCE, DoubleType.INSTANCE), + FunctionSignature.ret(DoubleType.INSTANCE).args(SmallIntType.INSTANCE, DoubleType.INSTANCE), + FunctionSignature.ret(DoubleType.INSTANCE).args(TinyIntType.INSTANCE, DoubleType.INSTANCE) ); /** diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/agg/PercentileArray.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/agg/PercentileArray.java index d4d8ed6c39..c97b617e61 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/agg/PercentileArray.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/agg/PercentileArray.java @@ -26,6 +26,11 @@ import org.apache.doris.nereids.trees.expressions.visitor.ExpressionVisitor; import org.apache.doris.nereids.types.ArrayType; import org.apache.doris.nereids.types.BigIntType; import org.apache.doris.nereids.types.DoubleType; +import org.apache.doris.nereids.types.FloatType; +import org.apache.doris.nereids.types.IntegerType; +import org.apache.doris.nereids.types.LargeIntType; +import org.apache.doris.nereids.types.SmallIntType; +import org.apache.doris.nereids.types.TinyIntType; import com.google.common.base.Preconditions; import com.google.common.collect.ImmutableList; @@ -40,8 +45,19 @@ public class PercentileArray extends AggregateFunction public static final List SIGNATURES = ImmutableList.of( FunctionSignature.ret(ArrayType.of(DoubleType.INSTANCE)) - .args(BigIntType.INSTANCE, ArrayType.of(DoubleType.INSTANCE)) - ); + .args(DoubleType.INSTANCE, ArrayType.of(DoubleType.INSTANCE)), + FunctionSignature.ret(ArrayType.of(DoubleType.INSTANCE)) + .args(FloatType.INSTANCE, ArrayType.of(DoubleType.INSTANCE)), + FunctionSignature.ret(ArrayType.of(DoubleType.INSTANCE)) + .args(LargeIntType.INSTANCE, ArrayType.of(DoubleType.INSTANCE)), + FunctionSignature.ret(ArrayType.of(DoubleType.INSTANCE)) + .args(BigIntType.INSTANCE, ArrayType.of(DoubleType.INSTANCE)), + FunctionSignature.ret(ArrayType.of(DoubleType.INSTANCE)) + .args(IntegerType.INSTANCE, ArrayType.of(DoubleType.INSTANCE)), + FunctionSignature.ret(ArrayType.of(DoubleType.INSTANCE)) + .args(SmallIntType.INSTANCE, ArrayType.of(DoubleType.INSTANCE)), + FunctionSignature.ret(ArrayType.of(DoubleType.INSTANCE)) + .args(TinyIntType.INSTANCE, ArrayType.of(DoubleType.INSTANCE))); /** * constructor with 2 arguments. diff --git a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/analysis/GenerateFunction.java b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/analysis/GenerateFunction.java index 2d010df4c5..105ab00f39 100644 --- a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/analysis/GenerateFunction.java +++ b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/analysis/GenerateFunction.java @@ -182,6 +182,7 @@ public class GenerateFunction { .put("any", "any_value") .put("char_length", "character_length") .put("stddev_pop", "stddev") + .put("percentile_cont", "percentile") .put("var_pop", "variance") .put("variance_pop", "variance") .put("var_samp", "variance_samp") diff --git a/regression-test/data/nereids_p0/sql_functions/aggregate_functions/test_aggregate_window_functions.out b/regression-test/data/nereids_p0/sql_functions/aggregate_functions/test_aggregate_window_functions.out index 6729ea26bc..03569f1aed 100644 --- a/regression-test/data/nereids_p0/sql_functions/aggregate_functions/test_aggregate_window_functions.out +++ b/regression-test/data/nereids_p0/sql_functions/aggregate_functions/test_aggregate_window_functions.out @@ -384,11 +384,11 @@ sichuan [{"cbe":{},"notnull":0,"null":1,"all":1}] 2 [123456789, 223456789, 323456789] 2 [123456789, 223456789, 323456789] 2 [123456789, 223456789, 323456789] -3 [223456789, 223456789, 323456789] -3 [223456789, 223456789, 323456789] -3 [223456789, 223456789, 323456789] -3 [223456789, 223456789, 323456789] -3 [223456789, 223456789, 323456789] +3 [223456789.6, 223456789.6, 323456789.1] +3 [223456789.6, 223456789.6, 323456789.1] +3 [223456789.6, 223456789.6, 323456789.1] +3 [223456789.6, 223456789.6, 323456789.1] +3 [223456789.6, 223456789.6, 323456789.1] -- !agg_window_percentile_approx -- 1 5.234568E8 diff --git a/regression-test/data/query_p0/aggregate/aggregate.out b/regression-test/data/query_p0/aggregate/aggregate.out index ffd3790499..f17c690ec4 100644 --- a/regression-test/data/query_p0/aggregate/aggregate.out +++ b/regression-test/data/query_p0/aggregate/aggregate.out @@ -697,3 +697,6 @@ TESTING AGAIN -- !aggregate_limit_contain_null -- 16 \N + +-- !aggregate35 -- +12.25 32.625 43.75 54.999999999999986 77.49999999999994 \N diff --git a/regression-test/suites/mv_p0/mv_percentile/mv_percentile.groovy b/regression-test/suites/mv_p0/mv_percentile/mv_percentile.groovy deleted file mode 100644 index dd6cb45330..0000000000 --- a/regression-test/suites/mv_p0/mv_percentile/mv_percentile.groovy +++ /dev/null @@ -1,66 +0,0 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - -import org.codehaus.groovy.runtime.IOGroovyMethods - -suite ("mv_percentile") { - sql "set enable_fallback_to_original_planner = false" - - sql """DROP TABLE IF EXISTS d_table;""" - - sql """ - create table d_table( - k1 int null, - k2 int not null, - k3 decimal(28,6) null, - k4 varchar(100) null - ) - duplicate key (k1,k2,k3) - distributed BY hash(k1) buckets 3 - properties("replication_num" = "1"); - """ - - sql "insert into d_table select 1,1,1,'a';" - sql "insert into d_table select 2,2,2,'b';" - sql "insert into d_table select 3,-3,null,'c';" - - createMV("create materialized view kp as select k1,k2,percentile(k3, 0.1),percentile(k3, 0.9) from d_table group by k1,k2;") - - sql "insert into d_table select -4,-4,-4,'d';" - sql "insert into d_table(k4,k2) values('d',4);" - - qt_select_star "select * from d_table order by k1;" - - explain { - sql("select k1,k2,percentile(k3, 0.1),percentile(k3, 0.9) from d_table group by k1,k2 order by k1,k2;") - contains "(kp)" - } - qt_select_mv "select k1,k2,percentile(k3, 0.1),percentile(k3, 0.9) from d_table group by k1,k2 order by k1,k2;" - - explain { - sql("select k1,k2,percentile(k3, 0.1),percentile(k3, 0.9) from d_table group by grouping sets((k1),(k1,k2),()) order by 1,2;") - contains "(kp)" - } - qt_select_mv "select k1,k2,percentile(k3, 0.1),percentile(k3, 0.9) from d_table group by grouping sets((k1),(k1,k2),()) order by 1,2,3;" - - - explain { - sql("select percentile(k3, 0.1) from d_table group by grouping sets((k1),()) order by 1;") - contains "(kp)" - } - qt_select_mv "select percentile(k3, 0.1) from d_table group by grouping sets((k1),()) order by 1;" -} diff --git a/regression-test/suites/query_p0/aggregate/aggregate.groovy b/regression-test/suites/query_p0/aggregate/aggregate.groovy index b611ff92b0..9836b8a2f5 100644 --- a/regression-test/suites/query_p0/aggregate/aggregate.groovy +++ b/regression-test/suites/query_p0/aggregate/aggregate.groovy @@ -308,4 +308,44 @@ suite("aggregate") { qt_aggregate_limit_contain_null """ select count(), cast(k12 as int) as t from baseall group by t limit 1; """ + + // Test case for percentile function with sales data + sql """ DROP TABLE IF EXISTS sales_data """ + sql """ + CREATE TABLE sales_data ( + product_id INT, + sale_price DECIMAL(10, 2) + ) DUPLICATE KEY(`product_id`) + DISTRIBUTED BY HASH(`product_id`) BUCKETS 1 + PROPERTIES ( + "replication_allocation" = "tag.location.default: 1" + ) + """ + + sql """ + INSERT INTO sales_data VALUES + (1, 10.00), + (1, 15.00), + (1, 20.00), + (1, 25.00), + (1, 30.25), + (1, 35.00), + (1, 40.00), + (1, 45.00), + (1, 50.00), + (1, 100.00) + """ + + qt_aggregate35 """ + SELECT + percentile(sale_price, 0.05) as median_price_05, + percentile(sale_price, 0.5) as median_price, + percentile(sale_price, 0.75) as p75_price, + percentile(sale_price, 0.90) as p90_price, + percentile(sale_price, 0.95) as p95_price, + percentile(null, 0.99) as p99_null + FROM sales_data + """ + + sql """ DROP TABLE IF EXISTS sales_data """ }