From 341c45974c42f0880fe2c4cfaac77b64b69941cd Mon Sep 17 00:00:00 2001 From: Gabriel Date: Thu, 27 Jul 2023 10:00:36 +0800 Subject: [PATCH] [round](decimalv2) round precise decimalv2 value (#22258) --- be/src/vec/functions/round.h | 6 ++--- .../doris/analysis/FunctionCallExpr.java | 8 +++++++ .../org/apache/doris/catalog/FunctionSet.java | 14 ++++++++++++ .../org/apache/doris/qe/SessionVariable.java | 5 +++++ gensrc/script/doris_builtins_functions.py | 14 +++++------- .../math_functions/test_round.out | 22 ++++++++++++++----- .../math_functions/test_round.groovy | 6 +++++ 7 files changed, 57 insertions(+), 18 deletions(-) diff --git a/be/src/vec/functions/round.h b/be/src/vec/functions/round.h index b19dc26383..66753f00c0 100644 --- a/be/src/vec/functions/round.h +++ b/be/src/vec/functions/round.h @@ -151,7 +151,6 @@ private: public: static NO_INLINE void apply(const Container& in, UInt32 in_scale, Container& out, Int16 out_scale) { - constexpr bool is_decimalv2 = IsDecimalV2; Int16 scale_arg = in_scale - out_scale; if (scale_arg > 0) { size_t scale = int_exp10(scale_arg); @@ -162,14 +161,13 @@ public: if (out_scale < 0) { while (p_in < end_in) { - Op::compute(p_in, scale, p_out, - is_decimalv2 ? int_exp10(9 - out_scale) : int_exp10(-out_scale)); + Op::compute(p_in, scale, p_out, int_exp10(-out_scale)); ++p_in; ++p_out; } } else { while (p_in < end_in) { - Op::compute(p_in, scale, p_out, is_decimalv2 ? scale : 1); + Op::compute(p_in, scale, p_out, 1); ++p_in; ++p_out; } diff --git a/fe/fe-core/src/main/java/org/apache/doris/analysis/FunctionCallExpr.java b/fe/fe-core/src/main/java/org/apache/doris/analysis/FunctionCallExpr.java index e9117bcf9c..e28587f84f 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/analysis/FunctionCallExpr.java +++ b/fe/fe-core/src/main/java/org/apache/doris/analysis/FunctionCallExpr.java @@ -1689,6 +1689,14 @@ public class FunctionCallExpr extends Expr { && (args[ix].isArrayType()) && ((ArrayType) args[ix]).getItemType().isDecimalV3()))) { continue; + } else if (!argTypes[i].matchesType(args[ix]) + && ROUND_FUNCTION_SET.contains(fnName.getFunction()) + && ConnectContext.get() != null + && ConnectContext.get().getSessionVariable().roundPreciseDecimalV2Value + && argTypes[i].isDecimalV2() + && args[ix].isDecimalV3()) { + uncheckedCastChild(ScalarType.createDecimalV3Type(ScalarType.MAX_DECIMALV2_PRECISION, + ((ScalarType) argTypes[i]).getScalarScale()), i); } else if (!argTypes[i].matchesType(args[ix]) && !(argTypes[i].isDecimalV3OrContainsDecimalV3() && args[ix].isDecimalV3OrContainsDecimalV3())) { diff --git a/fe/fe-core/src/main/java/org/apache/doris/catalog/FunctionSet.java b/fe/fe-core/src/main/java/org/apache/doris/catalog/FunctionSet.java index 5b3ec5ba83..881b507fee 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/catalog/FunctionSet.java +++ b/fe/fe-core/src/main/java/org/apache/doris/catalog/FunctionSet.java @@ -21,12 +21,14 @@ import org.apache.doris.analysis.ArithmeticExpr; import org.apache.doris.analysis.BinaryPredicate; import org.apache.doris.analysis.CastExpr; import org.apache.doris.analysis.CompoundPredicate; +import org.apache.doris.analysis.FunctionCallExpr; import org.apache.doris.analysis.InPredicate; import org.apache.doris.analysis.IsNullPredicate; import org.apache.doris.analysis.LikePredicate; import org.apache.doris.analysis.MatchPredicate; import org.apache.doris.builtins.ScalarBuiltins; import org.apache.doris.catalog.Function.NullableMode; +import org.apache.doris.qe.ConnectContext; import com.google.common.base.Preconditions; import com.google.common.collect.ImmutableMap; @@ -453,6 +455,18 @@ public class FunctionSet { return false; } } + // If set `roundPreciseDecimalV2Value`, only use decimalv3 as target type to execute round function + if (ConnectContext.get() != null + && ConnectContext.get().getSessionVariable().roundPreciseDecimalV2Value + && FunctionCallExpr.ROUND_FUNCTION_SET.contains(desc.functionName()) + && descArgType.isDecimalV2() && candicateArgType.getPrimitiveType() != PrimitiveType.DECIMAL128) { + return false; + } else if (ConnectContext.get() != null + && ConnectContext.get().getSessionVariable().roundPreciseDecimalV2Value + && FunctionCallExpr.ROUND_FUNCTION_SET.contains(desc.functionName()) + && descArgType.isDecimalV2() && candicateArgType.getPrimitiveType() == PrimitiveType.DECIMAL128) { + return true; + } if ((descArgType.isDecimalV3() && candicateArgType.isDecimalV2()) || (descArgType.isDecimalV2() && candicateArgType.isDecimalV3())) { return false; diff --git a/fe/fe-core/src/main/java/org/apache/doris/qe/SessionVariable.java b/fe/fe-core/src/main/java/org/apache/doris/qe/SessionVariable.java index 0594161a25..7ea1d49323 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/qe/SessionVariable.java +++ b/fe/fe-core/src/main/java/org/apache/doris/qe/SessionVariable.java @@ -363,6 +363,8 @@ public class SessionVariable implements Serializable, Writable { public static final String CBO_NET_WEIGHT = "cbo_net_weight"; + public static final String ROUND_PRECISE_DECIMALV2_VALUE = "round_precise_decimalv2_value"; + public static final List DEBUG_VARIABLES = ImmutableList.of( SKIP_DELETE_PREDICATE, SKIP_DELETE_BITMAP, @@ -377,6 +379,9 @@ public class SessionVariable implements Serializable, Writable { // if it is setStmt, we needn't collect session origin value public boolean isSingleSetVar = false; + @VariableMgr.VarAttr(name = ROUND_PRECISE_DECIMALV2_VALUE) + public boolean roundPreciseDecimalV2Value = false; + @VariableMgr.VarAttr(name = INSERT_VISIBLE_TIMEOUT_MS, needForward = true) public long insertVisibleTimeoutMs = DEFAULT_INSERT_VISIBLE_TIMEOUT_MS; diff --git a/gensrc/script/doris_builtins_functions.py b/gensrc/script/doris_builtins_functions.py index f79f1c0a4e..31c1bef74d 100644 --- a/gensrc/script/doris_builtins_functions.py +++ b/gensrc/script/doris_builtins_functions.py @@ -1242,10 +1242,11 @@ visible_functions = { [['floor', 'dfloor'], 'DOUBLE', ['DOUBLE'], ''], [['round', 'dround'], 'DOUBLE', ['DOUBLE'], ''], [['round_bankers'], 'DOUBLE', ['DOUBLE'], ''], - [['ceil', 'ceiling', 'dceil'], 'DECIMALV2', ['DECIMALV2'], ''], - [['floor', 'dfloor'], 'DECIMALV2', ['DECIMALV2'], ''], - [['round', 'dround'], 'DECIMALV2', ['DECIMALV2'], ''], - [['round_bankers'], 'DECIMALV2', ['DECIMALV2'], ''], + [['ceil', 'ceiling', 'dceil'], 'DOUBLE', ['DOUBLE', 'INT'], ''], + [['floor', 'dfloor'], 'DOUBLE', ['DOUBLE', 'INT'], ''], + [['round', 'dround'], 'DOUBLE', ['DOUBLE', 'INT'], ''], + [['round_bankers'], 'DOUBLE', ['DOUBLE', 'INT'], ''], + [['truncate'], 'DOUBLE', ['DOUBLE'], ''], [['ceil', 'ceiling', 'dceil'], 'DECIMAL32', ['DECIMAL32'], ''], [['floor', 'dfloor'], 'DECIMAL32', ['DECIMAL32'], ''], [['round', 'dround'], 'DECIMAL32', ['DECIMAL32'], ''], @@ -1259,25 +1260,20 @@ visible_functions = { [['round', 'dround'], 'DECIMAL128', ['DECIMAL128'], ''], [['round_bankers'], 'DECIMAL128', ['DECIMAL128'], ''], [['round', 'dround'], 'DOUBLE', ['DOUBLE', 'INT'], ''], - [['round', 'dround'], 'DECIMALV2', ['DECIMALV2', 'INT'], ''], [['round', 'dround'], 'DECIMAL32', ['DECIMAL32', 'INT'], ''], [['round', 'dround'], 'DECIMAL64', ['DECIMAL64', 'INT'], ''], [['round', 'dround'], 'DECIMAL128', ['DECIMAL128', 'INT'], ''], [['round_bankers', 'round_bankers'], 'DOUBLE', ['DOUBLE', 'INT'], ''], - [['round_bankers', 'round_bankers'], 'DECIMALV2', ['DECIMALV2', 'INT'], ''], [['round_bankers'], 'DECIMAL32', ['DECIMAL32', 'INT'], ''], [['round_bankers'], 'DECIMAL64', ['DECIMAL64', 'INT'], ''], [['round_bankers'], 'DECIMAL128', ['DECIMAL128', 'INT'], ''], - [['floor', 'dfloor'], 'DECIMALV2', ['DECIMALV2', 'INT'], ''], [['floor', 'dfloor'], 'DECIMAL32', ['DECIMAL32', 'INT'], ''], [['floor', 'dfloor'], 'DECIMAL64', ['DECIMAL64', 'INT'], ''], [['floor', 'dfloor'], 'DECIMAL128', ['DECIMAL128', 'INT'], ''], - [['ceil', 'dceil'], 'DECIMALV2', ['DECIMALV2', 'INT'], ''], [['ceil', 'dceil'], 'DECIMAL32', ['DECIMAL32', 'INT'], ''], [['ceil', 'dceil'], 'DECIMAL64', ['DECIMAL64', 'INT'], ''], [['ceil', 'dceil'], 'DECIMAL128', ['DECIMAL128', 'INT'], ''], [['truncate'], 'DOUBLE', ['DOUBLE', 'INT'], ''], - [['truncate'], 'DECIMALV2', ['DECIMALV2', 'INT'], ''], [['truncate'], 'DECIMAL32', ['DECIMAL32', 'INT'], ''], [['truncate'], 'DECIMAL64', ['DECIMAL64', 'INT'], ''], [['truncate'], 'DECIMAL128', ['DECIMAL128', 'INT'], ''], diff --git a/regression-test/data/query_p0/sql_functions/math_functions/test_round.out b/regression-test/data/query_p0/sql_functions/math_functions/test_round.out index d6a4c290e7..09f7105980 100644 --- a/regression-test/data/query_p0/sql_functions/math_functions/test_round.out +++ b/regression-test/data/query_p0/sql_functions/math_functions/test_round.out @@ -12,9 +12,9 @@ 10.12 -- !truncate -- -1 1989 1001 123.1 0.1 6.3 -2 1986 1001 1243.5 20.2 789.2 -3 1989 1002 24453.3 78945.0 3653.9 +1.0 1989.0 1001.0 123.1 0.1 6.3 +2.0 1986.0 1001.0 1243.5 20.2 789.2 +3.0 1989.0 1002.0 24453.3 78945.0 3654.0 -- !select -- 16 16 16 @@ -116,6 +116,12 @@ -- !query -- 0.000 0.000 0.000 +-- !query -- +16.02 + +-- !query -- +16.02 + -- !query -- 16.03 @@ -123,8 +129,14 @@ 16.02 -- !query -- -16.030000000 +16.03 -- !query -- -16.020000000 +16.02 + +-- !query -- +16.03 + +-- !query -- +16.02 diff --git a/regression-test/suites/query_p0/sql_functions/math_functions/test_round.groovy b/regression-test/suites/query_p0/sql_functions/math_functions/test_round.groovy index 71fb2677ec..a95bf4414c 100644 --- a/regression-test/suites/query_p0/sql_functions/math_functions/test_round.groovy +++ b/regression-test/suites/query_p0/sql_functions/math_functions/test_round.groovy @@ -145,6 +145,12 @@ sql """ DROP TABLE IF EXISTS `test_decimalv2` """ sql """ CREATE TABLE `test_decimalv2` ( id int, decimal_col DECIMAL(19,5)) ENGINE=OLAP duplicate KEY (id) DISTRIBUTED BY HASH(id) BUCKETS 1 PROPERTIES ( "replication_allocation" = "tag.location.default: 1"); """ sql """ insert into test_decimalv2 values (1, 16.025); """ + sql """ set round_precise_decimalv2_value=false; """ + qt_query """ select round(decimal_col,2) from test_decimalv2; """ + qt_query """ select truncate(decimal_col,2) from test_decimalv2; """ + qt_query """ select ceil(decimal_col,2) from test_decimalv2; """ + qt_query """ select floor(decimal_col,2) from test_decimalv2; """ + sql """ set round_precise_decimalv2_value=true; """ qt_query """ select round(decimal_col,2) from test_decimalv2; """ qt_query """ select truncate(decimal_col,2) from test_decimalv2; """ qt_query """ select ceil(decimal_col,2) from test_decimalv2; """