[fix](round) Fix incorrect decimal scale inference in round functions (#34471)

* FIX NEEDED

* FORMAT

* FORMAT

* FIX TEST
This commit is contained in:
zhiqiang
2024-05-10 16:09:46 +08:00
committed by yiguolei
parent 0a79c547ff
commit 58c19e33b3
4 changed files with 237 additions and 42 deletions

View File

@ -21,13 +21,17 @@
#pragma once
#include <cstddef>
#include <memory>
#include "common/exception.h"
#include "common/status.h"
#include "vec/columns/column_const.h"
#include "vec/columns/columns_number.h"
#include "vec/common/assert_cast.h"
#include "vec/core/column_with_type_and_name.h"
#include "vec/core/types.h"
#include "vec/data_types/data_type.h"
#include "vec/data_types/data_type_nullable.h"
#include "vec/functions/function.h"
#if defined(__SSE4_1__) || defined(__aarch64__)
#include "util/sse_util.hpp"
@ -430,7 +434,10 @@ struct Dispatcher {
FloatRoundingImpl<T, rounding_mode, scale_mode, tie_breaking_mode>,
IntegerRoundingImpl<T, rounding_mode, scale_mode, tie_breaking_mode>>>;
static ColumnPtr apply_vec_const(const IColumn* col_general, Int16 scale_arg) {
// scale_arg: scale for function computation
// result_scale: scale for result decimal, this scale is got from planner
static ColumnPtr apply_vec_const(const IColumn* col_general, const Int16 scale_arg,
[[maybe_unused]] Int16 result_scale) {
if constexpr (IsNumber<T>) {
const auto* const col = check_and_get_column<ColumnVector<T>>(col_general);
auto col_res = ColumnVector<T>::create();
@ -457,10 +464,7 @@ struct Dispatcher {
} else if constexpr (IsDecimalNumber<T>) {
const auto* const decimal_col = check_and_get_column<ColumnDecimal<T>>(col_general);
const auto& vec_src = decimal_col->get_data();
UInt32 result_scale =
std::min(static_cast<UInt32>(std::max(scale_arg, static_cast<Int16>(0))),
decimal_col->get_scale());
const size_t input_rows_count = vec_src.size();
auto col_res = ColumnDecimal<T>::create(vec_src.size(), result_scale);
auto& vec_res = col_res->get_data();
@ -468,6 +472,27 @@ struct Dispatcher {
FunctionRoundingImpl<ScaleMode::Negative>::apply(
decimal_col->get_data(), decimal_col->get_scale(), vec_res, scale_arg);
}
// We need to always make sure result decimal's scale is as expected as its in plan
// So we need to append enough zero to result.
// Case 0: scale_arg <= -(integer part digits count)
// do nothing, because result is 0
// Case 1: scale_arg <= 0 && scale_arg > -(integer part digits count)
// decimal parts has been erased, so add them back by multiply 10^(result_scale)
// Case 2: scale_arg > 0 && scale_arg < result_scale
// decimal part now has scale_arg digits, so multiply 10^(result_scale - scal_arg)
// Case 3: scale_arg >= input_scale
// do nothing
if (scale_arg <= 0) {
for (size_t i = 0; i < input_rows_count; ++i) {
vec_res[i].value *= int_exp10(result_scale);
}
} else if (scale_arg > 0 && scale_arg < result_scale) {
for (size_t i = 0; i < input_rows_count; ++i) {
vec_res[i].value *= int_exp10(result_scale - scale_arg);
}
}
return col_res;
} else {
@ -477,7 +502,9 @@ struct Dispatcher {
}
}
static ColumnPtr apply_vec_vec(const IColumn* col_general, const IColumn* col_scale) {
// result_scale: scale for result decimal, this scale is got from planner
static ColumnPtr apply_vec_vec(const IColumn* col_general, const IColumn* col_scale,
[[maybe_unused]] Int16 result_scale) {
const auto& col_scale_i32 = assert_cast<const ColumnInt32&>(*col_scale);
const size_t input_row_count = col_scale_i32.size();
for (size_t i = 0; i < input_row_count; ++i) {
@ -515,10 +542,8 @@ struct Dispatcher {
return col_res;
} else if constexpr (IsDecimalNumber<T>) {
const auto* decimal_col = assert_cast<const ColumnDecimal<T>*>(col_general);
// ALWAYS use SAME scale with source Decimal column
const Int32 input_scale = decimal_col->get_scale();
auto col_res = ColumnDecimal<T>::create(input_row_count, input_scale);
auto col_res = ColumnDecimal<T>::create(input_row_count, result_scale);
for (size_t i = 0; i < input_row_count; ++i) {
DecimalRoundingImpl<T, rounding_mode, tie_breaking_mode>::apply(
@ -534,15 +559,15 @@ struct Dispatcher {
// do nothing, because result is 0
// Case 1: scale_arg <= 0 && scale_arg > -(integer part digits count)
// decimal parts has been erased, so add them back by multiply 10^(scale_arg)
// Case 2: scale_arg > 0 && scale_arg < decimal part digits count
// decimal part now has scale_arg digits, so multiply 10^(input_scale - scal_arg)
// Case 2: scale_arg > 0 && scale_arg < result_scale
// decimal part now has scale_arg digits, so multiply 10^(result_scale - scal_arg)
// Case 3: scale_arg >= input_scale
// do nothing
const Int32 scale_arg = col_scale_i32.get_data()[i];
if (scale_arg <= 0) {
col_res->get_element(i).value *= int_exp10(input_scale);
} else if (scale_arg > 0 && scale_arg < input_scale) {
col_res->get_element(i).value *= int_exp10(input_scale - scale_arg);
col_res->get_element(i).value *= int_exp10(result_scale);
} else if (scale_arg > 0 && scale_arg < result_scale) {
col_res->get_element(i).value *= int_exp10(result_scale - scale_arg);
}
}
@ -554,8 +579,9 @@ struct Dispatcher {
}
}
static ColumnPtr apply_const_vec(const ColumnConst* const_col_general,
const IColumn* col_scale) {
// result_scale: scale for result decimal, this scale is got from planner
static ColumnPtr apply_const_vec(const ColumnConst* const_col_general, const IColumn* col_scale,
[[maybe_unused]] Int16 result_scale) {
const auto& col_scale_i32 = assert_cast<const ColumnInt32&>(*col_scale);
const size_t input_rows_count = col_scale->size();
@ -575,8 +601,7 @@ struct Dispatcher {
assert_cast<const ColumnDecimal<T>&>(const_col_general->get_data_column());
const T& general_val = data_col_general.get_data()[0];
Int32 input_scale = data_col_general.get_scale();
auto col_res = ColumnDecimal<T>::create(input_rows_count, input_scale);
auto col_res = ColumnDecimal<T>::create(input_rows_count, result_scale);
for (size_t i = 0; i < input_rows_count; ++i) {
DecimalRoundingImpl<T, rounding_mode, tie_breaking_mode>::apply(
@ -592,15 +617,15 @@ struct Dispatcher {
// do nothing, because result is 0
// Case 1: scale_arg <= 0 && scale_arg > -(integer part digits count)
// decimal parts has been erased, so add them back by multiply 10^(scale_arg)
// Case 2: scale_arg > 0 && scale_arg < decimal part digits count
// decimal part now has scale_arg digits, so multiply 10^(input_scale - scal_arg)
// Case 2: scale_arg > 0 && scale_arg < result_scale
// decimal part now has scale_arg digits, so multiply 10^(result_scale - scal_arg)
// Case 3: scale_arg >= input_scale
// do nothing
const Int32 scale_arg = col_scale_i32.get_data()[i];
if (scale_arg <= 0) {
col_res->get_element(i).value *= int_exp10(input_scale);
} else if (scale_arg > 0 && scale_arg < input_scale) {
col_res->get_element(i).value *= int_exp10(input_scale - scale_arg);
col_res->get_element(i).value *= int_exp10(result_scale);
} else if (scale_arg > 0 && scale_arg < result_scale) {
col_res->get_element(i).value *= int_exp10(result_scale - scale_arg);
}
}
@ -679,26 +704,23 @@ public:
return Status::OK();
}
/// SELECT number, truncate(123.345, 1) FROM number("numbers"="10")
/// should NOT behave like two column arguments, so we can not use const column default implementation
bool use_default_implementation_for_constants() const override { return false; }
bool use_default_implementation_for_constants() const override { return true; }
//// We moved and optimized the execute_impl logic of function_truncate.h from PR#32746,
//// as well as make it suitable for all functions.
Status execute_impl(FunctionContext* context, Block& block, const ColumnNumbers& arguments,
size_t result, size_t input_rows_count) const override {
const ColumnWithTypeAndName& column_general = block.get_by_position(arguments[0]);
ColumnWithTypeAndName& column_result = block.get_by_position(result);
const DataTypePtr result_type = block.get_by_position(result).type;
const bool is_col_general_const = is_column_const(*column_general.column);
const auto* col_general = is_col_general_const
? assert_cast<const ColumnConst&>(*column_general.column)
.get_data_column_ptr()
: column_general.column.get();
ColumnPtr res;
/// potential argument types:
/// if the SECOND argument is MISSING(would be considered as ZERO const) or CONST, then we have the following type:
/// 1. func(Column), func(ColumnConst), func(Column, ColumnConst), func(ColumnConst, ColumnConst)
/// 1. func(Column), func(Column, ColumnConst)
/// otherwise, the SECOND arugment is COLUMN, we have another type:
/// 2. func(Column, Column), func(ColumnConst, Column)
@ -706,6 +728,23 @@ public:
using Types = std::decay_t<decltype(types)>;
using DataType = typename Types::LeftType;
// For decimal, we will always make sure result Decimal has exactly same precision and scale with
// arguments from query plan.
Int16 result_scale = 0;
if constexpr (IsDataTypeDecimal<DataType>) {
if (column_result.type->get_type_id() == TypeIndex::Nullable) {
if (auto nullable_type = std::dynamic_pointer_cast<const DataTypeNullable>(
column_result.type)) {
result_scale = nullable_type->get_nested_type()->get_scale();
} else {
throw doris::Exception(ErrorCode::INTERNAL_ERROR,
"Illegal nullable column");
}
} else {
result_scale = column_result.type->get_scale();
}
}
if constexpr (IsDataTypeNumber<DataType> || IsDataTypeDecimal<DataType>) {
using FieldType = typename DataType::FieldType;
if (arguments.size() == 1 ||
@ -718,23 +757,20 @@ public:
}
res = Dispatcher<FieldType, rounding_mode, tie_breaking_mode>::apply_vec_const(
col_general, scale_arg);
if (is_col_general_const) {
// Important, make sure the result column has the same size as the input column
res = ColumnConst::create(std::move(res), input_rows_count);
}
col_general, scale_arg, result_scale);
} else {
// the SECOND arugment is COLUMN
if (is_col_general_const) {
res = Dispatcher<FieldType, rounding_mode, tie_breaking_mode>::
apply_const_vec(
&assert_cast<const ColumnConst&>(*column_general.column),
block.get_by_position(arguments[1]).column.get());
block.get_by_position(arguments[1]).column.get(),
result_scale);
} else {
res = Dispatcher<FieldType, rounding_mode, tie_breaking_mode>::
apply_vec_vec(col_general,
block.get_by_position(arguments[1]).column.get());
block.get_by_position(arguments[1]).column.get(),
result_scale);
}
}
return true;
@ -758,7 +794,7 @@ public:
column_general.type->get_name(), name);
}
block.replace_by_position(result, std::move(res));
column_result.column = std::move(res);
return Status::OK();
}
};

View File

@ -37,9 +37,12 @@ public interface ComputePrecisionForRound extends ComputePrecision {
Expression floatLength = getArgument(1);
int scale;
if (floatLength.isLiteral() || (floatLength instanceof Cast && floatLength.child(0).isLiteral()
// If scale arg is an integer literal, or it is a cast(Integer as Integer)
// then we will try to use its value as result scale
// In any other cases, we will make sure result decimal has same scale with input.
if ((floatLength.isLiteral() && floatLength.getDataType() instanceof Int32OrLessType)
|| (floatLength instanceof Cast && floatLength.child(0).isLiteral()
&& floatLength.child(0).getDataType() instanceof Int32OrLessType)) {
// Scale argument is a literal or cast from other literal
if (floatLength instanceof Cast) {
scale = ((IntegerLikeLiteral) floatLength.child(0)).getIntValue();
} else {

View File

@ -1,4 +1,115 @@
-- This file is automatically generated. You should know what you did if you want to edit this
-- !select --
123.100
-- !select --
123.100
123.100
123.100
123.100
123.100
123.100
123.100
123.100
123.100
123.100
-- !select --
120.000
120.000
120.000
120.000
120.000
120.000
120.000
120.000
120.000
120.000
-- !select --
123.100
123.100
123.100
123.100
123.100
123.100
123.100
123.100
123.100
123.100
-- !select --
120.000
120.000
120.000
120.000
120.000
120.000
120.000
120.000
120.000
120.000
-- !select --
123.200
123.200
123.200
123.200
123.200
123.200
123.200
123.200
123.200
123.200
-- !select --
130.000
130.000
130.000
130.000
130.000
130.000
130.000
130.000
130.000
130.000
-- !select --
123.100
123.100
123.100
123.100
123.100
123.100
123.100
123.100
123.100
123.100
-- !select --
120.000
120.000
120.000
120.000
120.000
120.000
120.000
120.000
120.000
120.000
-- !select --
4434.41
-- !select --
0
-- !select --
false \N 4434
-- !select --
0
-- !select --
10
@ -97,6 +208,18 @@
-- !select --
16.025 16.02500 16.02500
-- !select_fix --
16.025 16.02500 16.02500
-- !select_fix --
16.025 16.02500 16.02500
-- !select_fix --
16.025 16.02500 16.02500
-- !select_fix --
16.025 16.02500 16.02500
-- !nereids_round_arg1 --
10

View File

@ -15,7 +15,35 @@
// specific language governing permissions and limitations
// under the License.
suite("test_round") {
suite("test_round") {
sql "set enable_fold_constant_by_be=false;"
sql "SET enable_nereids_planner=true"
sql "SET enable_fallback_to_original_planner=false"
qt_select "SELECT round(123.123, 1.123);"
qt_select """SELECT round(123.123, 1.123) FROM numbers("number"="10");"""
qt_select """SELECT round(123.123, -1.123) FROM numbers("number"="10");"""
qt_select """SELECT truncate(123.123, 1.123) FROM numbers("number"="10");"""
qt_select """SELECT truncate(123.123, -1.123) FROM numbers("number"="10");"""
qt_select """SELECT ceil(123.123, 1.123) FROM numbers("number"="10");"""
qt_select """SELECT ceil(123.123, -1.123) FROM numbers("number"="10");"""
qt_select """SELECT round_bankers(123.123, 1.123) FROM numbers("number"="10");"""
qt_select """SELECT round_bankers(123.123, -1.123) FROM numbers("number"="10");"""
sql """drop table if exists test_round_1; """
sql """
create table test_round_1(big_key bigint not NULL)
DISTRIBUTED BY HASH(big_key) BUCKETS 1 PROPERTIES ("replication_num" = "1");
"""
qt_select """SELECT truncate(cast(round(8990.65 - 4556.2354, 2.4652) as Decimal(9,4)), 2);"""
qt_select """SELECT cast(round(round(465.56,min(-5.987)),2) as DECIMAL)"""
qt_select """
SELECT truncate(100,2)<-2308.57 , cast(round(round(465.56,min(-5.987)),2) as DECIMAL) , cast(truncate(round(8990.65-4556.2354,2.4652),2)as DECIMAL) from test_round_1;
"""
qt_select """
SELECT truncate(123456789.123456789, -9);
"""
qt_select "SELECT round(10.12345)"
qt_select "SELECT round(10.12345, 2)"
qt_select "SELECT round_bankers(10.12345)"
@ -62,6 +90,11 @@
qt_select """ SELECT truncate(col1, 7), truncate(col2, 7), truncate(col3, 7) FROM `${tableName}`; """
qt_select """ SELECT round_bankers(col1, 7), round_bankers(col2, 7), round_bankers(col3, 7) FROM `${tableName}`; """
qt_select_fix """ SELECT round(col1, 6.234), round(col2, 6.234), round(col3, 6.234) FROM `${tableName}`; """
qt_select_fix """ SELECT floor(col1, 6.234), floor(col2, 6.234), floor(col3, 6.234) FROM `${tableName}`; """
qt_select_fix """ SELECT truncate(col1, 6.234), truncate(col2, 6.234), truncate(col3, 6.234) FROM `${tableName}`; """
qt_select_fix """ SELECT round_bankers(col1, 6.234), round_bankers(col2, 6.234), round_bankers(col3, 6.234) FROM `${tableName}`; """
sql """ DROP TABLE IF EXISTS `${tableName}` """
sql "SET enable_nereids_planner=true"