From 7206ca39eca8154ec2fd343ef8651523be769946 Mon Sep 17 00:00:00 2001 From: Jerry Hu Date: Thu, 2 Jan 2025 18:42:04 +0800 Subject: [PATCH] [fix](DECIMAL) error DECIMAL cat to BOOLEAN (#44326) (#46275) In the past, there were issues with converting `double` and `decimal` to `boolean`. For example, a `double` value like 0.13 would first be cast to `uint8`, resulting in 0. Then, it would be converted to `bool`, yielding 0 (incorrect, as the expected result is 1). Similarly, `decimal` values were directly cast to `uint8`, leading to non-0/1 values for `bool`. This issue arises because Doris internally uses `uint8` to represent `boolean`. before ``` mysql> select cast(40.123 as BOOLEAN); +-------------------------+ | cast(40.123 as BOOLEAN) | +-------------------------+ | 40 | +-------------------------+ ``` now ``` mysql> select cast(40.123 as BOOLEAN); +-------------------------+ | cast(40.123 as BOOLEAN) | +-------------------------+ | 1 | +-------------------------+ ``` ### What problem does this PR solve? pick #44326 Related PR: #[44326](https://github.com/mrhhsg/doris/tree/pick_44326) Problem Summary: --- be/src/vec/data_types/data_type_decimal.h | 19 ++++--- be/src/vec/functions/function_cast.h | 43 ++++++++++----- .../test_cast_decimalv3_as_bool.out | 17 ++++++ .../test_cast_decimalv3_as_bool.groovy | 55 +++++++++++++++++++ .../test_case_function_null.groovy | 18 ++++-- 5 files changed, 124 insertions(+), 28 deletions(-) create mode 100644 regression-test/data/correctness/test_cast_decimalv3_as_bool.out create mode 100644 regression-test/suites/correctness/test_cast_decimalv3_as_bool.groovy diff --git a/be/src/vec/data_types/data_type_decimal.h b/be/src/vec/data_types/data_type_decimal.h index d95750c9f2..4880b2ceb8 100644 --- a/be/src/vec/data_types/data_type_decimal.h +++ b/be/src/vec/data_types/data_type_decimal.h @@ -509,15 +509,20 @@ void convert_from_decimals(RealTo* dst, const RealFrom* src, UInt32 precicion_fr MaxFieldType multiplier = DataTypeDecimal::get_scale_multiplier(scale_from); FromDataType from_data_type(precicion_from, scale_from); for (size_t i = 0; i < size; i++) { - auto tmp = static_cast(src[i]).value / multiplier.value; - if constexpr (narrow_integral) { - if (tmp < min_result.value || tmp > max_result.value) { - THROW_DECIMAL_CONVERT_OVERFLOW_EXCEPTION(from_data_type.to_string(src[i]), - from_data_type.get_name(), - OrigToDataType {}.get_name()); + // uint8_t now use as boolean in doris + if constexpr (std::is_same_v) { + dst[i] = static_cast(src[i]).value != 0; + } else { + auto tmp = static_cast(src[i]).value / multiplier.value; + if constexpr (narrow_integral) { + if (tmp < min_result.value || tmp > max_result.value) { + THROW_DECIMAL_CONVERT_OVERFLOW_EXCEPTION(from_data_type.to_string(src[i]), + from_data_type.get_name(), + OrigToDataType {}.get_name()); + } } + dst[i] = tmp; } - dst[i] = tmp; } } diff --git a/be/src/vec/functions/function_cast.h b/be/src/vec/functions/function_cast.h index 68b1eb85f2..cf3ea7b079 100644 --- a/be/src/vec/functions/function_cast.h +++ b/be/src/vec/functions/function_cast.h @@ -256,6 +256,21 @@ struct ConvertImpl { using FromFieldType = typename FromDataType::FieldType; using ToFieldType = typename ToDataType::FieldType; + // `static_cast_set` is introduced to wrap `static_cast` and handle special cases. + // Doris uses `uint8` to represent boolean values internally. + // Directly `static_cast` to `uint8` can result in non-0/1 values, + // To address this, `static_cast_set` performs an additional check: + // For `uint8` types, it explicitly uses `static_cast` to ensure + // the result is either 0 or 1. + static void static_cast_set(ToFieldType& to, const FromFieldType& from) { + // uint8_t now use as boolean in doris + if constexpr (std::is_same_v) { + to = static_cast(from); + } else { + to = static_cast(from); + } + } + template static Status execute(FunctionContext* context, Block& block, const ColumnNumbers& arguments, size_t result, size_t input_rows_count, @@ -375,8 +390,9 @@ struct ConvertImpl { } else if constexpr (IsDateTimeV2Type) { DataTypeDateTimeV2::cast_from_date(vec_from[i], vec_to[i]); } else { - vec_to[i] = - reinterpret_cast(vec_from[i]).to_int64(); + static_cast_set( + vec_to[i], + reinterpret_cast(vec_from[i]).to_int64()); } } } else if constexpr (IsTimeV2Type) { @@ -407,13 +423,16 @@ struct ConvertImpl { } } else { if constexpr (IsDateTimeV2Type) { - vec_to[i] = reinterpret_cast&>( - vec_from[i]) - .to_int64(); + static_cast_set( + vec_to[i], + reinterpret_cast&>( + vec_from[i]) + .to_int64()); } else { - vec_to[i] = reinterpret_cast&>( - vec_from[i]) - .to_int64(); + static_cast_set(vec_to[i], + reinterpret_cast&>( + vec_from[i]) + .to_int64()); } } } @@ -435,16 +454,10 @@ struct ConvertImpl { return Status::OK(); } else { for (size_t i = 0; i < size; ++i) { - vec_to[i] = static_cast(vec_from[i]); + static_cast_set(vec_to[i], vec_from[i]); } } } - // TODO: support boolean cast more reasonable - if constexpr (std::is_same_v) { - for (int i = 0; i < size; ++i) { - vec_to[i] = static_cast(vec_to[i]); - } - } block.replace_by_position(result, std::move(col_to)); } else { diff --git a/regression-test/data/correctness/test_cast_decimalv3_as_bool.out b/regression-test/data/correctness/test_cast_decimalv3_as_bool.out new file mode 100644 index 0000000000..4f41130b00 --- /dev/null +++ b/regression-test/data/correctness/test_cast_decimalv3_as_bool.out @@ -0,0 +1,17 @@ +-- This file is automatically generated. You should know what you did if you want to edit this +-- !select1 -- +0.000 13131.213132100 0E-16 +0.000 2131231.231000000 2.3323000E-9 +3.141 0E-9 123123.2131231231322130 + +-- !select2 -- +false true false +false true true +true false true + +-- !select3 -- +true 1 true false + +-- !select3 -- +true 1 true false + diff --git a/regression-test/suites/correctness/test_cast_decimalv3_as_bool.groovy b/regression-test/suites/correctness/test_cast_decimalv3_as_bool.groovy new file mode 100644 index 0000000000..768da49325 --- /dev/null +++ b/regression-test/suites/correctness/test_cast_decimalv3_as_bool.groovy @@ -0,0 +1,55 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +suite("test_cast_decimalv3_as_bool") { + sql """ DROP TABLE IF EXISTS cast_decimalv3_as_bool """ + sql """ + CREATE TABLE IF NOT EXISTS cast_decimalv3_as_bool ( + `id` int(11) , + `k1` decimalv3(9,3) , + `k2` decimalv3(18,9) , + `k3` decimalv3(38,16) , + ) + UNIQUE KEY(`id`) + DISTRIBUTED BY HASH(`id`) BUCKETS 10 + PROPERTIES ( + "enable_unique_key_merge_on_write" = "true", + "replication_num" = "1" + ); + """ + sql """ + set enable_nereids_planner=true,enable_fold_constant_by_be = false + """ + sql """ + INSERT INTO cast_decimalv3_as_bool VALUES + (1,0.00001,13131.2131321,0.000000000000000000), + (2,0.00000,2131231.231,0.0000000023323), + (3,3.141414,0.0000000000,123123.213123123132213); + """ + qt_select1 """ + select k1,k2,k3 from cast_decimalv3_as_bool order by id + """ + qt_select2 """ + select cast(k1 as boolean), cast(k2 as boolean) , cast(k3 as boolean) from cast_decimalv3_as_bool order by id + """ + qt_select3""" + select cast(3.00001 as boolean), cast(cast(3.00001 as boolean) as int),cast(0.001 as boolean),cast(0.000 as boolean); + """ + qt_select3""" + select cast(cast(3.00001 as double)as boolean), cast(cast(cast(3.00001 as double) as boolean) as int),cast(cast(0.001 as double) as boolean),cast(cast(0.000 as double) as boolean); + """ +} \ No newline at end of file diff --git a/regression-test/suites/query_p0/sql_functions/case_function/test_case_function_null.groovy b/regression-test/suites/query_p0/sql_functions/case_function/test_case_function_null.groovy index 5138db6e73..a91c86b5f4 100644 --- a/regression-test/suites/query_p0/sql_functions/case_function/test_case_function_null.groovy +++ b/regression-test/suites/query_p0/sql_functions/case_function/test_case_function_null.groovy @@ -185,10 +185,11 @@ suite("test_case_function_null", "query,p0,arrow_flight_sql") { c2, c1; """ - + // There is a behavior change. The 0.4cast boolean used to be 0 in the past, but now it has changed to 1. + // Therefore, we need to update the case accordingly. qt_sql_case1 """ SELECT SUM( - CASE (((NULL BETWEEN NULL AND NULL)) and (CAST(0.4716 AS BOOLEAN))) + CASE (((NULL BETWEEN NULL AND NULL)) and (CAST(0.0 AS BOOLEAN))) WHEN ((CAST('-1530390546' AS VARCHAR)) LIKE ('-1678299490')) THEN (- (+ case_null2.c0)) WHEN CASE (NULL IN (NULL)) @@ -197,9 +198,10 @@ suite("test_case_function_null", "query,p0,arrow_flight_sql") { END) FROM case_null2; """ - + // There is a behavior change. The 0.4cast boolean used to be 0 in the past, but now it has changed to 1. + // Therefore, we need to update the case accordingly. qt_sql_case2 """ - SELECT SUM(CASE (((NULL BETWEEN NULL AND NULL)) and (CAST(0.4716 AS BOOLEAN))) + SELECT SUM(CASE (((NULL BETWEEN NULL AND NULL)) and (CAST(0.0 AS BOOLEAN))) WHEN ((CAST('-1530390546' AS VARCHAR)) LIKE ('-1678299490')) THEN (- (+ case_null2.c0)) END) @@ -209,9 +211,11 @@ suite("test_case_function_null", "query,p0,arrow_flight_sql") { sql "SET experimental_enable_nereids_planner=true" sql "SET enable_fallback_to_original_planner=false" + // There is a behavior change. The 0.4cast boolean used to be 0 in the past, but now it has changed to 1. + // Therefore, we need to update the case accordingly. qt_sql_case1 """ SELECT SUM( - CASE (((NULL BETWEEN NULL AND NULL)) and (CAST(0.4716 AS BOOLEAN))) + CASE (((NULL BETWEEN NULL AND NULL)) and (CAST(0.0 AS BOOLEAN))) WHEN ((CAST('-1530390546' AS VARCHAR)) LIKE ('-1678299490')) THEN (- (+ case_null2.c0)) WHEN CASE (NULL IN (NULL)) @@ -221,8 +225,10 @@ suite("test_case_function_null", "query,p0,arrow_flight_sql") { FROM case_null2; """ + // There is a behavior change. The 0.4cast boolean used to be 0 in the past, but now it has changed to 1. + // Therefore, we need to update the case accordingly. qt_sql_case2 """ - SELECT SUM(CASE (((NULL BETWEEN NULL AND NULL)) and (CAST(0.4716 AS BOOLEAN))) + SELECT SUM(CASE (((NULL BETWEEN NULL AND NULL)) and (CAST(0.0 AS BOOLEAN))) WHEN ((CAST('-1530390546' AS VARCHAR)) LIKE ('-1678299490')) THEN (- (+ case_null2.c0)) END)