From 9d4405db6fa4846f02f83ec31faed301b2297374 Mon Sep 17 00:00:00 2001 From: "github-actions[bot]" <41898282+github-actions[bot]@users.noreply.github.com> Date: Wed, 13 Nov 2024 11:52:14 +0800 Subject: [PATCH] branch-2.1: [Fix](Nereids) fix floor/round/ceil/truncate functions type compute precision problem (#43782) Cherry-picked from #43422 Co-authored-by: LiBinfeng --- .../executable/NumericArithmetic.java | 34 ++++++++----------- .../fold_constant_numeric_arithmatic.groovy | 25 ++++++++++++++ 2 files changed, 39 insertions(+), 20 deletions(-) diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/executable/NumericArithmetic.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/executable/NumericArithmetic.java index 325e676fc0..a9acfeb2d6 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/executable/NumericArithmetic.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/executable/NumericArithmetic.java @@ -694,12 +694,17 @@ public class NumericArithmetic { return input; } + private static Expression castDecimalV3Literal(DecimalV3Literal literal, int precision) { + return new DecimalV3Literal(DecimalV3Type.createDecimalV3Type(precision, literal.getValue().scale()), + literal.getValue()); + } + /** * round */ @ExecFunction(name = "round") public static Expression round(DecimalV3Literal first) { - return first.round(0); + return castDecimalV3Literal(first.round(0), first.getValue().precision()); } /** @@ -707,7 +712,7 @@ public class NumericArithmetic { */ @ExecFunction(name = "round") public static Expression round(DecimalV3Literal first, IntegerLiteral second) { - return first.round(second.getValue()); + return castDecimalV3Literal(first.round(second.getValue()), first.getValue().precision()); } /** @@ -733,7 +738,7 @@ public class NumericArithmetic { */ @ExecFunction(name = "ceil") public static Expression ceil(DecimalV3Literal first) { - return first.roundCeiling(0); + return castDecimalV3Literal(first.roundCeiling(0), first.getValue().precision()); } /** @@ -741,7 +746,7 @@ public class NumericArithmetic { */ @ExecFunction(name = "ceil") public static Expression ceil(DecimalV3Literal first, IntegerLiteral second) { - return first.roundCeiling(second.getValue()); + return castDecimalV3Literal(first.roundCeiling(second.getValue()), first.getValue().precision()); } /** @@ -767,7 +772,7 @@ public class NumericArithmetic { */ @ExecFunction(name = "floor") public static Expression floor(DecimalV3Literal first) { - return first.roundFloor(0); + return castDecimalV3Literal(first.roundFloor(0), first.getValue().precision()); } /** @@ -775,7 +780,7 @@ public class NumericArithmetic { */ @ExecFunction(name = "floor") public static Expression floor(DecimalV3Literal first, IntegerLiteral second) { - return first.roundFloor(second.getValue()); + return castDecimalV3Literal(first.roundFloor(second.getValue()), first.getValue().precision()); } /** @@ -1136,21 +1141,10 @@ public class NumericArithmetic { public static Expression truncate(DecimalV3Literal first, IntegerLiteral second) { if (first.getValue().compareTo(BigDecimal.ZERO) == 0) { return first; + } else if (first.getValue().compareTo(BigDecimal.ZERO) < 0) { + return castDecimalV3Literal(first.roundCeiling(second.getValue()), first.getValue().precision()); } else { - if (first.getValue().scale() < second.getValue()) { - return first; - } - if (second.getValue() < 0) { - double factor = Math.pow(10, Math.abs(second.getValue())); - return new DecimalV3Literal( - DecimalV3Type.createDecimalV3Type(first.getValue().precision(), 0), - BigDecimal.valueOf(Math.floor(first.getDouble() / factor) * factor)); - } - if (first.getValue().compareTo(BigDecimal.ZERO) == -1) { - return first.roundCeiling(second.getValue()); - } else { - return first.roundFloor(second.getValue()); - } + return castDecimalV3Literal(first.roundFloor(second.getValue()), first.getValue().precision()); } } diff --git a/regression-test/suites/nereids_p0/expression/fold_constant/fold_constant_numeric_arithmatic.groovy b/regression-test/suites/nereids_p0/expression/fold_constant/fold_constant_numeric_arithmatic.groovy index 5f72865126..dbfd3fad7b 100644 --- a/regression-test/suites/nereids_p0/expression/fold_constant/fold_constant_numeric_arithmatic.groovy +++ b/regression-test/suites/nereids_p0/expression/fold_constant/fold_constant_numeric_arithmatic.groovy @@ -409,4 +409,29 @@ test { //Additional cases for Xor, Conv, and other mathematical functions testFoldConst("SELECT CONV(-10, 10, 2) AS conv_invalid_base") //Conv with negative input (may be undefined) + + // fix floor/ceil/round function return type with DecimalV3 input + testFoldConst("with cte as (select floor(300.343) order by 1 limit 1) select * from cte") + testFoldConst("with cte as (select round(300.343) order by 1 limit 1) select * from cte") + testFoldConst("with cte as (select ceil(300.343) order by 1 limit 1) select * from cte") + + testFoldConst("with cte as (select floor(300.343, 2) order by 1 limit 1) select * from cte") + testFoldConst("with cte as (select round(300.343, 2) order by 1 limit 1) select * from cte") + testFoldConst("with cte as (select ceil(300.343, 2) order by 1 limit 1) select * from cte") + testFoldConst("with cte as (select truncate(300.343, 2) order by 1 limit 1) select * from cte") + + testFoldConst("with cte as (select floor(300.343, 0) order by 1 limit 1) select * from cte") + testFoldConst("with cte as (select round(300.343, 0) order by 1 limit 1) select * from cte") + testFoldConst("with cte as (select ceil(300.343, 0) order by 1 limit 1) select * from cte") + testFoldConst("with cte as (select truncate(300.343, 0) order by 1 limit 1) select * from cte") + + testFoldConst("with cte as (select floor(300.343, -1) order by 1 limit 1) select * from cte") + testFoldConst("with cte as (select round(300.343, -1) order by 1 limit 1) select * from cte") + testFoldConst("with cte as (select ceil(300.343, -1) order by 1 limit 1) select * from cte") + testFoldConst("with cte as (select truncate(300.343, -1) order by 1 limit 1) select * from cte") + + testFoldConst("with cte as (select floor(300.343, -4) order by 1 limit 1) select * from cte") + testFoldConst("with cte as (select round(300.343, -4) order by 1 limit 1) select * from cte") + testFoldConst("with cte as (select ceil(300.343, -4) order by 1 limit 1) select * from cte") + testFoldConst("with cte as (select truncate(300.343, -4) order by 1 limit 1) select * from cte") }