From 132ff6c6ded533dbf598122c0f81a66d4fabf4f3 Mon Sep 17 00:00:00 2001 From: zhangstar333 <87313068+zhangstar333@users.noreply.github.com> Date: Fri, 5 Jan 2024 18:06:16 +0800 Subject: [PATCH] [opt](Nereids) add float type signature for sum aggregate function (#29503) * [opt](Nereids) add float type signature for sum aggregate function --- .../trees/expressions/functions/agg/MultiDistinctSum.java | 5 +++++ .../doris/nereids/trees/expressions/functions/agg/Sum.java | 2 +- .../suites/nereids_function_p0/agg_function/agg.groovy | 6 ++++++ 3 files changed, 12 insertions(+), 1 deletion(-) diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/agg/MultiDistinctSum.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/agg/MultiDistinctSum.java index b440d91a31..212140fed9 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/agg/MultiDistinctSum.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/agg/MultiDistinctSum.java @@ -67,6 +67,11 @@ public class MultiDistinctSum extends NullableAggregateFunction implements Unary return new Sum(getArgument(0)).getSignatures(); } + @Override + public FunctionSignature searchSignature(List signatures) { + return new Sum(getArgument(0)).searchSignature(signatures); + } + @Override public NullableAggregateFunction withAlwaysNullable(boolean alwaysNullable) { return new MultiDistinctSum(distinct, alwaysNullable, children.get(0)); diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/agg/Sum.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/agg/Sum.java index 8799203936..f0dbd83958 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/agg/Sum.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/agg/Sum.java @@ -114,7 +114,7 @@ public class Sum extends NullableAggregateFunction @Override public FunctionSignature searchSignature(List signatures) { if (getArgument(0).getDataType() instanceof FloatType) { - return FunctionSignature.ret(DoubleType.INSTANCE).args(DoubleType.INSTANCE); + return FunctionSignature.ret(DoubleType.INSTANCE).args(FloatType.INSTANCE); } return ExplicitlyCastableSignature.super.searchSignature(signatures); } diff --git a/regression-test/suites/nereids_function_p0/agg_function/agg.groovy b/regression-test/suites/nereids_function_p0/agg_function/agg.groovy index 1a12fd0383..81c84ad32d 100644 --- a/regression-test/suites/nereids_function_p0/agg_function/agg.groovy +++ b/regression-test/suites/nereids_function_p0/agg_function/agg.groovy @@ -2188,6 +2188,12 @@ suite("nereids_agg_fn") { qt_sql_sum_BigInt_agg_phase_4_notnull ''' select /*+SET_VAR(disable_nereids_rules='THREE_PHASE_AGGREGATE_WITH_DISTINCT, TWO_PHASE_AGGREGATE_WITH_DISTINCT')*/ count(distinct id), sum(kbint) from fn_test''' + //not cast float to double + explain { + sql("select sum(kfloat) from fn_test;") + contains "partial_sum(kfloat" + } + qt_sql_sum_Double_gb ''' select sum(kdbl) from fn_test group by kbool order by kbool''' qt_sql_sum_Double '''