branch-2.1: [Fix](Nereids) fix floor/round/ceil/truncate functions type compute precision problem (#43782)

Cherry-picked from #43422

Co-authored-by: LiBinfeng <libinfeng@selectdb.com>
This commit is contained in:
github-actions[bot]
2024-11-13 11:52:14 +08:00
committed by GitHub
parent 2e64491ee3
commit 9d4405db6f
2 changed files with 39 additions and 20 deletions

View File

@ -694,12 +694,17 @@ public class NumericArithmetic {
return input;
}
private static Expression castDecimalV3Literal(DecimalV3Literal literal, int precision) {
return new DecimalV3Literal(DecimalV3Type.createDecimalV3Type(precision, literal.getValue().scale()),
literal.getValue());
}
/**
* round
*/
@ExecFunction(name = "round")
public static Expression round(DecimalV3Literal first) {
return first.round(0);
return castDecimalV3Literal(first.round(0), first.getValue().precision());
}
/**
@ -707,7 +712,7 @@ public class NumericArithmetic {
*/
@ExecFunction(name = "round")
public static Expression round(DecimalV3Literal first, IntegerLiteral second) {
return first.round(second.getValue());
return castDecimalV3Literal(first.round(second.getValue()), first.getValue().precision());
}
/**
@ -733,7 +738,7 @@ public class NumericArithmetic {
*/
@ExecFunction(name = "ceil")
public static Expression ceil(DecimalV3Literal first) {
return first.roundCeiling(0);
return castDecimalV3Literal(first.roundCeiling(0), first.getValue().precision());
}
/**
@ -741,7 +746,7 @@ public class NumericArithmetic {
*/
@ExecFunction(name = "ceil")
public static Expression ceil(DecimalV3Literal first, IntegerLiteral second) {
return first.roundCeiling(second.getValue());
return castDecimalV3Literal(first.roundCeiling(second.getValue()), first.getValue().precision());
}
/**
@ -767,7 +772,7 @@ public class NumericArithmetic {
*/
@ExecFunction(name = "floor")
public static Expression floor(DecimalV3Literal first) {
return first.roundFloor(0);
return castDecimalV3Literal(first.roundFloor(0), first.getValue().precision());
}
/**
@ -775,7 +780,7 @@ public class NumericArithmetic {
*/
@ExecFunction(name = "floor")
public static Expression floor(DecimalV3Literal first, IntegerLiteral second) {
return first.roundFloor(second.getValue());
return castDecimalV3Literal(first.roundFloor(second.getValue()), first.getValue().precision());
}
/**
@ -1136,21 +1141,10 @@ public class NumericArithmetic {
public static Expression truncate(DecimalV3Literal first, IntegerLiteral second) {
if (first.getValue().compareTo(BigDecimal.ZERO) == 0) {
return first;
} else if (first.getValue().compareTo(BigDecimal.ZERO) < 0) {
return castDecimalV3Literal(first.roundCeiling(second.getValue()), first.getValue().precision());
} else {
if (first.getValue().scale() < second.getValue()) {
return first;
}
if (second.getValue() < 0) {
double factor = Math.pow(10, Math.abs(second.getValue()));
return new DecimalV3Literal(
DecimalV3Type.createDecimalV3Type(first.getValue().precision(), 0),
BigDecimal.valueOf(Math.floor(first.getDouble() / factor) * factor));
}
if (first.getValue().compareTo(BigDecimal.ZERO) == -1) {
return first.roundCeiling(second.getValue());
} else {
return first.roundFloor(second.getValue());
}
return castDecimalV3Literal(first.roundFloor(second.getValue()), first.getValue().precision());
}
}

View File

@ -409,4 +409,29 @@ test {
//Additional cases for Xor, Conv, and other mathematical functions
testFoldConst("SELECT CONV(-10, 10, 2) AS conv_invalid_base") //Conv with negative input (may be undefined)
// fix floor/ceil/round function return type with DecimalV3 input
testFoldConst("with cte as (select floor(300.343) order by 1 limit 1) select * from cte")
testFoldConst("with cte as (select round(300.343) order by 1 limit 1) select * from cte")
testFoldConst("with cte as (select ceil(300.343) order by 1 limit 1) select * from cte")
testFoldConst("with cte as (select floor(300.343, 2) order by 1 limit 1) select * from cte")
testFoldConst("with cte as (select round(300.343, 2) order by 1 limit 1) select * from cte")
testFoldConst("with cte as (select ceil(300.343, 2) order by 1 limit 1) select * from cte")
testFoldConst("with cte as (select truncate(300.343, 2) order by 1 limit 1) select * from cte")
testFoldConst("with cte as (select floor(300.343, 0) order by 1 limit 1) select * from cte")
testFoldConst("with cte as (select round(300.343, 0) order by 1 limit 1) select * from cte")
testFoldConst("with cte as (select ceil(300.343, 0) order by 1 limit 1) select * from cte")
testFoldConst("with cte as (select truncate(300.343, 0) order by 1 limit 1) select * from cte")
testFoldConst("with cte as (select floor(300.343, -1) order by 1 limit 1) select * from cte")
testFoldConst("with cte as (select round(300.343, -1) order by 1 limit 1) select * from cte")
testFoldConst("with cte as (select ceil(300.343, -1) order by 1 limit 1) select * from cte")
testFoldConst("with cte as (select truncate(300.343, -1) order by 1 limit 1) select * from cte")
testFoldConst("with cte as (select floor(300.343, -4) order by 1 limit 1) select * from cte")
testFoldConst("with cte as (select round(300.343, -4) order by 1 limit 1) select * from cte")
testFoldConst("with cte as (select ceil(300.343, -4) order by 1 limit 1) select * from cte")
testFoldConst("with cte as (select truncate(300.343, -4) order by 1 limit 1) select * from cte")
}