branch-2.1: [Fix](Nereids) fix floor/round/ceil/truncate functions type compute precision problem (#43782)
Cherry-picked from #43422 Co-authored-by: LiBinfeng <libinfeng@selectdb.com>
This commit is contained in:
committed by
GitHub
parent
2e64491ee3
commit
9d4405db6f
@ -694,12 +694,17 @@ public class NumericArithmetic {
|
||||
return input;
|
||||
}
|
||||
|
||||
private static Expression castDecimalV3Literal(DecimalV3Literal literal, int precision) {
|
||||
return new DecimalV3Literal(DecimalV3Type.createDecimalV3Type(precision, literal.getValue().scale()),
|
||||
literal.getValue());
|
||||
}
|
||||
|
||||
/**
|
||||
* round
|
||||
*/
|
||||
@ExecFunction(name = "round")
|
||||
public static Expression round(DecimalV3Literal first) {
|
||||
return first.round(0);
|
||||
return castDecimalV3Literal(first.round(0), first.getValue().precision());
|
||||
}
|
||||
|
||||
/**
|
||||
@ -707,7 +712,7 @@ public class NumericArithmetic {
|
||||
*/
|
||||
@ExecFunction(name = "round")
|
||||
public static Expression round(DecimalV3Literal first, IntegerLiteral second) {
|
||||
return first.round(second.getValue());
|
||||
return castDecimalV3Literal(first.round(second.getValue()), first.getValue().precision());
|
||||
}
|
||||
|
||||
/**
|
||||
@ -733,7 +738,7 @@ public class NumericArithmetic {
|
||||
*/
|
||||
@ExecFunction(name = "ceil")
|
||||
public static Expression ceil(DecimalV3Literal first) {
|
||||
return first.roundCeiling(0);
|
||||
return castDecimalV3Literal(first.roundCeiling(0), first.getValue().precision());
|
||||
}
|
||||
|
||||
/**
|
||||
@ -741,7 +746,7 @@ public class NumericArithmetic {
|
||||
*/
|
||||
@ExecFunction(name = "ceil")
|
||||
public static Expression ceil(DecimalV3Literal first, IntegerLiteral second) {
|
||||
return first.roundCeiling(second.getValue());
|
||||
return castDecimalV3Literal(first.roundCeiling(second.getValue()), first.getValue().precision());
|
||||
}
|
||||
|
||||
/**
|
||||
@ -767,7 +772,7 @@ public class NumericArithmetic {
|
||||
*/
|
||||
@ExecFunction(name = "floor")
|
||||
public static Expression floor(DecimalV3Literal first) {
|
||||
return first.roundFloor(0);
|
||||
return castDecimalV3Literal(first.roundFloor(0), first.getValue().precision());
|
||||
}
|
||||
|
||||
/**
|
||||
@ -775,7 +780,7 @@ public class NumericArithmetic {
|
||||
*/
|
||||
@ExecFunction(name = "floor")
|
||||
public static Expression floor(DecimalV3Literal first, IntegerLiteral second) {
|
||||
return first.roundFloor(second.getValue());
|
||||
return castDecimalV3Literal(first.roundFloor(second.getValue()), first.getValue().precision());
|
||||
}
|
||||
|
||||
/**
|
||||
@ -1136,21 +1141,10 @@ public class NumericArithmetic {
|
||||
public static Expression truncate(DecimalV3Literal first, IntegerLiteral second) {
|
||||
if (first.getValue().compareTo(BigDecimal.ZERO) == 0) {
|
||||
return first;
|
||||
} else if (first.getValue().compareTo(BigDecimal.ZERO) < 0) {
|
||||
return castDecimalV3Literal(first.roundCeiling(second.getValue()), first.getValue().precision());
|
||||
} else {
|
||||
if (first.getValue().scale() < second.getValue()) {
|
||||
return first;
|
||||
}
|
||||
if (second.getValue() < 0) {
|
||||
double factor = Math.pow(10, Math.abs(second.getValue()));
|
||||
return new DecimalV3Literal(
|
||||
DecimalV3Type.createDecimalV3Type(first.getValue().precision(), 0),
|
||||
BigDecimal.valueOf(Math.floor(first.getDouble() / factor) * factor));
|
||||
}
|
||||
if (first.getValue().compareTo(BigDecimal.ZERO) == -1) {
|
||||
return first.roundCeiling(second.getValue());
|
||||
} else {
|
||||
return first.roundFloor(second.getValue());
|
||||
}
|
||||
return castDecimalV3Literal(first.roundFloor(second.getValue()), first.getValue().precision());
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@ -409,4 +409,29 @@ test {
|
||||
|
||||
//Additional cases for Xor, Conv, and other mathematical functions
|
||||
testFoldConst("SELECT CONV(-10, 10, 2) AS conv_invalid_base") //Conv with negative input (may be undefined)
|
||||
|
||||
// fix floor/ceil/round function return type with DecimalV3 input
|
||||
testFoldConst("with cte as (select floor(300.343) order by 1 limit 1) select * from cte")
|
||||
testFoldConst("with cte as (select round(300.343) order by 1 limit 1) select * from cte")
|
||||
testFoldConst("with cte as (select ceil(300.343) order by 1 limit 1) select * from cte")
|
||||
|
||||
testFoldConst("with cte as (select floor(300.343, 2) order by 1 limit 1) select * from cte")
|
||||
testFoldConst("with cte as (select round(300.343, 2) order by 1 limit 1) select * from cte")
|
||||
testFoldConst("with cte as (select ceil(300.343, 2) order by 1 limit 1) select * from cte")
|
||||
testFoldConst("with cte as (select truncate(300.343, 2) order by 1 limit 1) select * from cte")
|
||||
|
||||
testFoldConst("with cte as (select floor(300.343, 0) order by 1 limit 1) select * from cte")
|
||||
testFoldConst("with cte as (select round(300.343, 0) order by 1 limit 1) select * from cte")
|
||||
testFoldConst("with cte as (select ceil(300.343, 0) order by 1 limit 1) select * from cte")
|
||||
testFoldConst("with cte as (select truncate(300.343, 0) order by 1 limit 1) select * from cte")
|
||||
|
||||
testFoldConst("with cte as (select floor(300.343, -1) order by 1 limit 1) select * from cte")
|
||||
testFoldConst("with cte as (select round(300.343, -1) order by 1 limit 1) select * from cte")
|
||||
testFoldConst("with cte as (select ceil(300.343, -1) order by 1 limit 1) select * from cte")
|
||||
testFoldConst("with cte as (select truncate(300.343, -1) order by 1 limit 1) select * from cte")
|
||||
|
||||
testFoldConst("with cte as (select floor(300.343, -4) order by 1 limit 1) select * from cte")
|
||||
testFoldConst("with cte as (select round(300.343, -4) order by 1 limit 1) select * from cte")
|
||||
testFoldConst("with cte as (select ceil(300.343, -4) order by 1 limit 1) select * from cte")
|
||||
testFoldConst("with cte as (select truncate(300.343, -4) order by 1 limit 1) select * from cte")
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user