From 3b462424834944e489b429374084030bd34e6955 Mon Sep 17 00:00:00 2001 From: Gabriel Date: Thu, 14 Jul 2022 10:50:50 +0800 Subject: [PATCH] [feature-wip] Optimize Decimal type (#10794) * [feature-wip](decimalv3) support decimalv3 * [feature-wip] Optimize Decimal type Co-authored-by: liaoxin --- be/src/common/config.h | 2 + be/src/common/consts.h | 4 + be/src/exec/olap_common.cpp | 12 - be/src/exec/olap_common.h | 132 +++++++--- be/src/exec/tablet_sink.cpp | 2 + be/src/exprs/anyval_util.cpp | 24 ++ be/src/exprs/anyval_util.h | 12 + be/src/exprs/binary_predicate.cpp | 21 ++ be/src/exprs/cast_functions.cpp | 249 ++++++++++++++++++ be/src/exprs/cast_functions.h | 21 ++ be/src/exprs/create_predicate_function.h | 21 +- be/src/exprs/expr.cpp | 24 ++ be/src/exprs/expr.h | 106 +++++--- be/src/exprs/expr_context.cpp | 26 +- be/src/exprs/expr_context.h | 2 +- be/src/exprs/hybrid_set.h | 40 +-- be/src/exprs/literal.cpp | 50 +++- be/src/exprs/literal.h | 3 + be/src/exprs/runtime_filter.cpp | 188 ++++++++++--- be/src/exprs/scalar_fn_call.cpp | 33 +++ be/src/exprs/scalar_fn_call.h | 4 + be/src/exprs/slot_ref.cpp | 30 +++ be/src/exprs/slot_ref.h | 3 + be/src/olap/aggregate_func.cpp | 24 ++ be/src/olap/bloom_filter_predicate.cpp | 5 +- be/src/olap/column_vector.cpp | 15 ++ be/src/olap/delete_handler.cpp | 6 + be/src/olap/field.h | 38 ++- be/src/olap/key_coder.cpp | 3 + be/src/olap/olap_common.h | 6 +- be/src/olap/olap_cond.cpp | 4 +- be/src/olap/reader.cpp | 81 ++++++ be/src/olap/row_block2.cpp | 33 +++ be/src/olap/row_cursor.cpp | 3 +- .../rowset/segment_v2/bitmap_index_writer.cpp | 9 + .../olap/rowset/segment_v2/column_reader.cpp | 2 +- be/src/olap/rowset/segment_v2/column_reader.h | 7 +- .../olap/rowset/segment_v2/encoding_info.cpp | 12 + be/src/olap/rowset/segment_v2/segment.cpp | 3 +- .../rowset/segment_v2/segment_iterator.cpp | 2 + be/src/olap/schema.cpp | 6 + be/src/olap/schema_change.cpp | 5 +- be/src/olap/tablet_meta.cpp | 7 +- be/src/olap/tablet_schema.cpp | 21 ++ be/src/olap/types.cpp | 10 + be/src/olap/types.h | 153 +++++++++-- be/src/olap/wrapper_field.h | 5 +- be/src/runtime/decimalv2_value.cpp | 3 +- be/src/runtime/decimalv2_value.h | 14 + be/src/runtime/primitive_type.cpp | 51 ++++ be/src/runtime/primitive_type.h | 18 ++ be/src/runtime/raw_value.cpp | 77 ++++++ be/src/runtime/raw_value.h | 42 +++ be/src/runtime/types.cpp | 27 +- be/src/runtime/types.h | 16 +- be/src/udf/udf.h | 87 +++++- be/src/util/string_parser.cpp | 16 ++ be/src/util/string_parser.hpp | 73 +---- be/src/util/symbols_util.cpp | 9 + .../aggregate_function_avg.h | 16 +- .../aggregate_function_min_max.cpp | 12 +- .../aggregate_function_min_max.h | 29 +- .../aggregate_function_min_max_by.cpp | 27 +- .../aggregate_function_stddev.cpp | 19 +- .../aggregate_function_stddev.h | 17 +- be/src/vec/columns/column.h | 3 + be/src/vec/columns/column_decimal.cpp | 33 +++ be/src/vec/columns/column_decimal.h | 15 +- be/src/vec/columns/column_nullable.h | 2 + be/src/vec/data_types/data_type_decimal.cpp | 111 +++++++- be/src/vec/data_types/data_type_decimal.h | 79 +++++- be/src/vec/data_types/data_type_factory.cpp | 19 +- be/src/vec/data_types/data_type_factory.hpp | 12 +- be/src/vec/exec/join/vhash_join_node.cpp | 19 +- be/src/vec/exec/vaggregation_node.cpp | 20 +- be/src/vec/exec/volap_scan_node.cpp | 44 +++- be/src/vec/exec/vschema_scan_node.cpp | 9 +- be/src/vec/exec/vset_operation_node.cpp | 3 + be/src/vec/exprs/vliteral.cpp | 53 ++++ be/src/vec/functions/function.h | 7 + .../functions/function_binary_arithmetic.h | 119 ++++++++- be/src/vec/functions/function_cast.h | 52 ++-- be/src/vec/io/io_helper.h | 17 +- be/src/vec/olap/olap_data_convertor.cpp | 14 +- be/src/vec/olap/olap_data_convertor.h | 29 +- be/src/vec/runtime/vfile_result_writer.cpp | 12 + be/src/vec/sink/vmysql_result_writer.cpp | 102 ++++++- be/src/vec/sink/vmysql_table_writer.cpp | 7 + be/test/exprs/runtime_filter_test.cpp | 2 +- .../segment_v2/column_reader_writer_test.cpp | 16 +- fe/fe-core/src/main/cup/sql_parser.cup | 10 +- .../doris/analysis/AggregateInfoBase.java | 2 +- .../apache/doris/analysis/ArithmeticExpr.java | 231 ++++++++++++---- .../doris/analysis/BinaryPredicate.java | 3 + .../analysis/BuiltinAggregateFunction.java | 4 +- .../org/apache/doris/analysis/CastExpr.java | 24 +- .../org/apache/doris/analysis/ColumnDef.java | 3 + .../doris/analysis/CreateFunctionStmt.java | 16 ++ .../apache/doris/analysis/DecimalLiteral.java | 54 +++- .../java/org/apache/doris/analysis/Expr.java | 11 +- .../doris/analysis/ExpressionFunctions.java | 2 +- .../apache/doris/analysis/FloatLiteral.java | 4 +- .../doris/analysis/FunctionCallExpr.java | 43 ++- .../org/apache/doris/analysis/IndexDef.java | 4 +- .../doris/analysis/LargeIntLiteral.java | 2 +- .../apache/doris/analysis/LiteralExpr.java | 3 + .../analysis/MVColumnHLLUnionPattern.java | 2 +- .../apache/doris/analysis/OutFileClause.java | 6 + .../apache/doris/analysis/StringLiteral.java | 3 + .../org/apache/doris/analysis/TypeDef.java | 51 +++- .../doris/catalog/AggregateFunction.java | 4 +- .../java/org/apache/doris/catalog/Column.java | 7 +- .../org/apache/doris/catalog/ColumnType.java | 21 ++ .../org/apache/doris/catalog/Function.java | 24 +- .../org/apache/doris/catalog/FunctionSet.java | 136 +++++++++- .../catalog/HiveMetaStoreClientHelper.java | 5 +- .../apache/doris/catalog/PrimitiveType.java | 209 ++++++++++++++- .../apache/doris/catalog/ScalarFunction.java | 10 +- .../org/apache/doris/catalog/ScalarType.java | 168 +++++++++--- .../java/org/apache/doris/catalog/Type.java | 210 ++++++++++++++- .../java/org/apache/doris/common/Config.java | 7 + .../org/apache/doris/common/util/Util.java | 3 + .../doris/datasource/InternalDataSource.java | 5 +- .../doris/external/hive/util/HiveUtil.java | 3 +- .../iceberg/util/DorisTypeToType.java | 3 +- .../iceberg/util/TypeToDorisType.java | 2 +- .../doris/httpv2/rest/TableSchemaAction.java | 2 +- .../load/loadv2/SparkLoadPendingTask.java | 2 +- .../apache/doris/planner/OriginalPlanner.java | 6 +- .../apache/doris/qe/cache/PartitionRange.java | 3 + .../org/apache/doris/rewrite/FEFunctions.java | 27 ++ .../rewrite/RewriteBinaryPredicatesRule.java | 60 ++++- .../doris/rewrite/RewriteInPredicateRule.java | 4 +- .../apache/doris/statistics/ColumnStats.java | 3 + .../doris/task/HadoopLoadPendingTask.java | 5 +- .../doris/analysis/ArithmeticExprTest.java | 175 ++++++++++++ .../CreateMaterializedViewStmtTest.java | 2 +- .../doris/analysis/DecimalLiteralTest.java | 32 +++ .../doris/analysis/FunctionCallExprTest.java | 82 ++++++ .../apache/doris/backup/CatalogMocker.java | 2 +- .../catalog/ColumnGsonSerializationTest.java | 2 +- .../apache/doris/catalog/ColumnTypeTest.java | 4 +- .../java/org/apache/doris/udf/UdfUtils.java | 2 +- .../doris/load/loadv2/dpp/ColumnParser.java | 5 +- .../doris/load/loadv2/dpp/DppUtils.java | 6 + .../doris/load/loadv2/dpp/SparkDpp.java | 3 + .../load/loadv2/dpp/SparkRDDAggregator.java | 9 + gensrc/proto/internal_service.proto | 5 + gensrc/thrift/Types.thrift | 3 + 149 files changed, 4011 insertions(+), 549 deletions(-) create mode 100644 fe/fe-core/src/test/java/org/apache/doris/analysis/ArithmeticExprTest.java create mode 100644 fe/fe-core/src/test/java/org/apache/doris/analysis/FunctionCallExprTest.java diff --git a/be/src/common/config.h b/be/src/common/config.h index a9ff1624f9..d1c377edf4 100644 --- a/be/src/common/config.h +++ b/be/src/common/config.h @@ -773,6 +773,8 @@ CONF_Bool(parquet_predicate_push_down, "false"); // if it is lower than a specific threshold, the predicate will be disabled. CONF_mInt32(bloom_filter_predicate_check_row_num, "1000"); +CONF_Bool(enable_decimalv3, "false"); + //whether turn on quick compaction feature CONF_Bool(enable_quick_compaction, "false"); // For continuous versions that rows less than quick_compaction_max_rows will trigger compaction quickly diff --git a/be/src/common/consts.h b/be/src/common/consts.h index 5a89973b6e..22aaae512c 100644 --- a/be/src/common/consts.h +++ b/be/src/common/consts.h @@ -24,5 +24,9 @@ namespace BeConsts { const std::string CSV = "csv"; const std::string CSV_WITH_NAMES = "csv_with_names"; const std::string CSV_WITH_NAMES_AND_TYPES = "csv_with_names_and_types"; + +constexpr int MAX_DECIMAL32_PRECISION = 9; +constexpr int MAX_DECIMAL64_PRECISION = 18; +constexpr int MAX_DECIMAL128_PRECISION = 38; } // namespace BeConsts } // namespace doris diff --git a/be/src/exec/olap_common.cpp b/be/src/exec/olap_common.cpp index 4760b5a151..8069c47a17 100644 --- a/be/src/exec/olap_common.cpp +++ b/be/src/exec/olap_common.cpp @@ -29,18 +29,6 @@ namespace doris { -template <> -std::string cast_to_string(__int128 value) { - std::stringstream ss; - ss << value; - return ss.str(); -} - -template <> -std::string cast_to_string(int8_t value) { - return std::to_string(static_cast(value)); -} - template <> void ColumnValueRange::convert_to_fixed_value() { return; diff --git a/be/src/exec/olap_common.h b/be/src/exec/olap_common.h index 0484eb93c0..abb856cef3 100644 --- a/be/src/exec/olap_common.h +++ b/be/src/exec/olap_common.h @@ -29,19 +29,35 @@ #include "olap/tuple.h" #include "runtime/primitive_type.h" #include "runtime/type_limit.h" +#include "vec/io/io_helper.h" namespace doris { -template -std::string cast_to_string(T value) { - return boost::lexical_cast(value); +template +std::string cast_to_string(T value, int scale) { + if constexpr (primitive_type == TYPE_DECIMAL32) { + std::stringstream ss; + vectorized::write_text((int32_t)value, scale, ss); + return ss.str(); + } else if constexpr (primitive_type == TYPE_DECIMAL64) { + std::stringstream ss; + vectorized::write_text((int64_t)value, scale, ss); + return ss.str(); + } else if constexpr (primitive_type == TYPE_DECIMAL128) { + std::stringstream ss; + vectorized::write_text((int128_t)value, scale, ss); + return ss.str(); + } else if constexpr (primitive_type == TYPE_TINYINT) { + return std::to_string(static_cast(value)); + } else if constexpr (primitive_type == TYPE_LARGEINT) { + std::stringstream ss; + ss << value; + return ss.str(); + } else { + return boost::lexical_cast(value); + } } -// TYPE_TINYINT should cast to int32_t to first -// because it need to convert to num not char for build Olap fetch Query -template <> -std::string cast_to_string(int8_t); - /** * @brief Column's value range **/ @@ -58,6 +74,11 @@ public: ColumnValueRange(std::string col_name, const CppType& min, const CppType& max, bool contain_null); + ColumnValueRange(std::string col_name, int precision, int scale); + + ColumnValueRange(std::string col_name, const CppType& min, const CppType& max, + bool contain_null, int precision, int scale); + // should add fixed value before add range Status add_fixed_value(const CppType& value); @@ -138,7 +159,8 @@ public: if (TYPE_MIN != _low_value || FILTER_LARGER_OR_EQUAL != _low_op) { low.__set_column_name(_column_name); low.__set_condition_op((_low_op == FILTER_LARGER_OR_EQUAL ? ">=" : ">>")); - low.condition_values.push_back(cast_to_string(_low_value)); + low.condition_values.push_back( + cast_to_string(_low_value, _scale)); } if (low.condition_values.size() != 0) { @@ -149,7 +171,8 @@ public: if (TYPE_MAX != _high_value || FILTER_LESS_OR_EQUAL != _high_op) { high.__set_column_name(_column_name); high.__set_condition_op((_high_op == FILTER_LESS_OR_EQUAL ? "<=" : "<<")); - high.condition_values.push_back(cast_to_string(_high_value)); + high.condition_values.push_back( + cast_to_string(_high_value, _scale)); } if (high.condition_values.size() != 0) { @@ -176,7 +199,8 @@ public: condition.__set_condition_op(is_in ? "*=" : "!*="); for (const auto& value : _fixed_values) { - condition.condition_values.push_back(cast_to_string(value)); + condition.condition_values.push_back( + cast_to_string(value, _scale)); } if (condition.condition_values.size() != 0) { @@ -214,6 +238,8 @@ public: _contain_null = contain_null; }; + const int scale() { return _scale; } + static void add_fixed_value_range(ColumnValueRange& range, CppType* value) { range.add_fixed_value(*value); } @@ -231,6 +257,18 @@ public: return ColumnValueRange(col_name, TYPE_MAX, TYPE_MIN, false); } + static ColumnValueRange create_empty_column_value_range(int precision, + int scale) { + return ColumnValueRange::create_empty_column_value_range("", precision, + scale); + } + + static ColumnValueRange create_empty_column_value_range( + const std::string& col_name, int precision, int scale) { + return ColumnValueRange(col_name, TYPE_MAX, TYPE_MIN, false, precision, + scale); + } + protected: bool is_in_range(const CppType& value); @@ -247,6 +285,8 @@ private: std::set _fixed_values; // Column's fixed int value bool _contain_null; + int _precision; + int _scale; }; class OlapScanKeys { @@ -314,13 +354,14 @@ private: bool _is_convertible; }; -typedef std::variant, ColumnValueRange, - ColumnValueRange, ColumnValueRange, - ColumnValueRange, ColumnValueRange, - ColumnValueRange, ColumnValueRange, - ColumnValueRange, ColumnValueRange, - ColumnValueRange, ColumnValueRange, - ColumnValueRange, ColumnValueRange> +typedef std::variant< + ColumnValueRange, ColumnValueRange, ColumnValueRange, + ColumnValueRange, ColumnValueRange, ColumnValueRange, + ColumnValueRange, ColumnValueRange, ColumnValueRange, + ColumnValueRange, ColumnValueRange, + ColumnValueRange, ColumnValueRange, + ColumnValueRange, ColumnValueRange, + ColumnValueRange, ColumnValueRange> ColumnValueRangeType; template @@ -333,7 +374,8 @@ const typename ColumnValueRange::CppType type_limit::CppType>::max(); template -ColumnValueRange::ColumnValueRange() : _column_type(INVALID_TYPE) {} +ColumnValueRange::ColumnValueRange() + : _column_type(INVALID_TYPE), _precision(-1), _scale(-1) {} template ColumnValueRange::ColumnValueRange(std::string col_name) @@ -348,7 +390,27 @@ ColumnValueRange::ColumnValueRange(std::string col_name, const C _high_value(max), _low_op(FILTER_LARGER_OR_EQUAL), _high_op(FILTER_LESS_OR_EQUAL), - _contain_null(contain_null) {} + _contain_null(contain_null), + _precision(-1), + _scale(-1) {} + +template +ColumnValueRange::ColumnValueRange(std::string col_name, const CppType& min, + const CppType& max, bool contain_null, + int precision, int scale) + : _column_name(std::move(col_name)), + _column_type(primitive_type), + _low_value(min), + _high_value(max), + _low_op(FILTER_LARGER_OR_EQUAL), + _high_op(FILTER_LESS_OR_EQUAL), + _contain_null(contain_null), + _precision(precision), + _scale(scale) {} + +template +ColumnValueRange::ColumnValueRange(std::string col_name, int precision, int scale) + : ColumnValueRange(std::move(col_name), TYPE_MIN, TYPE_MAX, true, precision, scale) {} template Status ColumnValueRange::add_fixed_value(const CppType& value) { @@ -803,9 +865,11 @@ Status OlapScanKeys::extend_scan_key(ColumnValueRange& range, for (; iter != fixed_value_set.end(); ++iter) { _begin_scan_keys.emplace_back(); - _begin_scan_keys.back().add_value(cast_to_string(*iter)); + _begin_scan_keys.back().add_value( + cast_to_string(*iter, range.scale())); _end_scan_keys.emplace_back(); - _end_scan_keys.back().add_value(cast_to_string(*iter)); + _end_scan_keys.back().add_value( + cast_to_string(*iter, range.scale())); } if (range.contain_null()) { @@ -828,14 +892,18 @@ Status OlapScanKeys::extend_scan_key(ColumnValueRange& range, for (; iter != fixed_value_set.end(); ++iter) { // alter the first ScanKey in original place if (iter == fixed_value_set.begin()) { - _begin_scan_keys[i].add_value(cast_to_string(*iter)); - _end_scan_keys[i].add_value(cast_to_string(*iter)); + _begin_scan_keys[i].add_value( + cast_to_string(*iter, range.scale())); + _end_scan_keys[i].add_value( + cast_to_string(*iter, range.scale())); } // append follow ScanKey else { _begin_scan_keys.push_back(start_base_key_range); - _begin_scan_keys.back().add_value(cast_to_string(*iter)); + _begin_scan_keys.back().add_value( + cast_to_string(*iter, range.scale())); _end_scan_keys.push_back(end_base_key_range); - _end_scan_keys.back().add_value(cast_to_string(*iter)); + _end_scan_keys.back().add_value( + cast_to_string(*iter, range.scale())); } } @@ -856,18 +924,22 @@ Status OlapScanKeys::extend_scan_key(ColumnValueRange& range, if (_begin_scan_keys.empty()) { _begin_scan_keys.emplace_back(); - _begin_scan_keys.back().add_value(cast_to_string(range.get_range_min_value()), + _begin_scan_keys.back().add_value(cast_to_string( + range.get_range_min_value(), range.scale()), range.contain_null()); _end_scan_keys.emplace_back(); - _end_scan_keys.back().add_value(cast_to_string(range.get_range_max_value())); + _end_scan_keys.back().add_value(cast_to_string( + range.get_range_max_value(), range.scale())); } else { for (int i = 0; i < _begin_scan_keys.size(); ++i) { - _begin_scan_keys[i].add_value(cast_to_string(range.get_range_min_value()), + _begin_scan_keys[i].add_value(cast_to_string( + range.get_range_min_value(), range.scale()), range.contain_null()); } for (int i = 0; i < _end_scan_keys.size(); ++i) { - _end_scan_keys[i].add_value(cast_to_string(range.get_range_max_value())); + _end_scan_keys[i].add_value(cast_to_string( + range.get_range_max_value(), range.scale())); } } diff --git a/be/src/exec/tablet_sink.cpp b/be/src/exec/tablet_sink.cpp index 14f01f9327..141c760f8e 100644 --- a/be/src/exec/tablet_sink.cpp +++ b/be/src/exec/tablet_sink.cpp @@ -768,6 +768,8 @@ Status OlapTableSink::prepare(RuntimeState* state) { for (int i = 0; i < _output_tuple_desc->slots().size(); ++i) { auto slot = _output_tuple_desc->slots()[i]; switch (slot->type().type) { + // For DECIMAL32,DECIMAL64,DECIMAL128, we have done precision and scale conversion so just + // skip data validation here. case TYPE_DECIMALV2: _max_decimalv2_val[i].to_max_decimal(slot->type().precision, slot->type().scale); _min_decimalv2_val[i].to_min_decimal(slot->type().precision, slot->type().scale); diff --git a/be/src/exprs/anyval_util.cpp b/be/src/exprs/anyval_util.cpp index 374c5a3130..4c8a96e498 100644 --- a/be/src/exprs/anyval_util.cpp +++ b/be/src/exprs/anyval_util.cpp @@ -93,6 +93,15 @@ AnyVal* create_any_val(ObjectPool* pool, const TypeDescriptor& type) { case TYPE_DECIMALV2: return pool->add(new DecimalV2Val); + case TYPE_DECIMAL32: + return pool->add(new IntVal); + + case TYPE_DECIMAL64: + return pool->add(new BigIntVal); + + case TYPE_DECIMAL128: + return pool->add(new LargeIntVal); + case TYPE_DATE: return pool->add(new DateTimeVal); @@ -145,6 +154,21 @@ FunctionContext::TypeDesc AnyValUtil::column_type_to_type_desc(const TypeDescrip case TYPE_DATEV2: out.type = FunctionContext::TYPE_DATEV2; break; + case TYPE_DECIMAL32: + out.type = FunctionContext::TYPE_DECIMAL32; + out.precision = type.precision; + out.scale = type.scale; + break; + case TYPE_DECIMAL64: + out.type = FunctionContext::TYPE_DECIMAL64; + out.precision = type.precision; + out.scale = type.scale; + break; + case TYPE_DECIMAL128: + out.type = FunctionContext::TYPE_DECIMAL128; + out.precision = type.precision; + out.scale = type.scale; + break; case TYPE_VARCHAR: out.type = FunctionContext::TYPE_VARCHAR; out.len = type.len; diff --git a/be/src/exprs/anyval_util.h b/be/src/exprs/anyval_util.h index f0246d4c53..a6b8063be3 100644 --- a/be/src/exprs/anyval_util.h +++ b/be/src/exprs/anyval_util.h @@ -392,6 +392,18 @@ public: reinterpret_cast(dst)->val = reinterpret_cast(slot)->value; return; + case TYPE_DECIMAL32: + reinterpret_cast(dst)->val = + *reinterpret_cast(slot); + return; + case TYPE_DECIMAL64: + reinterpret_cast(dst)->val = + *reinterpret_cast(slot); + return; + case TYPE_DECIMAL128: + memcpy(&reinterpret_cast(dst)->val, slot, sizeof(__int128)); + return; + case TYPE_DATE: reinterpret_cast(slot)->to_datetime_val( reinterpret_cast(dst)); diff --git a/be/src/exprs/binary_predicate.cpp b/be/src/exprs/binary_predicate.cpp index ac815b67ac..533f2a6d82 100644 --- a/be/src/exprs/binary_predicate.cpp +++ b/be/src/exprs/binary_predicate.cpp @@ -36,10 +36,13 @@ Expr* BinaryPredicate::from_thrift(const TExprNode& node) { case TPrimitiveType::SMALLINT: return new EqSmallIntValPred(node); case TPrimitiveType::INT: + case TPrimitiveType::DECIMAL32: return new EqIntValPred(node); case TPrimitiveType::BIGINT: + case TPrimitiveType::DECIMAL64: return new EqBigIntValPred(node); case TPrimitiveType::LARGEINT: + case TPrimitiveType::DECIMAL128: return new EqLargeIntValPred(node); case TPrimitiveType::FLOAT: return new EqFloatValPred(node); @@ -67,10 +70,13 @@ Expr* BinaryPredicate::from_thrift(const TExprNode& node) { case TPrimitiveType::SMALLINT: return new NeSmallIntValPred(node); case TPrimitiveType::INT: + case TPrimitiveType::DECIMAL32: return new NeIntValPred(node); case TPrimitiveType::BIGINT: + case TPrimitiveType::DECIMAL64: return new NeBigIntValPred(node); case TPrimitiveType::LARGEINT: + case TPrimitiveType::DECIMAL128: return new NeLargeIntValPred(node); case TPrimitiveType::FLOAT: return new NeFloatValPred(node); @@ -98,10 +104,13 @@ Expr* BinaryPredicate::from_thrift(const TExprNode& node) { case TPrimitiveType::SMALLINT: return new LtSmallIntValPred(node); case TPrimitiveType::INT: + case TPrimitiveType::DECIMAL32: return new LtIntValPred(node); case TPrimitiveType::BIGINT: + case TPrimitiveType::DECIMAL64: return new LtBigIntValPred(node); case TPrimitiveType::LARGEINT: + case TPrimitiveType::DECIMAL128: return new LtLargeIntValPred(node); case TPrimitiveType::FLOAT: return new LtFloatValPred(node); @@ -129,10 +138,13 @@ Expr* BinaryPredicate::from_thrift(const TExprNode& node) { case TPrimitiveType::SMALLINT: return new LeSmallIntValPred(node); case TPrimitiveType::INT: + case TPrimitiveType::DECIMAL32: return new LeIntValPred(node); case TPrimitiveType::BIGINT: + case TPrimitiveType::DECIMAL64: return new LeBigIntValPred(node); case TPrimitiveType::LARGEINT: + case TPrimitiveType::DECIMAL128: return new LeLargeIntValPred(node); case TPrimitiveType::FLOAT: return new LeFloatValPred(node); @@ -160,10 +172,13 @@ Expr* BinaryPredicate::from_thrift(const TExprNode& node) { case TPrimitiveType::SMALLINT: return new GtSmallIntValPred(node); case TPrimitiveType::INT: + case TPrimitiveType::DECIMAL32: return new GtIntValPred(node); case TPrimitiveType::BIGINT: + case TPrimitiveType::DECIMAL64: return new GtBigIntValPred(node); case TPrimitiveType::LARGEINT: + case TPrimitiveType::DECIMAL128: return new GtLargeIntValPred(node); case TPrimitiveType::FLOAT: return new GtFloatValPred(node); @@ -191,10 +206,13 @@ Expr* BinaryPredicate::from_thrift(const TExprNode& node) { case TPrimitiveType::SMALLINT: return new GeSmallIntValPred(node); case TPrimitiveType::INT: + case TPrimitiveType::DECIMAL32: return new GeIntValPred(node); case TPrimitiveType::BIGINT: + case TPrimitiveType::DECIMAL64: return new GeBigIntValPred(node); case TPrimitiveType::LARGEINT: + case TPrimitiveType::DECIMAL128: return new GeLargeIntValPred(node); case TPrimitiveType::FLOAT: return new GeFloatValPred(node); @@ -222,10 +240,13 @@ Expr* BinaryPredicate::from_thrift(const TExprNode& node) { case TPrimitiveType::SMALLINT: return new EqForNullSmallIntValPred(node); case TPrimitiveType::INT: + case TPrimitiveType::DECIMAL32: return new EqForNullIntValPred(node); case TPrimitiveType::BIGINT: + case TPrimitiveType::DECIMAL64: return new EqForNullBigIntValPred(node); case TPrimitiveType::LARGEINT: + case TPrimitiveType::DECIMAL128: return new EqForNullLargeIntValPred(node); case TPrimitiveType::FLOAT: return new EqForNullFloatValPred(node); diff --git a/be/src/exprs/cast_functions.cpp b/be/src/exprs/cast_functions.cpp index 3584775ed5..65cf27d6d4 100644 --- a/be/src/exprs/cast_functions.cpp +++ b/be/src/exprs/cast_functions.cpp @@ -31,6 +31,7 @@ #include "util/array_parser.h" #include "util/mysql_global.h" #include "util/string_parser.hpp" +#include "vec/data_types/data_type_decimal.h" namespace doris { @@ -360,6 +361,254 @@ DateTimeVal CastFunctions::cast_to_date_val(FunctionContext* ctx, const StringVa return result; } +#define CAST_TYPE_DECIMAL32(from_type) \ + Decimal32Val CastFunctions::cast_to_decimal32_val(FunctionContext* ctx, \ + const from_type& val) { \ + if (val.is_null) { \ + return Decimal32Val::null(); \ + } \ + auto scale_to = ctx->get_return_type().scale; \ + return Decimal32Val( \ + val.val * \ + vectorized::DataTypeDecimal::get_scale_multiplier( \ + scale_to)); \ + } + +#define CAST_TYPE_DECIMAL64(from_type) \ + Decimal64Val CastFunctions::cast_to_decimal64_val(FunctionContext* ctx, \ + const from_type& val) { \ + if (val.is_null) { \ + return Decimal64Val::null(); \ + } \ + auto scale_to = ctx->get_return_type().scale; \ + return Decimal64Val( \ + val.val * \ + vectorized::DataTypeDecimal::get_scale_multiplier( \ + scale_to)); \ + } + +#define CAST_TYPE_DECIMAL128(from_type) \ + Decimal128Val CastFunctions::cast_to_decimal128_val(FunctionContext* ctx, \ + const from_type& val) { \ + if (val.is_null) { \ + return Decimal128Val::null(); \ + } \ + auto scale_to = ctx->get_return_type().scale; \ + return Decimal128Val( \ + val.val * \ + vectorized::DataTypeDecimal::get_scale_multiplier( \ + scale_to)); \ + } + +#define CAST_TYPE_DECIMAL(to_type) \ + CAST_TYPE_##to_type(TinyIntVal) CAST_TYPE_##to_type(SmallIntVal) CAST_TYPE_##to_type(IntVal) \ + CAST_TYPE_##to_type(BigIntVal) CAST_TYPE_##to_type(LargeIntVal) \ + CAST_TYPE_##to_type(FloatVal) CAST_TYPE_##to_type(DoubleVal) + +CAST_TYPE_DECIMAL(DECIMAL32) +CAST_TYPE_DECIMAL(DECIMAL64) +CAST_TYPE_DECIMAL(DECIMAL128) + +Decimal32Val CastFunctions::cast_to_decimal32_val(FunctionContext* context, + const DateTimeVal& val) { + if (val.is_null) { + return Decimal32Val::null(); + } + auto scale_to = context->get_return_type().scale; + DateTimeValue dt_value = DateTimeValue::from_datetime_val(val); + return Decimal32Val( + dt_value.to_int64() * + vectorized::DataTypeDecimal::get_scale_multiplier(scale_to)); +} + +Decimal32Val CastFunctions::cast_to_decimal32_val(FunctionContext* context, const StringVal& val) { + if (val.is_null) { + return Decimal32Val::null(); + } + std::stringstream ss; + StringParser::ParseResult result; + int32_t v = StringParser::string_to_decimal((const char*)val.ptr, val.len, + context->get_return_type().precision, + context->get_return_type().scale, &result); + return Decimal32Val(v); +} + +Decimal32Val CastFunctions::cast_to_decimal32_val(FunctionContext* ctx, const Decimal32Val& val) { + if (ctx->get_arg_type(0)->scale == ctx->get_return_type().scale && + ctx->get_arg_type(0)->precision == ctx->get_return_type().precision) { + return val; + } + if (val.is_null) { + return Decimal32Val::null(); + } + auto scale_from = ctx->get_arg_type(0)->scale; + auto scale_to = ctx->get_return_type().scale; + if (scale_to > scale_from) { + return Decimal32Val( + val.val * vectorized::DataTypeDecimal::get_scale_multiplier( + scale_to - scale_from)); + } else { + return Decimal32Val( + val.val / vectorized::DataTypeDecimal::get_scale_multiplier( + scale_from - scale_to)); + } +} + +Decimal64Val CastFunctions::cast_to_decimal64_val(FunctionContext* context, + const DateTimeVal& val) { + if (val.is_null) { + return Decimal64Val::null(); + } + auto scale_to = context->get_return_type().scale; + DateTimeValue dt_value = DateTimeValue::from_datetime_val(val); + return Decimal64Val( + dt_value.to_int64() * + vectorized::DataTypeDecimal::get_scale_multiplier(scale_to)); +} + +Decimal64Val CastFunctions::cast_to_decimal64_val(FunctionContext* context, const StringVal& val) { + if (val.is_null) { + return Decimal64Val::null(); + } + std::stringstream ss; + StringParser::ParseResult result; + int64_t v = StringParser::string_to_decimal((const char*)val.ptr, val.len, + context->get_return_type().precision, + context->get_return_type().scale, &result); + return Decimal64Val(v); +} + +Decimal64Val CastFunctions::cast_to_decimal64_val(FunctionContext* ctx, const Decimal64Val& val) { + if (ctx->get_arg_type(0)->scale == ctx->get_return_type().scale && + ctx->get_arg_type(0)->precision == ctx->get_return_type().precision) { + return val; + } + if (val.is_null) { + return Decimal64Val::null(); + } + auto scale_from = ctx->get_arg_type(0)->scale; + auto scale_to = ctx->get_return_type().scale; + if (scale_to > scale_from) { + return Decimal64Val( + val.val * vectorized::DataTypeDecimal::get_scale_multiplier( + scale_to - scale_from)); + } else { + return Decimal64Val( + val.val / vectorized::DataTypeDecimal::get_scale_multiplier( + scale_from - scale_to)); + } +} + +Decimal128Val CastFunctions::cast_to_decimal128_val(FunctionContext* context, + const DateTimeVal& val) { + if (val.is_null) { + return Decimal128Val::null(); + } + auto scale_to = context->get_return_type().scale; + DateTimeValue dt_value = DateTimeValue::from_datetime_val(val); + return Decimal128Val( + dt_value.to_int64() * + vectorized::DataTypeDecimal::get_scale_multiplier(scale_to)); +} + +Decimal128Val CastFunctions::cast_to_decimal128_val(FunctionContext* context, + const StringVal& val) { + if (val.is_null) { + return Decimal128Val::null(); + } + std::stringstream ss; + StringParser::ParseResult result; + int128_t v = StringParser::string_to_decimal( + (const char*)val.ptr, val.len, context->get_return_type().precision, + context->get_return_type().scale, &result); + return Decimal128Val(v); +} + +Decimal128Val CastFunctions::cast_to_decimal128_val(FunctionContext* ctx, + const Decimal128Val& val) { + if (ctx->get_arg_type(0)->scale == ctx->get_return_type().scale && + ctx->get_arg_type(0)->precision == ctx->get_return_type().precision) { + return val; + } + if (val.is_null) { + return Decimal128Val::null(); + } + auto scale_from = ctx->get_arg_type(0)->scale; + auto scale_to = ctx->get_return_type().scale; + if (scale_to > scale_from) { + return Decimal128Val( + val.val * vectorized::DataTypeDecimal::get_scale_multiplier( + scale_to - scale_from)); + } else { + return Decimal128Val( + val.val / vectorized::DataTypeDecimal::get_scale_multiplier( + scale_from - scale_to)); + } +} + +Decimal128Val CastFunctions::cast_to_decimal128_val(FunctionContext* ctx, const Decimal32Val& val) { + if (ctx->get_arg_type(0)->scale == ctx->get_return_type().scale && + ctx->get_arg_type(0)->precision == ctx->get_return_type().precision) { + return Decimal128Val(val.val); + } + if (val.is_null) { + return Decimal128Val::null(); + } + auto scale_from = ctx->get_arg_type(0)->scale; + auto scale_to = ctx->get_return_type().scale; + if (scale_to > scale_from) { + return Decimal128Val( + val.val * vectorized::DataTypeDecimal::get_scale_multiplier( + scale_to - scale_from)); + } else { + return Decimal128Val( + val.val / vectorized::DataTypeDecimal::get_scale_multiplier( + scale_from - scale_to)); + } +} + +Decimal128Val CastFunctions::cast_to_decimal128_val(FunctionContext* ctx, const Decimal64Val& val) { + if (ctx->get_arg_type(0)->scale == ctx->get_return_type().scale && + ctx->get_arg_type(0)->precision == ctx->get_return_type().precision) { + return Decimal128Val(val.val); + } + if (val.is_null) { + return Decimal128Val::null(); + } + auto scale_from = ctx->get_arg_type(0)->scale; + auto scale_to = ctx->get_return_type().scale; + if (scale_to > scale_from) { + return Decimal128Val( + val.val * vectorized::DataTypeDecimal::get_scale_multiplier( + scale_to - scale_from)); + } else { + return Decimal128Val( + val.val / vectorized::DataTypeDecimal::get_scale_multiplier( + scale_from - scale_to)); + } +} + +Decimal64Val CastFunctions::cast_to_decimal64_val(FunctionContext* ctx, const Decimal32Val& val) { + if (ctx->get_arg_type(0)->scale == ctx->get_return_type().scale && + ctx->get_arg_type(0)->precision == ctx->get_return_type().precision) { + return Decimal64Val(val.val); + } + if (val.is_null) { + return Decimal64Val::null(); + } + auto scale_from = ctx->get_arg_type(0)->scale; + auto scale_to = ctx->get_return_type().scale; + if (scale_to > scale_from) { + return Decimal64Val( + val.val * vectorized::DataTypeDecimal::get_scale_multiplier( + scale_to - scale_from)); + } else { + return Decimal64Val( + val.val / vectorized::DataTypeDecimal::get_scale_multiplier( + scale_from - scale_to)); + } +} + CollectionVal CastFunctions::cast_to_array_val(FunctionContext* context, const StringVal& val) { CollectionVal array_val; Status status = ArrayParser::parse(array_val, context, val); diff --git a/be/src/exprs/cast_functions.h b/be/src/exprs/cast_functions.h index b45cfac3f6..497f49bf80 100644 --- a/be/src/exprs/cast_functions.h +++ b/be/src/exprs/cast_functions.h @@ -139,6 +139,27 @@ public: static DateTimeVal cast_to_date_val(FunctionContext* context, const DateTimeVal& val); static DateTimeVal cast_to_date_val(FunctionContext* context, const StringVal& val); +#define DECLARE_CAST_TO_DECIMAL(width) \ + static Decimal##width##Val cast_to_decimal##width##_val(FunctionContext*, const TinyIntVal&); \ + static Decimal##width##Val cast_to_decimal##width##_val(FunctionContext*, const SmallIntVal&); \ + static Decimal##width##Val cast_to_decimal##width##_val(FunctionContext*, const IntVal&); \ + static Decimal##width##Val cast_to_decimal##width##_val(FunctionContext*, const BigIntVal&); \ + static Decimal##width##Val cast_to_decimal##width##_val(FunctionContext*, const LargeIntVal&); \ + static Decimal##width##Val cast_to_decimal##width##_val(FunctionContext*, const FloatVal&); \ + static Decimal##width##Val cast_to_decimal##width##_val(FunctionContext*, const DoubleVal&); \ + static Decimal##width##Val cast_to_decimal##width##_val(FunctionContext*, const DateTimeVal&); \ + static Decimal##width##Val cast_to_decimal##width##_val(FunctionContext*, const StringVal&); \ + static Decimal##width##Val cast_to_decimal##width##_val(FunctionContext*, \ + const Decimal##width##Val&); + + DECLARE_CAST_TO_DECIMAL(32) + DECLARE_CAST_TO_DECIMAL(64) + DECLARE_CAST_TO_DECIMAL(128) + + static Decimal64Val cast_to_decimal64_val(FunctionContext*, const Decimal32Val&); + static Decimal128Val cast_to_decimal128_val(FunctionContext*, const Decimal32Val&); + static Decimal128Val cast_to_decimal128_val(FunctionContext*, const Decimal64Val&); + static CollectionVal cast_to_array_val(FunctionContext* context, const StringVal& val); }; diff --git a/be/src/exprs/create_predicate_function.h b/be/src/exprs/create_predicate_function.h index ebd268b8cf..36fed51383 100644 --- a/be/src/exprs/create_predicate_function.h +++ b/be/src/exprs/create_predicate_function.h @@ -38,17 +38,10 @@ public: using BasePtr = HybridSetBase*; template static BasePtr get_function() { - if constexpr (is_vec) { - using CppType = typename VecPrimitiveTypeTraits::CppType; - using Set = std::conditional_t, StringValueSet, - HybridSet>; - return new (std::nothrow) Set(); - } else { - using CppType = typename PrimitiveTypeTraits::CppType; - using Set = std::conditional_t, StringValueSet, - HybridSet>; - return new (std::nothrow) Set(); - } + using CppType = typename PrimitiveTypeTraits::CppType; + using Set = std::conditional_t, StringValueSet, + HybridSet>; + return new (std::nothrow) Set(); }; }; @@ -109,6 +102,12 @@ typename Traits::BasePtr create_predicate_function(PrimitiveType type) { return Creator::template create(); case TYPE_STRING: return Creator::template create(); + case TYPE_DECIMAL32: + return Creator::template create(); + case TYPE_DECIMAL64: + return Creator::template create(); + case TYPE_DECIMAL128: + return Creator::template create(); default: DCHECK(false) << "Invalid type."; diff --git a/be/src/exprs/expr.cpp b/be/src/exprs/expr.cpp index 695bc5a44e..1cbacf3445 100644 --- a/be/src/exprs/expr.cpp +++ b/be/src/exprs/expr.cpp @@ -701,6 +701,18 @@ doris_udf::AnyVal* Expr::get_const_val(ExprContext* context) { _constant_val.reset(new DecimalV2Val(get_decimalv2_val(context, nullptr))); break; } + case TYPE_DECIMAL32: { + _constant_val.reset(new Decimal32Val(get_decimal32_val(context, nullptr))); + break; + } + case TYPE_DECIMAL64: { + _constant_val.reset(new Decimal64Val(get_decimal64_val(context, nullptr))); + break; + } + case TYPE_DECIMAL128: { + _constant_val.reset(new Decimal128Val(get_decimal128_val(context, nullptr))); + break; + } case TYPE_NULL: { _constant_val.reset(new AnyVal(true)); break; @@ -760,6 +772,18 @@ LargeIntVal Expr::get_large_int_val(ExprContext* context, TupleRow* row) { return LargeIntVal::null(); // (*(int64_t*)get_value(row)); } +Decimal32Val Expr::get_decimal32_val(ExprContext* context, TupleRow* row) { + return Decimal32Val::null(); // (*(int32_t*)get_value(row)); +} + +Decimal64Val Expr::get_decimal64_val(ExprContext* context, TupleRow* row) { + return Decimal64Val::null(); +} + +Decimal128Val Expr::get_decimal128_val(ExprContext* context, TupleRow* row) { + return Decimal128Val::null(); +} + FloatVal Expr::get_float_val(ExprContext* context, TupleRow* row) { return FloatVal::null(); // (*(float*)get_value(row)); } diff --git a/be/src/exprs/expr.h b/be/src/exprs/expr.h index 2fe979cc17..9b114b664b 100644 --- a/be/src/exprs/expr.h +++ b/be/src/exprs/expr.h @@ -34,6 +34,9 @@ #include "runtime/tuple.h" #include "runtime/tuple_row.h" #include "udf/udf.h" +#include "util/string_parser.hpp" +#include "vec/data_types/data_type_decimal.h" +#include "vec/io/io_helper.h" #undef USING_DORIS_UDF #define USING_DORIS_UDF using namespace doris_udf @@ -104,6 +107,10 @@ public: virtual DecimalV2Val get_decimalv2_val(ExprContext* context, TupleRow*); virtual CollectionVal get_array_val(ExprContext* context, TupleRow*); + virtual Decimal32Val get_decimal32_val(ExprContext* context, TupleRow*); + virtual Decimal64Val get_decimal64_val(ExprContext* context, TupleRow*); + virtual Decimal128Val get_decimal128_val(ExprContext* context, TupleRow*); + // Get the number of digits after the decimal that should be displayed for this // value. Returns -1 if no scale has been specified (currently the scale is only set for // doubles set by RoundUpTo). get_value() must have already been called. @@ -447,39 +454,48 @@ private: int _fn_ctx_idx_end = 0; }; -template -Status create_texpr_literal_node(const void* data, TExprNode* node) { - if constexpr (std::is_same_v) { - auto origin_value = reinterpret_cast(data); +template +Status create_texpr_literal_node(const void* data, TExprNode* node, int precision = 0, + int scale = 0) { + if constexpr (T == TYPE_BOOLEAN) { + auto origin_value = reinterpret_cast(data); TBoolLiteral boolLiteral; (*node).__set_node_type(TExprNodeType::BOOL_LITERAL); boolLiteral.__set_value(*origin_value); (*node).__set_bool_literal(boolLiteral); (*node).__set_type(create_type_desc(PrimitiveType::TYPE_BOOLEAN)); - } else if constexpr (std::is_same_v || std::is_same_v || - std::is_same_v || std::is_same_v) { - auto origin_value = reinterpret_cast(data); + } else if constexpr (T == TYPE_TINYINT) { + auto origin_value = reinterpret_cast(data); (*node).__set_node_type(TExprNodeType::INT_LITERAL); TIntLiteral intLiteral; intLiteral.__set_value(*origin_value); (*node).__set_int_literal(intLiteral); - if constexpr (std::is_same_v) { - (*node).__set_type(create_type_desc(PrimitiveType::TYPE_TINYINT)); - } else if constexpr (std::is_same_v) { - (*node).__set_type(create_type_desc(PrimitiveType::TYPE_SMALLINT)); - } else if constexpr (std::is_same_v) { - (*node).__set_type(create_type_desc(PrimitiveType::TYPE_INT)); - } else if constexpr (std::is_same_v) { - (*node).__set_type(create_type_desc(PrimitiveType::TYPE_BIGINT)); - } - } else if constexpr (std::is_same_v<__int128_t, T>) { - auto origin_value = reinterpret_cast(data); + } else if constexpr (T == TYPE_SMALLINT) { + auto origin_value = reinterpret_cast(data); + (*node).__set_node_type(TExprNodeType::INT_LITERAL); + TIntLiteral intLiteral; + intLiteral.__set_value(*origin_value); + (*node).__set_int_literal(intLiteral); + } else if constexpr (T == TYPE_INT) { + auto origin_value = reinterpret_cast(data); + (*node).__set_node_type(TExprNodeType::INT_LITERAL); + TIntLiteral intLiteral; + intLiteral.__set_value(*origin_value); + (*node).__set_int_literal(intLiteral); + } else if constexpr (T == TYPE_BIGINT) { + auto origin_value = reinterpret_cast(data); + (*node).__set_node_type(TExprNodeType::INT_LITERAL); + TIntLiteral intLiteral; + intLiteral.__set_value(*origin_value); + (*node).__set_int_literal(intLiteral); + } else if constexpr (T == TYPE_LARGEINT) { + auto origin_value = reinterpret_cast(data); (*node).__set_node_type(TExprNodeType::LARGE_INT_LITERAL); TLargeIntLiteral large_int_literal; large_int_literal.__set_value(LargeIntValue::to_string(*origin_value)); (*node).__set_large_int_literal(large_int_literal); (*node).__set_type(create_type_desc(PrimitiveType::TYPE_LARGEINT)); - } else if constexpr (std::is_same_v) { + } else if constexpr ((T == TYPE_DATE) || (T == TYPE_DATETIME) || (T == TYPE_TIME)) { auto origin_value = reinterpret_cast(data); TDateLiteral date_literal; char convert_buffer[30]; @@ -494,7 +510,7 @@ Status create_texpr_literal_node(const void* data, TExprNode* node) { } else if (origin_value->type() == TimeType::TIME_TIME) { (*node).__set_type(create_type_desc(PrimitiveType::TYPE_TIME)); } - } else if constexpr (std::is_same_v) { + } else if constexpr (T == TYPE_DATEV2) { auto origin_value = reinterpret_cast(data); TDateLiteral date_literal; char convert_buffer[30]; @@ -503,25 +519,55 @@ Status create_texpr_literal_node(const void* data, TExprNode* node) { (*node).__set_date_literal(date_literal); (*node).__set_node_type(TExprNodeType::DATE_LITERAL); (*node).__set_type(create_type_desc(PrimitiveType::TYPE_DATEV2)); - } else if constexpr (std::is_same_v) { + } else if constexpr (T == TYPE_DECIMALV2) { auto origin_value = reinterpret_cast(data); (*node).__set_node_type(TExprNodeType::DECIMAL_LITERAL); TDecimalLiteral decimal_literal; decimal_literal.__set_value(origin_value->to_string()); (*node).__set_decimal_literal(decimal_literal); - (*node).__set_type(create_type_desc(PrimitiveType::TYPE_DECIMALV2)); - } else if constexpr (std::is_same_v || std::is_same_v) { - auto origin_value = reinterpret_cast(data); + (*node).__set_type(create_type_desc(PrimitiveType::TYPE_DECIMALV2, precision, scale)); + } else if constexpr (T == TYPE_DECIMAL32) { + auto origin_value = reinterpret_cast(data); + (*node).__set_node_type(TExprNodeType::DECIMAL_LITERAL); + TDecimalLiteral decimal_literal; + std::stringstream ss; + vectorized::write_text(*origin_value, scale, ss); + decimal_literal.__set_value(ss.str()); + (*node).__set_decimal_literal(decimal_literal); + (*node).__set_type(create_type_desc(PrimitiveType::TYPE_DECIMAL32, precision, scale)); + } else if constexpr (T == TYPE_DECIMAL64) { + auto origin_value = reinterpret_cast(data); + (*node).__set_node_type(TExprNodeType::DECIMAL_LITERAL); + TDecimalLiteral decimal_literal; + std::stringstream ss; + vectorized::write_text(*origin_value, scale, ss); + decimal_literal.__set_value(ss.str()); + (*node).__set_decimal_literal(decimal_literal); + (*node).__set_type(create_type_desc(PrimitiveType::TYPE_DECIMAL64, precision, scale)); + } else if constexpr (T == TYPE_DECIMAL128) { + auto origin_value = reinterpret_cast(data); + (*node).__set_node_type(TExprNodeType::DECIMAL_LITERAL); + TDecimalLiteral decimal_literal; + std::stringstream ss; + vectorized::write_text(*origin_value, scale, ss); + decimal_literal.__set_value(ss.str()); + (*node).__set_decimal_literal(decimal_literal); + (*node).__set_type(create_type_desc(PrimitiveType::TYPE_DECIMAL128, precision, scale)); + } else if constexpr (T == TYPE_FLOAT) { + auto origin_value = reinterpret_cast(data); (*node).__set_node_type(TExprNodeType::FLOAT_LITERAL); TFloatLiteral float_literal; float_literal.__set_value(*origin_value); (*node).__set_float_literal(float_literal); - if constexpr (std::is_same_v) { - (*node).__set_type(create_type_desc(PrimitiveType::TYPE_FLOAT)); - } else if constexpr (std::is_same_v) { - (*node).__set_type(create_type_desc(PrimitiveType::TYPE_DOUBLE)); - } - } else if constexpr (std::is_same_v) { + (*node).__set_type(create_type_desc(PrimitiveType::TYPE_FLOAT)); + } else if constexpr (T == TYPE_DOUBLE) { + auto origin_value = reinterpret_cast(data); + (*node).__set_node_type(TExprNodeType::FLOAT_LITERAL); + TFloatLiteral float_literal; + float_literal.__set_value(*origin_value); + (*node).__set_float_literal(float_literal); + (*node).__set_type(create_type_desc(PrimitiveType::TYPE_DOUBLE)); + } else if constexpr ((T == TYPE_STRING) || (T == TYPE_CHAR) || (T == TYPE_VARCHAR)) { auto origin_value = reinterpret_cast(data); (*node).__set_node_type(TExprNodeType::STRING_LITERAL); TStringLiteral string_literal; diff --git a/be/src/exprs/expr_context.cpp b/be/src/exprs/expr_context.cpp index 199076cdcb..139127b5c8 100644 --- a/be/src/exprs/expr_context.cpp +++ b/be/src/exprs/expr_context.cpp @@ -171,7 +171,7 @@ bool ExprContext::is_nullable() { return false; } -void* ExprContext::get_value(Expr* e, TupleRow* row) { +void* ExprContext::get_value(Expr* e, TupleRow* row, int precision, int scale) { switch (e->_type.type) { case TYPE_NULL: { return nullptr; @@ -274,6 +274,30 @@ void* ExprContext::get_value(Expr* e, TupleRow* row) { _result.decimalv2_val = DecimalV2Value::from_decimal_val(v); return &_result.decimalv2_val; } + case TYPE_DECIMAL32: { + doris_udf::Decimal32Val v = e->get_decimal32_val(this, row); + if (v.is_null) { + return nullptr; + } + _result.int_val = v.val; + return &_result.int_val; + } + case TYPE_DECIMAL64: { + doris_udf::Decimal64Val v = e->get_decimal64_val(this, row); + if (v.is_null) { + return nullptr; + } + _result.bigint_val = v.val; + return &_result.bigint_val; + } + case TYPE_DECIMAL128: { + doris_udf::Decimal128Val v = e->get_decimal128_val(this, row); + if (v.is_null) { + return nullptr; + } + _result.large_int_val = v.val; + return &_result.large_int_val; + } case TYPE_ARRAY: { doris_udf::CollectionVal v = e->get_array_val(this, row); if (v.is_null) { diff --git a/be/src/exprs/expr_context.h b/be/src/exprs/expr_context.h index 6543ff1b9c..14cfa86df2 100644 --- a/be/src/exprs/expr_context.h +++ b/be/src/exprs/expr_context.h @@ -194,7 +194,7 @@ private: /// Calls the appropriate Get*Val() function on 'e' and stores the result in result_. /// This is used by Exprs to call GetValue() on a child expr, rather than root_. - void* get_value(Expr* e, TupleRow* row); + void* get_value(Expr* e, TupleRow* row, int precision = 0, int scale = 0); }; inline void* ExprContext::get_value(TupleRow* row) { diff --git a/be/src/exprs/hybrid_set.h b/be/src/exprs/hybrid_set.h index 642d99af1b..16b740e951 100644 --- a/be/src/exprs/hybrid_set.h +++ b/be/src/exprs/hybrid_set.h @@ -49,7 +49,8 @@ public: virtual bool find(void* data, size_t) = 0; virtual Status to_vexpr_list(doris::ObjectPool* pool, - std::vector* vexpr_list) = 0; + std::vector* vexpr_list, int precision, + int scale) = 0; class IteratorBase { public: IteratorBase() {} @@ -62,21 +63,25 @@ public: virtual IteratorBase* begin() = 0; }; -template +template class HybridSet : public HybridSetBase { public: + using CppType = std::conditional_t::CppType, + typename PrimitiveTypeTraits::CppType>; + HybridSet() = default; ~HybridSet() override = default; - virtual Status to_vexpr_list(doris::ObjectPool* pool, - std::vector* vexpr_list) override { + Status to_vexpr_list(doris::ObjectPool* pool, + std::vector* vexpr_list, int precision, + int scale) override { HybridSetBase::IteratorBase* it = begin(); DCHECK(it != nullptr); while (it->has_next()) { TExprNode node; const void* v = it->get_value(); - create_texpr_literal_node(v, &node); + create_texpr_literal_node(v, &node, precision, scale); vexpr_list->push_back(pool->add(new doris::vectorized::VLiteral(node))); it->next(); } @@ -86,26 +91,26 @@ public: void insert(const void* data) override { if (data == nullptr) return; - if constexpr (sizeof(T) >= 16) { - // for largeint, it will core dump with no memcpy - T value; - memcpy(&value, data, sizeof(T)); + if constexpr (sizeof(CppType) >= 16) { + // for large int, it will core dump with no memcpy + CppType value; + memcpy(&value, data, sizeof(CppType)); _set.insert(value); } else { - _set.insert(*reinterpret_cast(data)); + _set.insert(*reinterpret_cast(data)); } } void insert(void* data, size_t) override { insert(data); } void insert(HybridSetBase* set) override { - HybridSet* hybrid_set = reinterpret_cast*>(set); + HybridSet* hybrid_set = reinterpret_cast*>(set); _set.insert(hybrid_set->_set.begin(), hybrid_set->_set.end()); } int size() override { return _set.size(); } bool find(void* data) override { - auto it = _set.find(*reinterpret_cast(data)); + auto it = _set.find(*reinterpret_cast(data)); return !(it == _set.end()); } @@ -128,11 +133,11 @@ public: }; IteratorBase* begin() override { - return _pool.add(new (std::nothrow) Iterator(_set.begin(), _set.end())); + return _pool.add(new (std::nothrow) Iterator(_set.begin(), _set.end())); } private: - phmap::flat_hash_set _set; + phmap::flat_hash_set _set; ObjectPool _pool; }; @@ -142,14 +147,15 @@ public: ~StringValueSet() override = default; - virtual Status to_vexpr_list(doris::ObjectPool* pool, - std::vector* vexpr_list) override { + Status to_vexpr_list(doris::ObjectPool* pool, + std::vector* vexpr_list, int precision, + int scale) override { HybridSetBase::IteratorBase* it = begin(); DCHECK(it != nullptr); while (it->has_next()) { TExprNode node; const void* v = it->get_value(); - create_texpr_literal_node(v, &node); + create_texpr_literal_node(v, &node); vexpr_list->push_back(pool->add(new doris::vectorized::VLiteral(node))); it->next(); } diff --git a/be/src/exprs/literal.cpp b/be/src/exprs/literal.cpp index 462ec259be..de0e214bfe 100644 --- a/be/src/exprs/literal.cpp +++ b/be/src/exprs/literal.cpp @@ -100,6 +100,14 @@ Literal::Literal(const TExprNode& node) : Expr(node) { _value.decimalv2_val = DecimalV2Value(node.decimal_literal.value); break; } + case TYPE_DECIMAL32: + case TYPE_DECIMAL64: + case TYPE_DECIMAL128: { + DCHECK_EQ(node.node_type, TExprNodeType::DECIMAL_LITERAL); + DCHECK(node.__isset.decimal_literal); + _value.set_string_val(node.decimal_literal.value); + break; + } case TYPE_ARRAY: { DCHECK_EQ(node.node_type, TExprNodeType::ARRAY_LITERAL); // init in prepare @@ -129,20 +137,56 @@ SmallIntVal Literal::get_small_int_val(ExprContext* context, TupleRow* row) { } IntVal Literal::get_int_val(ExprContext* context, TupleRow* row) { - DCHECK_EQ(_type.type, TYPE_INT) << _type; + DCHECK(_type.type == TYPE_INT) << _type; return IntVal(_value.int_val); } BigIntVal Literal::get_big_int_val(ExprContext* context, TupleRow* row) { - DCHECK_EQ(_type.type, TYPE_BIGINT) << _type; + DCHECK(_type.type == TYPE_BIGINT) << _type; return BigIntVal(_value.bigint_val); } LargeIntVal Literal::get_large_int_val(ExprContext* context, TupleRow* row) { - DCHECK_EQ(_type.type, TYPE_LARGEINT) << _type; + DCHECK(_type.type == TYPE_LARGEINT) << _type; return LargeIntVal(_value.large_int_val); } +Decimal32Val Literal::get_decimal32_val(ExprContext* context, TupleRow* row) { + DCHECK(_type.type == TYPE_DECIMAL32) << _type; + StringParser::ParseResult result; + auto decimal32_value = StringParser::string_to_decimal( + _value.string_val.ptr, _value.string_val.len, _type.precision, _type.scale, &result); + if (result == StringParser::ParseResult::PARSE_SUCCESS) { + return Decimal32Val(decimal32_value); + } else { + return Decimal32Val::null(); + } +} + +Decimal64Val Literal::get_decimal64_val(ExprContext* context, TupleRow* row) { + DCHECK(_type.type == TYPE_DECIMAL64) << _type; + StringParser::ParseResult result; + auto decimal_value = StringParser::string_to_decimal( + _value.string_val.ptr, _value.string_val.len, _type.precision, _type.scale, &result); + if (result == StringParser::ParseResult::PARSE_SUCCESS) { + return Decimal64Val(decimal_value); + } else { + return Decimal64Val::null(); + } +} + +Decimal128Val Literal::get_decimal128_val(ExprContext* context, TupleRow* row) { + DCHECK(_type.type == TYPE_DECIMAL128) << _type; + StringParser::ParseResult result; + auto decimal_value = StringParser::string_to_decimal( + _value.string_val.ptr, _value.string_val.len, _type.precision, _type.scale, &result); + if (result == StringParser::ParseResult::PARSE_SUCCESS) { + return Decimal128Val(decimal_value); + } else { + return Decimal128Val::null(); + } +} + FloatVal Literal::get_float_val(ExprContext* context, TupleRow* row) { DCHECK_EQ(_type.type, TYPE_FLOAT) << _type; return FloatVal(_value.float_val); diff --git a/be/src/exprs/literal.h b/be/src/exprs/literal.h index fc77e06ad7..87372b1590 100644 --- a/be/src/exprs/literal.h +++ b/be/src/exprs/literal.h @@ -46,6 +46,9 @@ public: virtual DateTimeVal get_datetime_val(ExprContext* context, TupleRow*) override; virtual StringVal get_string_val(ExprContext* context, TupleRow* row) override; virtual CollectionVal get_array_val(ExprContext* context, TupleRow*) override; + virtual Decimal32Val get_decimal32_val(ExprContext* context, TupleRow*) override; + virtual Decimal64Val get_decimal64_val(ExprContext* context, TupleRow*) override; + virtual Decimal128Val get_decimal128_val(ExprContext* context, TupleRow*) override; // init val before use virtual Status prepare(RuntimeState* state, const RowDescriptor& row_desc, ExprContext* context) override; diff --git a/be/src/exprs/runtime_filter.cpp b/be/src/exprs/runtime_filter.cpp index fed22825b1..262247f864 100644 --- a/be/src/exprs/runtime_filter.cpp +++ b/be/src/exprs/runtime_filter.cpp @@ -68,6 +68,9 @@ TExprNodeType::type get_expr_node_type(PrimitiveType type) { return TExprNodeType::FLOAT_LITERAL; break; + case TYPE_DECIMAL32: + case TYPE_DECIMAL64: + case TYPE_DECIMAL128: case TYPE_DECIMALV2: return TExprNodeType::DECIMAL_LITERAL; @@ -118,6 +121,12 @@ PColumnType to_proto(PrimitiveType type) { return PColumnType::COLUMN_TYPE_DATETIME; case TYPE_DECIMALV2: return PColumnType::COLUMN_TYPE_DECIMALV2; + case TYPE_DECIMAL32: + return PColumnType::COLUMN_TYPE_DECIMAL32; + case TYPE_DECIMAL64: + return PColumnType::COLUMN_TYPE_DECIMAL64; + case TYPE_DECIMAL128: + return PColumnType::COLUMN_TYPE_DECIMAL128; case TYPE_CHAR: return PColumnType::COLUMN_TYPE_CHAR; case TYPE_VARCHAR: @@ -159,6 +168,12 @@ PrimitiveType to_primitive_type(PColumnType type) { return TYPE_DATETIME; case PColumnType::COLUMN_TYPE_DECIMALV2: return TYPE_DECIMALV2; + case PColumnType::COLUMN_TYPE_DECIMAL32: + return TYPE_DECIMAL32; + case PColumnType::COLUMN_TYPE_DECIMAL64: + return TYPE_DECIMAL64; + case PColumnType::COLUMN_TYPE_DECIMAL128: + return TYPE_DECIMAL128; case PColumnType::COLUMN_TYPE_VARCHAR: return TYPE_VARCHAR; case PColumnType::COLUMN_TYPE_CHAR: @@ -204,59 +219,80 @@ PFilterType get_type(RuntimeFilterType type) { } template -Status create_literal(ObjectPool* pool, PrimitiveType type, const void* data, void** expr) { +Status create_literal(ObjectPool* pool, const TypeDescriptor& type, const void* data, void** expr) { TExprNode node; - switch (type) { + switch (type.type) { case TYPE_BOOLEAN: { - create_texpr_literal_node(data, &node); + create_texpr_literal_node(data, &node); break; } case TYPE_TINYINT: { - create_texpr_literal_node(data, &node); + create_texpr_literal_node(data, &node); break; } case TYPE_SMALLINT: { - create_texpr_literal_node(data, &node); + create_texpr_literal_node(data, &node); break; } case TYPE_INT: { - create_texpr_literal_node(data, &node); + create_texpr_literal_node(data, &node); break; } case TYPE_BIGINT: { - create_texpr_literal_node(data, &node); + create_texpr_literal_node(data, &node); break; } case TYPE_LARGEINT: { - create_texpr_literal_node(data, &node); + create_texpr_literal_node(data, &node); break; } case TYPE_FLOAT: { - create_texpr_literal_node(data, &node); + create_texpr_literal_node(data, &node); break; } case TYPE_DOUBLE: { - create_texpr_literal_node(data, &node); + create_texpr_literal_node(data, &node); break; } case TYPE_DATEV2: { - create_texpr_literal_node(data, &node); + create_texpr_literal_node(data, &node); + break; + } + case TYPE_DATE: { + create_texpr_literal_node(data, &node); break; } - case TYPE_DATE: case TYPE_DATETIME: { - create_texpr_literal_node(data, &node); + create_texpr_literal_node(data, &node); break; } case TYPE_DECIMALV2: { - create_texpr_literal_node(data, &node); + create_texpr_literal_node(data, &node, type.precision, type.scale); + break; + } + case TYPE_DECIMAL32: { + create_texpr_literal_node(data, &node, type.precision, type.scale); + break; + } + case TYPE_DECIMAL64: { + create_texpr_literal_node(data, &node, type.precision, type.scale); + break; + } + case TYPE_DECIMAL128: { + create_texpr_literal_node(data, &node, type.precision, type.scale); + break; + } + case TYPE_CHAR: { + create_texpr_literal_node(data, &node); + break; + } + case TYPE_VARCHAR: { + create_texpr_literal_node(data, &node); break; } - case TYPE_CHAR: - case TYPE_VARCHAR: case TYPE_STRING: { - create_texpr_literal_node(data, &node); + create_texpr_literal_node(data, &node); break; } default: @@ -293,7 +329,7 @@ BinaryPredicate* create_bin_predicate(ObjectPool* pool, PrimitiveType prim_type, return (BinaryPredicate*)pool->add(BinaryPredicate::from_thrift(node)); } -Status create_vbin_predicate(ObjectPool* pool, PrimitiveType prim_type, TExprOpcode::type opcode, +Status create_vbin_predicate(ObjectPool* pool, const TypeDescriptor& type, TExprOpcode::type opcode, doris::vectorized::VExpr** expr, TExprNode* tnode) { TExprNode node; TScalarType tscalar_type; @@ -306,9 +342,9 @@ Status create_vbin_predicate(ObjectPool* pool, PrimitiveType prim_type, TExprOpc node.__set_type(t_type_desc); node.__set_opcode(opcode); node.__set_vector_opcode(opcode); - node.__set_child_type(to_thrift(prim_type)); + node.__set_child_type(to_thrift(type.type)); node.__set_num_children(2); - node.__set_output_scale(-1); + node.__set_output_scale(type.scale); node.__set_node_type(TExprNodeType::BINARY_PRED); TFunction fn; TFunctionName fn_name; @@ -330,7 +366,9 @@ Status create_vbin_predicate(ObjectPool* pool, PrimitiveType prim_type, TExprOpc TTypeNode type_node; type_node.__set_type(TTypeNodeType::SCALAR); TScalarType scalar_type; - scalar_type.__set_type(to_thrift(prim_type)); + scalar_type.__set_type(to_thrift(type.type)); + scalar_type.__set_precision(type.precision); + scalar_type.__set_scale(type.scale); type_node.__set_scalar_type(scalar_type); std::vector type_nodes; @@ -708,6 +746,34 @@ public: }); break; } + case TYPE_DECIMAL32: { + batch_assign(in_filter, [](std::unique_ptr& set, PColumnValue& column, + ObjectPool* pool) { + int32_t decimal_32_val = column.intval(); + set->insert(&decimal_32_val); + }); + break; + } + case TYPE_DECIMAL64: { + batch_assign(in_filter, [](std::unique_ptr& set, PColumnValue& column, + ObjectPool* pool) { + int64_t decimal_64_val = column.longval(); + set->insert(&decimal_64_val); + }); + break; + } + case TYPE_DECIMAL128: { + batch_assign(in_filter, [](std::unique_ptr& set, PColumnValue& column, + ObjectPool* pool) { + auto string_val = column.stringval(); + StringParser::ParseResult result; + int128_t int128_val = StringParser::string_to_int( + string_val.c_str(), string_val.length(), &result); + DCHECK(result == StringParser::PARSE_SUCCESS); + set->insert(&int128_val); + }); + break; + } case TYPE_VARCHAR: case TYPE_CHAR: case TYPE_STRING: { @@ -814,6 +880,28 @@ public: DecimalV2Value max_val(max_val_ref); return _minmax_func->assign(&min_val, &max_val); } + case TYPE_DECIMAL32: { + int32_t min_val = minmax_filter->min_val().intval(); + int32_t max_val = minmax_filter->max_val().intval(); + return _minmax_func->assign(&min_val, &max_val); + } + case TYPE_DECIMAL64: { + int64_t min_val = minmax_filter->min_val().longval(); + int64_t max_val = minmax_filter->max_val().longval(); + return _minmax_func->assign(&min_val, &max_val); + } + case TYPE_DECIMAL128: { + auto min_string_val = minmax_filter->min_val().stringval(); + auto max_string_val = minmax_filter->max_val().stringval(); + StringParser::ParseResult result; + int128_t min_val = StringParser::string_to_int( + min_string_val.c_str(), min_string_val.length(), &result); + DCHECK(result == StringParser::PARSE_SUCCESS); + int128_t max_val = StringParser::string_to_int( + max_string_val.c_str(), max_string_val.length(), &result); + DCHECK(result == StringParser::PARSE_SUCCESS); + return _minmax_func->assign(&min_val, &max_val); + } case TYPE_VARCHAR: case TYPE_CHAR: case TYPE_STRING: { @@ -1299,6 +1387,24 @@ void IRuntimeFilter::to_protobuf(PInFilter* filter) { }); return; } + case TYPE_DECIMAL32: { + batch_copy(filter, it, [](PColumnValue* column, const int32_t* value) { + column->set_intval(*value); + }); + return; + } + case TYPE_DECIMAL64: { + batch_copy(filter, it, [](PColumnValue* column, const int64_t* value) { + column->set_longval(*value); + }); + return; + } + case TYPE_DECIMAL128: { + batch_copy(filter, it, [](PColumnValue* column, const int128_t* value) { + column->set_stringval(LargeIntValue::to_string(*value)); + }); + return; + } case TYPE_CHAR: case TYPE_VARCHAR: case TYPE_STRING: { @@ -1386,6 +1492,23 @@ void IRuntimeFilter::to_protobuf(PMinMaxFilter* filter) { reinterpret_cast(max_data)->to_string()); return; } + case TYPE_DECIMAL32: { + filter->mutable_min_val()->set_intval(*reinterpret_cast(min_data)); + filter->mutable_max_val()->set_intval(*reinterpret_cast(max_data)); + return; + } + case TYPE_DECIMAL64: { + filter->mutable_min_val()->set_longval(*reinterpret_cast(min_data)); + filter->mutable_max_val()->set_longval(*reinterpret_cast(max_data)); + return; + } + case TYPE_DECIMAL128: { + filter->mutable_min_val()->set_stringval( + LargeIntValue::to_string(*reinterpret_cast(min_data))); + filter->mutable_max_val()->set_stringval( + LargeIntValue::to_string(*reinterpret_cast(max_data))); + return; + } case TYPE_CHAR: case TYPE_VARCHAR: case TYPE_STRING: { @@ -1465,16 +1588,16 @@ Status RuntimePredicateWrapper::get_push_context(T* container, RuntimeState* sta // create max filter Expr* max_literal = nullptr; auto max_pred = create_bin_predicate(_pool, _column_return_type, TExprOpcode::LE); - RETURN_IF_ERROR(create_literal(_pool, _column_return_type, _minmax_func->get_max(), - (void**)&max_literal)); + RETURN_IF_ERROR(create_literal(_pool, prob_expr->root()->type(), + _minmax_func->get_max(), (void**)&max_literal)); max_pred->add_child(Expr::copy(_pool, prob_expr->root())); max_pred->add_child(max_literal); container->push_back(_pool->add(new ExprContext(max_pred))); // create min filter Expr* min_literal = nullptr; auto min_pred = create_bin_predicate(_pool, _column_return_type, TExprOpcode::GE); - RETURN_IF_ERROR(create_literal(_pool, _column_return_type, _minmax_func->get_min(), - (void**)&min_literal)); + RETURN_IF_ERROR(create_literal(_pool, prob_expr->root()->type(), + _minmax_func->get_min(), (void**)&min_literal)); min_pred->add_child(Expr::copy(_pool, prob_expr->root())); min_pred->add_child(min_literal); container->push_back(_pool->add(new ExprContext(min_pred))); @@ -1533,7 +1656,8 @@ Status RuntimePredicateWrapper::get_push_vexprs(std::vectoradd_child(cloned_vexpr); auto& children = const_cast&>(expr->children()); - _hybrid_set->to_vexpr_list(_pool, &children); + _hybrid_set->to_vexpr_list(_pool, &children, vprob_expr->root()->type().precision, + vprob_expr->root()->type().scale); container->push_back( _pool->add(new doris::vectorized::VRuntimeFilterWrapper(node, expr))); } @@ -1543,11 +1667,11 @@ Status RuntimePredicateWrapper::get_push_vexprs(std::vectorroot()->type(), TExprOpcode::LE, &max_pred, &max_pred_node)); doris::vectorized::VExpr* max_literal = nullptr; - RETURN_IF_ERROR(create_literal(_pool, _column_return_type, _minmax_func->get_max(), - (void**)&max_literal)); + RETURN_IF_ERROR(create_literal(_pool, vprob_expr->root()->type(), + _minmax_func->get_max(), (void**)&max_literal)); auto cloned_vexpr = vprob_expr->root()->clone(_pool); max_pred->add_child(cloned_vexpr); max_pred->add_child(max_literal); @@ -1557,11 +1681,11 @@ Status RuntimePredicateWrapper::get_push_vexprs(std::vectorroot()->type(), TExprOpcode::GE, &min_pred, &min_pred_node)); doris::vectorized::VExpr* min_literal = nullptr; - RETURN_IF_ERROR(create_literal(_pool, _column_return_type, _minmax_func->get_min(), - (void**)&min_literal)); + RETURN_IF_ERROR(create_literal(_pool, vprob_expr->root()->type(), + _minmax_func->get_min(), (void**)&min_literal)); cloned_vexpr = vprob_expr->root()->clone(_pool); min_pred->add_child(cloned_vexpr); min_pred->add_child(min_literal); diff --git a/be/src/exprs/scalar_fn_call.cpp b/be/src/exprs/scalar_fn_call.cpp index 07f62a81fc..c5e2b61b0f 100644 --- a/be/src/exprs/scalar_fn_call.cpp +++ b/be/src/exprs/scalar_fn_call.cpp @@ -356,6 +356,9 @@ typedef DoubleVal (*DoubleWrapper)(ExprContext*, TupleRow*); typedef StringVal (*StringWrapper)(ExprContext*, TupleRow*); typedef DateTimeVal (*DatetimeWrapper)(ExprContext*, TupleRow*); typedef DecimalV2Val (*DecimalV2Wrapper)(ExprContext*, TupleRow*); +typedef Decimal32Val (*Decimal32Wrapper)(ExprContext*, TupleRow*); +typedef Decimal64Val (*Decimal64Wrapper)(ExprContext*, TupleRow*); +typedef Decimal128Val (*Decimal128Wrapper)(ExprContext*, TupleRow*); typedef CollectionVal (*ArrayWrapper)(ExprContext*, TupleRow*); // TODO: macroify this? @@ -470,6 +473,36 @@ DecimalV2Val ScalarFnCall::get_decimalv2_val(ExprContext* context, TupleRow* row return fn(context, row); } +Decimal32Val ScalarFnCall::get_decimal32_val(ExprContext* context, TupleRow* row) { + DCHECK_EQ(_type.type, TYPE_DECIMAL32); + DCHECK(context != nullptr); + if (_scalar_fn_wrapper == nullptr) { + return interpret_eval(context, row); + } + Decimal32Wrapper fn = reinterpret_cast(_scalar_fn_wrapper); + return fn(context, row); +} + +Decimal64Val ScalarFnCall::get_decimal64_val(ExprContext* context, TupleRow* row) { + DCHECK_EQ(_type.type, TYPE_DECIMAL64); + DCHECK(context != nullptr); + if (_scalar_fn_wrapper == nullptr) { + return interpret_eval(context, row); + } + Decimal64Wrapper fn = reinterpret_cast(_scalar_fn_wrapper); + return fn(context, row); +} + +Decimal128Val ScalarFnCall::get_decimal128_val(ExprContext* context, TupleRow* row) { + DCHECK_EQ(_type.type, TYPE_DECIMAL128); + DCHECK(context != nullptr); + if (_scalar_fn_wrapper == nullptr) { + return interpret_eval(context, row); + } + Decimal128Wrapper fn = reinterpret_cast(_scalar_fn_wrapper); + return fn(context, row); +} + CollectionVal ScalarFnCall::get_array_val(ExprContext* context, TupleRow* row) { DCHECK_EQ(_type.type, TYPE_ARRAY); DCHECK(context != nullptr); diff --git a/be/src/exprs/scalar_fn_call.h b/be/src/exprs/scalar_fn_call.h index 5242e4b56a..97287c37c7 100644 --- a/be/src/exprs/scalar_fn_call.h +++ b/be/src/exprs/scalar_fn_call.h @@ -87,6 +87,10 @@ protected: virtual doris_udf::DecimalV2Val get_decimalv2_val(ExprContext* context, TupleRow*) override; virtual CollectionVal get_array_val(ExprContext* context, TupleRow*) override; + virtual Decimal32Val get_decimal32_val(ExprContext* context, TupleRow*) override; + virtual Decimal64Val get_decimal64_val(ExprContext* context, TupleRow*) override; + virtual Decimal128Val get_decimal128_val(ExprContext* context, TupleRow*) override; + private: /// If this function has var args, children()[_vararg_start_idx] is the first vararg /// argument. diff --git a/be/src/exprs/slot_ref.cpp b/be/src/exprs/slot_ref.cpp index 59e33c7254..666b427565 100644 --- a/be/src/exprs/slot_ref.cpp +++ b/be/src/exprs/slot_ref.cpp @@ -235,6 +235,36 @@ DecimalV2Val SlotRef::get_decimalv2_val(ExprContext* context, TupleRow* row) { return DecimalV2Val(reinterpret_cast(t->get_slot(_slot_offset))->value); } +Decimal32Val SlotRef::get_decimal32_val(ExprContext* context, TupleRow* row) { + DCHECK_EQ(_type.type, TYPE_DECIMAL32); + Tuple* t = row->get_tuple(_tuple_idx); + if (t == nullptr || t->is_null(_null_indicator_offset)) { + return Decimal32Val::null(); + } + + return Decimal32Val(*reinterpret_cast(t->get_slot(_slot_offset))); +} + +Decimal64Val SlotRef::get_decimal64_val(ExprContext* context, TupleRow* row) { + DCHECK_EQ(_type.type, TYPE_DECIMAL64); + Tuple* t = row->get_tuple(_tuple_idx); + if (t == nullptr || t->is_null(_null_indicator_offset)) { + return Decimal64Val::null(); + } + + return Decimal64Val(*reinterpret_cast(t->get_slot(_slot_offset))); +} + +Decimal128Val SlotRef::get_decimal128_val(ExprContext* context, TupleRow* row) { + DCHECK_EQ(_type.type, TYPE_DECIMAL128); + Tuple* t = row->get_tuple(_tuple_idx); + if (t == nullptr || t->is_null(_null_indicator_offset)) { + return Decimal128Val::null(); + } + + return Decimal128Val(reinterpret_cast(t->get_slot(_slot_offset))->value); +} + doris_udf::CollectionVal SlotRef::get_array_val(ExprContext* context, TupleRow* row) { DCHECK_EQ(_type.type, TYPE_ARRAY); diff --git a/be/src/exprs/slot_ref.h b/be/src/exprs/slot_ref.h index 3e63bda69c..0e916dff9f 100644 --- a/be/src/exprs/slot_ref.h +++ b/be/src/exprs/slot_ref.h @@ -71,6 +71,9 @@ public: virtual doris_udf::DateTimeVal get_datetime_val(ExprContext* context, TupleRow*) override; virtual doris_udf::DecimalV2Val get_decimalv2_val(ExprContext* context, TupleRow*) override; virtual doris_udf::CollectionVal get_array_val(ExprContext* context, TupleRow*) override; + virtual Decimal32Val get_decimal32_val(ExprContext* context, TupleRow*) override; + virtual Decimal64Val get_decimal64_val(ExprContext* context, TupleRow*) override; + virtual Decimal128Val get_decimal128_val(ExprContext* context, TupleRow*) override; private: int _tuple_idx; // within row diff --git a/be/src/olap/aggregate_func.cpp b/be/src/olap/aggregate_func.cpp index dda4d0f780..36f60933aa 100644 --- a/be/src/olap/aggregate_func.cpp +++ b/be/src/olap/aggregate_func.cpp @@ -98,6 +98,9 @@ AggregateFuncResolver::AggregateFuncResolver() { add_aggregate_mapping(); add_aggregate_mapping(); add_aggregate_mapping(); + add_aggregate_mapping(); + add_aggregate_mapping(); + add_aggregate_mapping(); add_aggregate_mapping(); add_aggregate_mapping(); add_aggregate_mapping(); @@ -134,6 +137,12 @@ AggregateFuncResolver::AggregateFuncResolver() { OLAP_FIELD_TYPE_DATETIME>(); add_aggregate_mapping(); + add_aggregate_mapping(); + add_aggregate_mapping(); + add_aggregate_mapping(); add_aggregate_mapping(); @@ -146,6 +155,9 @@ AggregateFuncResolver::AggregateFuncResolver() { add_aggregate_mapping(); add_aggregate_mapping(); add_aggregate_mapping(); + add_aggregate_mapping(); + add_aggregate_mapping(); + add_aggregate_mapping(); add_aggregate_mapping(); add_aggregate_mapping(); add_aggregate_mapping(); @@ -162,6 +174,9 @@ AggregateFuncResolver::AggregateFuncResolver() { add_aggregate_mapping(); add_aggregate_mapping(); add_aggregate_mapping(); + add_aggregate_mapping(); + add_aggregate_mapping(); + add_aggregate_mapping(); add_aggregate_mapping(); add_aggregate_mapping(); add_aggregate_mapping(); @@ -178,6 +193,9 @@ AggregateFuncResolver::AggregateFuncResolver() { add_aggregate_mapping(); add_aggregate_mapping(); add_aggregate_mapping(); + add_aggregate_mapping(); + add_aggregate_mapping(); + add_aggregate_mapping(); // Replace Aggregate Function add_aggregate_mapping(); @@ -189,6 +207,9 @@ AggregateFuncResolver::AggregateFuncResolver() { add_aggregate_mapping(); add_aggregate_mapping(); add_aggregate_mapping(); + add_aggregate_mapping(); + add_aggregate_mapping(); + add_aggregate_mapping(); add_aggregate_mapping(); add_aggregate_mapping(); add_aggregate_mapping(); @@ -206,6 +227,9 @@ AggregateFuncResolver::AggregateFuncResolver() { add_aggregate_mapping(); add_aggregate_mapping(); add_aggregate_mapping(); + add_aggregate_mapping(); + add_aggregate_mapping(); + add_aggregate_mapping(); add_aggregate_mapping(); add_aggregate_mapping(); add_aggregate_mapping(); diff --git a/be/src/olap/bloom_filter_predicate.cpp b/be/src/olap/bloom_filter_predicate.cpp index 6b87a86a61..a1acb913c4 100644 --- a/be/src/olap/bloom_filter_predicate.cpp +++ b/be/src/olap/bloom_filter_predicate.cpp @@ -32,7 +32,10 @@ M(TYPE_DATEV2) \ M(TYPE_DATETIME) \ M(TYPE_VARCHAR) \ - M(TYPE_STRING) + M(TYPE_STRING) \ + M(TYPE_DECIMAL32) \ + M(TYPE_DECIMAL64) \ + M(TYPE_DECIMAL128) namespace doris { ColumnPredicate* BloomFilterColumnPredicateFactory::create_column_predicate( diff --git a/be/src/olap/column_vector.cpp b/be/src/olap/column_vector.cpp index 15957ffd2c..6d085b7a64 100644 --- a/be/src/olap/column_vector.cpp +++ b/be/src/olap/column_vector.cpp @@ -87,6 +87,21 @@ Status ColumnVectorBatch::create(size_t init_capacity, bool is_nullable, const T new ScalarColumnVectorBatch::CppType>( type_info, is_nullable)); break; + case OLAP_FIELD_TYPE_DECIMAL32: + local.reset( + new ScalarColumnVectorBatch::CppType>( + type_info, is_nullable)); + break; + case OLAP_FIELD_TYPE_DECIMAL64: + local.reset( + new ScalarColumnVectorBatch::CppType>( + type_info, is_nullable)); + break; + case OLAP_FIELD_TYPE_DECIMAL128: + local.reset( + new ScalarColumnVectorBatch::CppType>( + type_info, is_nullable)); + break; case OLAP_FIELD_TYPE_DATE: local.reset(new ScalarColumnVectorBatch::CppType>( type_info, is_nullable)); diff --git a/be/src/olap/delete_handler.cpp b/be/src/olap/delete_handler.cpp index 69e5f0b7af..99c6087575 100644 --- a/be/src/olap/delete_handler.cpp +++ b/be/src/olap/delete_handler.cpp @@ -139,6 +139,12 @@ bool DeleteConditionHandler::is_condition_value_valid(const TabletColumn& column return valid_unsigned_number(value_str); case OLAP_FIELD_TYPE_DECIMAL: return valid_decimal(value_str, column.precision(), column.frac()); + case OLAP_FIELD_TYPE_DECIMAL32: + return valid_decimal(value_str, column.precision(), column.frac()); + case OLAP_FIELD_TYPE_DECIMAL64: + return valid_decimal(value_str, column.precision(), column.frac()); + case OLAP_FIELD_TYPE_DECIMAL128: + return valid_decimal(value_str, column.precision(), column.frac()); case OLAP_FIELD_TYPE_CHAR: case OLAP_FIELD_TYPE_VARCHAR: return value_str.size() <= column.length(); diff --git a/be/src/olap/field.h b/be/src/olap/field.h index 2973bc1fbb..dadd483581 100644 --- a/be/src/olap/field.h +++ b/be/src/olap/field.h @@ -243,7 +243,8 @@ public: // used by init scan key stored in string format // value_string should end with '\0' - Status from_string(char* buf, const std::string& value_string) const { + Status from_string(char* buf, const std::string& value_string, const int precision = 0, + const int scale = 0) const { if (type() == OLAP_FIELD_TYPE_STRING && !value_string.empty()) { auto slice = reinterpret_cast(buf); if (slice->size < value_string.size()) { @@ -252,7 +253,7 @@ public: slice->size = value_string.size(); } } - return _type_info->from_string(buf, value_string); + return _type_info->from_string(buf, value_string, precision, scale); } // convert inner value to string @@ -298,6 +299,11 @@ public: Field* get_sub_field(int i) const { return _sub_fields[i].get(); } size_t get_sub_field_count() const { return _sub_fields.size(); } + void set_precision(int32_t precision) { _precision = precision; } + void set_scale(int32_t scale) { _scale = scale; } + int32_t get_precision() const { return _precision; } + int32_t get_scale() const { return _scale; } + protected: TypeInfoPtr _type_info; const AggregateInfo* _agg_info; @@ -326,6 +332,8 @@ protected: other->_index_size = this->_index_size; other->_is_nullable = this->_is_nullable; other->_sub_fields.clear(); + other->_precision = this->_precision; + other->_scale = this->_scale; for (const auto& f : _sub_fields) { Field* item = f->clone(); other->add_sub_field(std::unique_ptr(item)); @@ -340,6 +348,8 @@ private: uint16_t _index_size; bool _is_nullable; std::vector> _sub_fields; + int32_t _precision; + int32_t _scale; }; template @@ -738,6 +748,18 @@ public: local->add_sub_field(std::move(item_field)); return local; } + case OLAP_FIELD_TYPE_DECIMAL: + [[fallthrough]]; + case OLAP_FIELD_TYPE_DECIMAL32: + [[fallthrough]]; + case OLAP_FIELD_TYPE_DECIMAL64: + [[fallthrough]]; + case OLAP_FIELD_TYPE_DECIMAL128: { + Field* field = new Field(column); + field->set_precision(column.precision()); + field->set_scale(column.frac()); + return field; + } default: return new Field(column); } @@ -764,6 +786,18 @@ public: local->add_sub_field(std::move(item_field)); return local; } + case OLAP_FIELD_TYPE_DECIMAL: + [[fallthrough]]; + case OLAP_FIELD_TYPE_DECIMAL32: + [[fallthrough]]; + case OLAP_FIELD_TYPE_DECIMAL64: + [[fallthrough]]; + case OLAP_FIELD_TYPE_DECIMAL128: { + Field* field = new Field(column); + field->set_precision(column.precision()); + field->set_scale(column.frac()); + return field; + } default: return new Field(column); } diff --git a/be/src/olap/key_coder.cpp b/be/src/olap/key_coder.cpp index 49a06a13d8..ce3a722893 100644 --- a/be/src/olap/key_coder.cpp +++ b/be/src/olap/key_coder.cpp @@ -74,6 +74,9 @@ private: add_mapping(); add_mapping(); add_mapping(); + add_mapping(); + add_mapping(); + add_mapping(); } template diff --git a/be/src/olap/olap_common.h b/be/src/olap/olap_common.h index 92a1d6f241..1763a7b94c 100644 --- a/be/src/olap/olap_common.h +++ b/be/src/olap/olap_common.h @@ -147,7 +147,11 @@ enum FieldType { OLAP_FIELD_TYPE_STRING = 26, OLAP_FIELD_TYPE_QUANTILE_STATE = 27, OLAP_FIELD_TYPE_DATEV2 = 28, - OLAP_FIELD_TYPE_DATETIMEV2 = 29 + OLAP_FIELD_TYPE_DATETIMEV2 = 29, + OLAP_FIELD_TYPE_TIMEV2 = 30, + OLAP_FIELD_TYPE_DECIMAL32 = 31, + OLAP_FIELD_TYPE_DECIMAL64 = 32, + OLAP_FIELD_TYPE_DECIMAL128 = 33 }; // Define all aggregation methods supported by Field diff --git a/be/src/olap/olap_cond.cpp b/be/src/olap/olap_cond.cpp index 743566b387..558ad98bf5 100644 --- a/be/src/olap/olap_cond.cpp +++ b/be/src/olap/olap_cond.cpp @@ -129,7 +129,7 @@ Status Cond::init(const TCondition& tcond, const TabletColumn& column) { << ", operand=" << operand->c_str() << ", op_type=" << op << "]"; return Status::OLAPInternalError(OLAP_ERR_INPUT_PARAMETER_ERROR); } - Status res = f->from_string(*operand); + Status res = f->from_string(*operand, column.precision(), column.frac()); if (!res.ok()) { LOG(WARNING) << "Convert from string failed. [name=" << tcond.column_name << ", operand=" << operand->c_str() << ", op_type=" << op << "]"; @@ -146,7 +146,7 @@ Status Cond::init(const TCondition& tcond, const TabletColumn& column) { << ", operand=" << operand.c_str() << ", op_type=" << op << "]"; return Status::OLAPInternalError(OLAP_ERR_INPUT_PARAMETER_ERROR); } - Status res = f->from_string(operand); + Status res = f->from_string(operand, column.precision(), column.frac()); if (!res.ok()) { LOG(WARNING) << "Convert from string failed. [name=" << tcond.column_name << ", operand=" << operand.c_str() << ", op_type=" << op << "]"; diff --git a/be/src/olap/reader.cpp b/be/src/olap/reader.cpp index a3286f4b33..4099355869 100644 --- a/be/src/olap/reader.cpp +++ b/be/src/olap/reader.cpp @@ -36,6 +36,7 @@ #include "runtime/mem_pool.h" #include "util/date_func.h" #include "util/mem_util.hpp" +#include "vec/data_types/data_type_decimal.h" using std::nothrow; using std::set; @@ -479,6 +480,31 @@ void TabletReader::_init_conditions_param(const ReaderParams& read_params) { predicate = new PREDICATE(index, value, opposite); \ break; \ } \ + case OLAP_FIELD_TYPE_DECIMAL32: { \ + int32_t value = 0; \ + StringParser::ParseResult result = StringParser::ParseResult::PARSE_SUCCESS; \ + value = (int32_t)StringParser::string_to_decimal( \ + cond.data(), cond.size(), column.precision(), column.frac(), &result); \ + \ + predicate = new PREDICATE(index, value, opposite); \ + break; \ + } \ + case OLAP_FIELD_TYPE_DECIMAL64: { \ + int64_t value = 0; \ + StringParser::ParseResult result = StringParser::ParseResult::PARSE_SUCCESS; \ + value = (int64_t)StringParser::string_to_decimal( \ + cond.data(), cond.size(), column.precision(), column.frac(), &result); \ + predicate = new PREDICATE(index, value, opposite); \ + break; \ + } \ + case OLAP_FIELD_TYPE_DECIMAL128: { \ + int128_t value = 0; \ + StringParser::ParseResult result; \ + value = StringParser::string_to_decimal( \ + cond.data(), cond.size(), column.precision(), column.frac(), &result); \ + predicate = new PREDICATE(index, value, opposite); \ + break; \ + } \ case OLAP_FIELD_TYPE_INT: { \ int32_t value = 0; \ std::from_chars(cond.data(), cond.data() + cond.size(), value); \ @@ -639,6 +665,61 @@ ColumnPredicate* TabletReader::_parse_to_predicate(const TCondition& condition, } break; } + case OLAP_FIELD_TYPE_DECIMAL32: { + phmap::flat_hash_set values; + for (auto& cond_val : condition.condition_values) { + StringParser::ParseResult result = StringParser::ParseResult::PARSE_SUCCESS; + int128_t val = StringParser::string_to_decimal( + cond_val.data(), cond_val.size(), column.precision(), column.frac(), + &result); + if (result == StringParser::ParseResult::PARSE_SUCCESS) { + values.insert((int32_t)val); + } + } + if (condition.condition_op == "*=") { + predicate = new InListPredicate(index, std::move(values), opposite); + } else { + predicate = new NotInListPredicate(index, std::move(values), opposite); + } + break; + } + case OLAP_FIELD_TYPE_DECIMAL64: { + phmap::flat_hash_set values; + for (auto& cond_val : condition.condition_values) { + StringParser::ParseResult result; + int128_t val = StringParser::string_to_decimal( + cond_val.data(), cond_val.size(), column.precision(), column.frac(), + &result); + if (result == StringParser::ParseResult::PARSE_SUCCESS) { + values.insert((int64_t)val); + } + } + if (condition.condition_op == "*=") { + predicate = new InListPredicate(index, std::move(values), opposite); + } else { + predicate = new NotInListPredicate(index, std::move(values), opposite); + } + break; + } + case OLAP_FIELD_TYPE_DECIMAL128: { + phmap::flat_hash_set values; + int128_t val; + for (auto& cond_val : condition.condition_values) { + StringParser::ParseResult result = StringParser::ParseResult::PARSE_SUCCESS; + val = StringParser::string_to_decimal(cond_val.data(), cond_val.size(), + column.precision(), column.frac(), + &result); + if (result == StringParser::ParseResult::PARSE_SUCCESS) { + values.insert(val); + } + } + if (condition.condition_op == "*=") { + predicate = new InListPredicate(index, std::move(values), opposite); + } else { + predicate = new NotInListPredicate(index, std::move(values), opposite); + } + break; + } case OLAP_FIELD_TYPE_INT: { phmap::flat_hash_set values; int32_t value = 0; diff --git a/be/src/olap/row_block2.cpp b/be/src/olap/row_block2.cpp index b850a50899..8f8632e5ea 100644 --- a/be/src/olap/row_block2.cpp +++ b/be/src/olap/row_block2.cpp @@ -287,6 +287,24 @@ Status RowBlockV2::_copy_data_to_column(int cid, } break; } + case OLAP_FIELD_TYPE_DECIMAL32: { + auto column_decimal = + assert_cast*>(column); + insert_data_directly(cid, column_decimal); + break; + } + case OLAP_FIELD_TYPE_DECIMAL64: { + auto column_decimal = + assert_cast*>(column); + insert_data_directly(cid, column_decimal); + break; + } + case OLAP_FIELD_TYPE_DECIMAL128: { + auto column_decimal = + assert_cast*>(column); + insert_data_directly(cid, column_decimal); + break; + } case OLAP_FIELD_TYPE_ARRAY: { auto column_array = assert_cast(column); auto nested_col = (*column_array->get_data_ptr()).assume_mutable(); @@ -554,6 +572,21 @@ Status RowBlockV2::_append_data_to_column(const ColumnVectorBatch* batch, size_t } break; } + case OLAP_FIELD_TYPE_DECIMAL32: { + auto column_decimal = + assert_cast*>(column); + insert_data_directly(batch, column_decimal, start, len); + } + case OLAP_FIELD_TYPE_DECIMAL64: { + auto column_decimal = + assert_cast*>(column); + insert_data_directly(batch, column_decimal, start, len); + } + case OLAP_FIELD_TYPE_DECIMAL128: { + auto column_decimal = + assert_cast*>(column); + insert_data_directly(batch, column_decimal, start, len); + } case OLAP_FIELD_TYPE_ARRAY: { auto array_batch = reinterpret_cast(batch); auto column_array = assert_cast(column); diff --git a/be/src/olap/row_cursor.cpp b/be/src/olap/row_cursor.cpp index 303f421f25..bb91469177 100644 --- a/be/src/olap/row_cursor.cpp +++ b/be/src/olap/row_cursor.cpp @@ -271,7 +271,8 @@ Status RowCursor::from_tuple(const OlapTuple& tuple) { } set_not_null(cid); char* buf = cell_ptr(cid); - Status res = field->from_string(buf, tuple.get_value(i)); + Status res = field->from_string(buf, tuple.get_value(i), field->get_precision(), + field->get_scale()); if (!res.ok()) { LOG(WARNING) << "fail to convert field from string. string=" << tuple.get_value(i) << ", res=" << res; diff --git a/be/src/olap/rowset/segment_v2/bitmap_index_writer.cpp b/be/src/olap/rowset/segment_v2/bitmap_index_writer.cpp index e2437d175a..fea31b0a5b 100644 --- a/be/src/olap/rowset/segment_v2/bitmap_index_writer.cpp +++ b/be/src/olap/rowset/segment_v2/bitmap_index_writer.cpp @@ -229,6 +229,15 @@ Status BitmapIndexWriter::create(const TypeInfo* type_info, case OLAP_FIELD_TYPE_DECIMAL: res->reset(new BitmapIndexWriterImpl(type_info)); break; + case OLAP_FIELD_TYPE_DECIMAL32: + res->reset(new BitmapIndexWriterImpl(type_info)); + break; + case OLAP_FIELD_TYPE_DECIMAL64: + res->reset(new BitmapIndexWriterImpl(type_info)); + break; + case OLAP_FIELD_TYPE_DECIMAL128: + res->reset(new BitmapIndexWriterImpl(type_info)); + break; case OLAP_FIELD_TYPE_BOOL: res->reset(new BitmapIndexWriterImpl(type_info)); break; diff --git a/be/src/olap/rowset/segment_v2/column_reader.cpp b/be/src/olap/rowset/segment_v2/column_reader.cpp index 23d8eb9e82..e2aef4e2d8 100644 --- a/be/src/olap/rowset/segment_v2/column_reader.cpp +++ b/be/src/olap/rowset/segment_v2/column_reader.cpp @@ -889,7 +889,7 @@ Status DefaultValueColumnIterator::init(const ColumnIteratorOptions& opts) { // TODO llj for Array default value return Status::NotSupported("Array default type is unsupported"); } else { - s = _type_info->from_string(_mem_value, _default_value); + s = _type_info->from_string(_mem_value, _default_value, _precision, _scale); } if (!s.ok()) { return s; diff --git a/be/src/olap/rowset/segment_v2/column_reader.h b/be/src/olap/rowset/segment_v2/column_reader.h index 9aa10dad01..9fe37bdae8 100644 --- a/be/src/olap/rowset/segment_v2/column_reader.h +++ b/be/src/olap/rowset/segment_v2/column_reader.h @@ -427,7 +427,8 @@ private: class DefaultValueColumnIterator : public ColumnIterator { public: DefaultValueColumnIterator(bool has_default_value, const std::string& default_value, - bool is_nullable, TypeInfoPtr type_info, size_t schema_length) + bool is_nullable, TypeInfoPtr type_info, size_t schema_length, + int precision, int scale) : _has_default_value(has_default_value), _default_value(default_value), _is_nullable(is_nullable), @@ -435,6 +436,8 @@ public: _schema_length(schema_length), _is_default_value_null(false), _type_size(0), + _precision(precision), + _scale(scale), _pool(new MemPool("DefaultValueColumnIterator")) {} Status init(const ColumnIteratorOptions& opts) override; @@ -476,6 +479,8 @@ private: size_t _schema_length; bool _is_default_value_null; size_t _type_size; + int _precision; + int _scale; void* _mem_value = nullptr; std::unique_ptr _pool; diff --git a/be/src/olap/rowset/segment_v2/encoding_info.cpp b/be/src/olap/rowset/segment_v2/encoding_info.cpp index 555886611d..1a49da8792 100644 --- a/be/src/olap/rowset/segment_v2/encoding_info.cpp +++ b/be/src/olap/rowset/segment_v2/encoding_info.cpp @@ -294,6 +294,18 @@ EncodingInfoResolver::EncodingInfoResolver() { _add_map(); _add_map(); + _add_map(); + _add_map(); + _add_map(); + + _add_map(); + _add_map(); + _add_map(); + + _add_map(); + _add_map(); + _add_map(); + _add_map(); _add_map(); diff --git a/be/src/olap/rowset/segment_v2/segment.cpp b/be/src/olap/rowset/segment_v2/segment.cpp index e18a000905..bc4a750788 100644 --- a/be/src/olap/rowset/segment_v2/segment.cpp +++ b/be/src/olap/rowset/segment_v2/segment.cpp @@ -205,7 +205,8 @@ Status Segment::new_column_iterator(uint32_t cid, ColumnIterator** iter) { std::unique_ptr default_value_iter( new DefaultValueColumnIterator( tablet_column.has_default_value(), tablet_column.default_value(), - tablet_column.is_nullable(), std::move(type_info), tablet_column.length())); + tablet_column.is_nullable(), std::move(type_info), tablet_column.length(), + tablet_column.precision(), tablet_column.frac())); ColumnIteratorOptions iter_opts; RETURN_IF_ERROR(default_value_iter->init(iter_opts)); diff --git a/be/src/olap/rowset/segment_v2/segment_iterator.cpp b/be/src/olap/rowset/segment_v2/segment_iterator.cpp index cedc0a9c6f..921c37e575 100644 --- a/be/src/olap/rowset/segment_v2/segment_iterator.cpp +++ b/be/src/olap/rowset/segment_v2/segment_iterator.cpp @@ -812,6 +812,8 @@ void SegmentIterator::_init_current_block( current_columns[cid]->set_datetime_type(); } else if (column_desc->type() == OLAP_FIELD_TYPE_DATEV2) { current_columns[cid]->set_date_v2_type(); + } else if (column_desc->type() == OLAP_FIELD_TYPE_DECIMAL) { + current_columns[cid]->set_decimalv2_type(); } current_columns[cid]->reserve(_opts.block_row_max); } diff --git a/be/src/olap/schema.cpp b/be/src/olap/schema.cpp index dce43a1bdb..47fc3bcbdd 100644 --- a/be/src/olap/schema.cpp +++ b/be/src/olap/schema.cpp @@ -169,6 +169,12 @@ vectorized::IColumn::MutablePtr Schema::get_predicate_column_ptr(FieldType type) case OLAP_FIELD_TYPE_DECIMAL: return doris::vectorized::PredicateColumnType::create(); + case OLAP_FIELD_TYPE_DECIMAL32: + return doris::vectorized::PredicateColumnType::create(); + case OLAP_FIELD_TYPE_DECIMAL64: + return doris::vectorized::PredicateColumnType::create(); + case OLAP_FIELD_TYPE_DECIMAL128: + return doris::vectorized::PredicateColumnType::create(); default: LOG(FATAL) << "Unexpected type when choosing predicate column, type=" << type; diff --git a/be/src/olap/schema_change.cpp b/be/src/olap/schema_change.cpp index ba0609e03e..edc2b8ef4c 100644 --- a/be/src/olap/schema_change.cpp +++ b/be/src/olap/schema_change.cpp @@ -2402,6 +2402,8 @@ Status SchemaChangeHandler::_parse_request( auto column_new = new_tablet_schema.column(i); auto column_old = ref_tablet_schema.column(column_mapping->ref_column); if (column_new.type() != column_old.type() || + column_new.precision() != column_old.precision() || + column_new.frac() != column_old.frac() || column_new.length() != column_old.length() || column_new.is_bf_column() != column_old.is_bf_column() || column_new.has_bitmap_index() != column_old.has_bitmap_index()) { @@ -2437,7 +2439,8 @@ Status SchemaChangeHandler::_init_column_mapping(ColumnMapping* column_mapping, if (column_schema.is_nullable() && value.length() == 0) { column_mapping->default_value->set_null(); } else { - column_mapping->default_value->from_string(value); + column_mapping->default_value->from_string(value, column_schema.precision(), + column_schema.frac()); } return Status::OK(); diff --git a/be/src/olap/tablet_meta.cpp b/be/src/olap/tablet_meta.cpp index 5c1be52076..ac6fe44e4a 100644 --- a/be/src/olap/tablet_meta.cpp +++ b/be/src/olap/tablet_meta.cpp @@ -19,6 +19,7 @@ #include +#include "common/consts.h" #include "olap/file_helper.h" #include "olap/olap_common.h" #include "olap/olap_define.h" @@ -215,14 +216,12 @@ void TabletMeta::init_column_from_tcolumn(uint32_t unique_id, const TColumn& tco EnumToString(TPrimitiveType, tcolumn.column_type.type, data_type); column->set_type(data_type); - if (tcolumn.column_type.type == TPrimitiveType::DECIMALV2) { - column->set_precision(tcolumn.column_type.precision); - column->set_frac(tcolumn.column_type.scale); - } uint32_t length = TabletColumn::get_field_length_by_type(tcolumn.column_type.type, tcolumn.column_type.len); column->set_length(length); column->set_index_length(length); + column->set_precision(tcolumn.column_type.precision); + column->set_frac(tcolumn.column_type.scale); if (tcolumn.column_type.type == TPrimitiveType::VARCHAR || tcolumn.column_type.type == TPrimitiveType::STRING) { if (!tcolumn.column_type.__isset.index_len) { diff --git a/be/src/olap/tablet_schema.cpp b/be/src/olap/tablet_schema.cpp index a8659fc1e3..2dfda1fc07 100644 --- a/be/src/olap/tablet_schema.cpp +++ b/be/src/olap/tablet_schema.cpp @@ -63,6 +63,12 @@ FieldType TabletColumn::get_field_type_by_string(const std::string& type_str) { type = OLAP_FIELD_TYPE_DATEV2; } else if (0 == upper_type_str.compare("DATETIME")) { type = OLAP_FIELD_TYPE_DATETIME; + } else if (0 == upper_type_str.compare("DECIMAL32")) { + type = OLAP_FIELD_TYPE_DECIMAL32; + } else if (0 == upper_type_str.compare("DECIMAL64")) { + type = OLAP_FIELD_TYPE_DECIMAL64; + } else if (0 == upper_type_str.compare("DECIMAL128")) { + type = OLAP_FIELD_TYPE_DECIMAL128; } else if (0 == upper_type_str.compare(0, 7, "DECIMAL")) { type = OLAP_FIELD_TYPE_DECIMAL; } else if (0 == upper_type_str.compare(0, 7, "VARCHAR")) { @@ -177,6 +183,15 @@ std::string TabletColumn::get_string_by_field_type(FieldType type) { case OLAP_FIELD_TYPE_DECIMAL: return "DECIMAL"; + case OLAP_FIELD_TYPE_DECIMAL32: + return "DECIMAL32"; + + case OLAP_FIELD_TYPE_DECIMAL64: + return "DECIMAL64"; + + case OLAP_FIELD_TYPE_DECIMAL128: + return "DECIMAL128"; + case OLAP_FIELD_TYPE_VARCHAR: return "VARCHAR"; @@ -277,6 +292,12 @@ uint32_t TabletColumn::get_field_length_by_type(TPrimitiveType::type type, uint3 return string_length + sizeof(OLAP_STRING_MAX_LENGTH); case TPrimitiveType::ARRAY: return OLAP_ARRAY_MAX_LENGTH; + case TPrimitiveType::DECIMAL32: + return 4; + case TPrimitiveType::DECIMAL64: + return 8; + case TPrimitiveType::DECIMAL128: + return 16; case TPrimitiveType::DECIMALV2: return 12; // use 12 bytes in olap engine. default: diff --git a/be/src/olap/types.cpp b/be/src/olap/types.cpp index d1b2f345bf..1ae1f93ee6 100644 --- a/be/src/olap/types.cpp +++ b/be/src/olap/types.cpp @@ -84,6 +84,11 @@ const TypeInfo* get_scalar_type_info(FieldType field_type) { get_scalar_type_info(), get_scalar_type_info(), get_scalar_type_info(), + get_scalar_type_info(), + get_scalar_type_info(), + get_scalar_type_info(), + get_scalar_type_info(), + get_scalar_type_info(), }; return field_type_array[field_type]; } @@ -149,6 +154,11 @@ const TypeInfo* get_array_type_info(FieldType leaf_type, int32_t iterations) { INIT_ARRAY_TYPE_INFO_LIST(OLAP_FIELD_TYPE_STRING), INIT_ARRAY_TYPE_INFO_LIST(OLAP_FIELD_TYPE_QUANTILE_STATE), INIT_ARRAY_TYPE_INFO_LIST(OLAP_FIELD_TYPE_DATEV2), + INIT_ARRAY_TYPE_INFO_LIST(OLAP_FIELD_TYPE_DATETIMEV2), + INIT_ARRAY_TYPE_INFO_LIST(OLAP_FIELD_TYPE_TIMEV2), + INIT_ARRAY_TYPE_INFO_LIST(OLAP_FIELD_TYPE_DECIMAL32), + INIT_ARRAY_TYPE_INFO_LIST(OLAP_FIELD_TYPE_DECIMAL64), + INIT_ARRAY_TYPE_INFO_LIST(OLAP_FIELD_TYPE_DECIMAL128), }; return array_type_Info_arr[leaf_type][iterations]; } diff --git a/be/src/olap/types.h b/be/src/olap/types.h index 6f5bb48ef3..0f6a604aa0 100644 --- a/be/src/olap/types.h +++ b/be/src/olap/types.h @@ -78,7 +78,8 @@ public: virtual Status convert_from(void* dest, const void* src, const TypeInfo* src_type, MemPool* mem_pool, size_t variable_len = 0) const = 0; - virtual Status from_string(void* buf, const std::string& scan_key) const = 0; + virtual Status from_string(void* buf, const std::string& scan_key, const int precision = 0, + const int scale = 0) const = 0; virtual std::string to_string(const void* src) const = 0; @@ -121,8 +122,9 @@ public: return _convert_from(dest, src, src_type, mem_pool, variable_len); } - Status from_string(void* buf, const std::string& scan_key) const override { - return _from_string(buf, scan_key); + Status from_string(void* buf, const std::string& scan_key, const int precision = 0, + const int scale = 0) const override { + return _from_string(buf, scan_key, precision, scale); } std::string to_string(const void* src) const override { return _to_string(src); } @@ -167,7 +169,8 @@ private: Status (*_convert_from)(void* dest, const void* src, const TypeInfo* src_type, MemPool* mem_pool, size_t variable_len); - Status (*_from_string)(void* buf, const std::string& scan_key); + Status (*_from_string)(void* buf, const std::string& scan_key, const int precision, + const int scale); std::string (*_to_string)(const void* src); void (*_set_to_max)(void* buf); @@ -369,7 +372,8 @@ public: return Status::OLAPInternalError(OLAP_ERR_FUNC_NOT_IMPLEMENTED); } - Status from_string(void* buf, const std::string& scan_key) const override { + Status from_string(void* buf, const std::string& scan_key, const int precision = 0, + const int scale = 0) const override { return Status::OLAPInternalError(OLAP_ERR_FUNC_NOT_IMPLEMENTED); } @@ -495,6 +499,21 @@ struct CppTypeTraits { using UnsignedCppType = decimal12_t; }; template <> +struct CppTypeTraits { + using CppType = int32_t; + using UnsignedCppType = uint32_t; +}; +template <> +struct CppTypeTraits { + using CppType = int64_t; + using UnsignedCppType = uint64_t; +}; +template <> +struct CppTypeTraits { + using CppType = int128_t; + using UnsignedCppType = uint128_t; +}; +template <> struct CppTypeTraits { using CppType = uint24_t; using UnsignedCppType = uint24_t; @@ -505,6 +524,16 @@ struct CppTypeTraits { using UnsignedCppType = uint32_t; }; template <> +struct CppTypeTraits { + using CppType = uint64_t; + using UnsignedCppType = uint64_t; +}; +template <> +struct CppTypeTraits { + using CppType = uint64_t; + using UnsignedCppType = uint64_t; +}; +template <> struct CppTypeTraits { using CppType = int64_t; using UnsignedCppType = uint64_t; @@ -608,7 +637,8 @@ struct BaseFieldtypeTraits : public CppTypeTraits { return std::to_string(get_cpp_type_value(src)); } - static Status from_string(void* buf, const std::string& scan_key) { + static Status from_string(void* buf, const std::string& scan_key, const int precision, + const int scale) { CppType value = 0; if (scan_key.length() > 0) { value = static_cast(strtol(scan_key.c_str(), nullptr, 10)); @@ -707,7 +737,8 @@ struct FieldTypeTraits : public BaseFieldtypeTraits struct FieldTypeTraits : public NumericFieldtypeTraits { - static Status from_string(void* buf, const std::string& scan_key) { + static Status from_string(void* buf, const std::string& scan_key, const int precision, + const int scale) { int128_t value = 0; const char* value_string = scan_key.c_str(); @@ -812,7 +843,8 @@ struct FieldTypeTraits template <> struct FieldTypeTraits : public NumericFieldtypeTraits { - static Status from_string(void* buf, const std::string& scan_key) { + static Status from_string(void* buf, const std::string& scan_key, const int precision, + const int scale) { CppType value = 0.0f; if (scan_key.length() > 0) { value = static_cast(atof(scan_key.c_str())); @@ -833,7 +865,8 @@ struct FieldTypeTraits template <> struct FieldTypeTraits : public NumericFieldtypeTraits { - static Status from_string(void* buf, const std::string& scan_key) { + static Status from_string(void* buf, const std::string& scan_key, const int precision, + const int scale) { CppType value = 0.0; if (scan_key.length() > 0) { value = atof(scan_key.c_str()); @@ -879,7 +912,8 @@ struct FieldTypeTraits template <> struct FieldTypeTraits : public BaseFieldtypeTraits { - static Status from_string(void* buf, const std::string& scan_key) { + static Status from_string(void* buf, const std::string& scan_key, const int precision, + const int scale) { CppType* data_ptr = reinterpret_cast(buf); return data_ptr->from_string(scan_key); } @@ -899,9 +933,91 @@ struct FieldTypeTraits } }; +template <> +struct FieldTypeTraits + : public BaseFieldtypeTraits { + static Status from_string(void* buf, const std::string& scan_key, const int precision, + const int scale) { + StringParser::ParseResult result = StringParser::PARSE_SUCCESS; + int32_t value = StringParser::string_to_decimal(scan_key.c_str(), scan_key.size(), + 9, scale, &result); + + if (result == StringParser::PARSE_FAILURE) { + return Status::OLAPInternalError(OLAP_ERR_INPUT_PARAMETER_ERROR); + } + *reinterpret_cast(buf) = (int32_t)value; + return Status::OK(); + } + static void set_to_max(void* buf) { + CppType* data = reinterpret_cast(buf); + *data = 999999999; + } + static void set_to_min(void* buf) { + CppType* data = reinterpret_cast(buf); + *data = -999999999; + } +}; + +template <> +struct FieldTypeTraits + : public BaseFieldtypeTraits { + static Status from_string(void* buf, const std::string& scan_key, const int precision, + const int scale) { + StringParser::ParseResult result = StringParser::PARSE_SUCCESS; + int64_t value = StringParser::string_to_decimal(scan_key.c_str(), scan_key.size(), + 18, scale, &result); + if (result == StringParser::PARSE_FAILURE) { + return Status::OLAPInternalError(OLAP_ERR_INPUT_PARAMETER_ERROR); + } + *reinterpret_cast(buf) = (int64_t)value; + return Status::OK(); + } + static void set_to_max(void* buf) { + CppType* data = reinterpret_cast(buf); + *data = 999999999999999999ll; + } + static void set_to_min(void* buf) { + CppType* data = reinterpret_cast(buf); + *data = -999999999999999999ll; + } +}; + +template <> +struct FieldTypeTraits + : public BaseFieldtypeTraits { + static Status from_string(void* buf, const std::string& scan_key, const int precision, + const int scale) { + StringParser::ParseResult result = StringParser::PARSE_SUCCESS; + int128_t value = StringParser::string_to_decimal( + scan_key.c_str(), scan_key.size(), 38, scale, &result); + if (result == StringParser::PARSE_FAILURE) { + return Status::OLAPInternalError(OLAP_ERR_INPUT_PARAMETER_ERROR); + } + *reinterpret_cast(buf) = value; + return Status::OK(); + } + static std::string to_string(const void* src) { + int128_t value = reinterpret_cast(src)->value; + fmt::memory_buffer buffer; + fmt::format_to(buffer, "{}", value); + return std::string(buffer.data(), buffer.size()); + } + static void set_to_max(void* buf) { + *reinterpret_cast(buf) = + static_cast(999999999999999999ll) * 100000000000000000ll * 1000ll + + static_cast(99999999999999999ll) * 1000ll + 999ll; + } + static void set_to_min(void* buf) { + *reinterpret_cast(buf) = + -(static_cast(999999999999999999ll) * 100000000000000000ll * 1000ll + + static_cast(99999999999999999ll) * 1000ll + 999ll); + } +}; + template <> struct FieldTypeTraits : public BaseFieldtypeTraits { - static Status from_string(void* buf, const std::string& scan_key) { + static Status from_string(void* buf, const std::string& scan_key, const int precision, + const int scale) { tm time_tm; char* res = strptime(scan_key.c_str(), "%Y-%m-%d", &time_tm); @@ -991,7 +1107,8 @@ struct FieldTypeTraits : public BaseFieldtypeTraits struct FieldTypeTraits : public BaseFieldtypeTraits { - static Status from_string(void* buf, const std::string& scan_key) { + static Status from_string(void* buf, const std::string& scan_key, const int precision, + const int scale) { tm time_tm; char* res = strptime(scan_key.c_str(), "%Y-%m-%d", &time_tm); @@ -1089,7 +1206,8 @@ struct FieldTypeTraits template <> struct FieldTypeTraits : public BaseFieldtypeTraits { - static Status from_string(void* buf, const std::string& scan_key) { + static Status from_string(void* buf, const std::string& scan_key, const int precision, + const int scale) { tm time_tm; char* res = strptime(scan_key.c_str(), "%Y-%m-%d %H:%M:%S", &time_tm); @@ -1165,7 +1283,8 @@ struct FieldTypeTraits : public BaseFieldtypeTraits(right); return l_slice->compare(*r_slice); } - static Status from_string(void* buf, const std::string& scan_key) { + static Status from_string(void* buf, const std::string& scan_key, const int precision, + const int scale) { size_t value_len = scan_key.length(); if (value_len > OLAP_VARCHAR_MAX_LENGTH) { LOG(WARNING) << "the len of value string is too long, len=" << value_len @@ -1238,7 +1357,8 @@ struct FieldTypeTraits : public BaseFieldtypeTraits struct FieldTypeTraits : public FieldTypeTraits { - static Status from_string(void* buf, const std::string& scan_key) { + static Status from_string(void* buf, const std::string& scan_key, const int precision, + const int scale) { size_t value_len = scan_key.length(); if (value_len > OLAP_VARCHAR_MAX_LENGTH) { LOG(WARNING) << "the len of value string is too long, len=" << value_len @@ -1290,7 +1410,8 @@ struct FieldTypeTraits : public FieldTypeTraits struct FieldTypeTraits : public FieldTypeTraits { - static Status from_string(void* buf, const std::string& scan_key) { + static Status from_string(void* buf, const std::string& scan_key, const int precision, + const int scale) { size_t value_len = scan_key.length(); if (value_len > config::string_type_length_soft_limit_bytes) { LOG(WARNING) << "the len of value string is too long, len=" << value_len diff --git a/be/src/olap/wrapper_field.h b/be/src/olap/wrapper_field.h index 021c5d64e7..a4e15c5ecb 100644 --- a/be/src/olap/wrapper_field.h +++ b/be/src/olap/wrapper_field.h @@ -53,7 +53,8 @@ public: // Deserialize field value from incoming string. // // NOTE: the parameter must be a '\0' terminated string. It do not include the null flag. - Status from_string(const std::string& value_string) { + Status from_string(const std::string& value_string, const int precision = 0, + const int scale = 0) { if (_is_string_type) { if (value_string.size() > _var_length) { Slice* slice = reinterpret_cast(cell_ptr()); @@ -63,7 +64,7 @@ public: slice->data = _string_content.get(); } } - return _rep->from_string(_field_buf + 1, value_string); + return _rep->from_string(_field_buf + 1, value_string, precision, scale); } // Attach to a buf. diff --git a/be/src/runtime/decimalv2_value.cpp b/be/src/runtime/decimalv2_value.cpp index 4eef5e1558..41841b3bc5 100644 --- a/be/src/runtime/decimalv2_value.cpp +++ b/be/src/runtime/decimalv2_value.cpp @@ -352,7 +352,8 @@ int DecimalV2Value::parse_from_str(const char* decimal_str, int32_t length) { int32_t error = E_DEC_OK; StringParser::ParseResult result = StringParser::PARSE_SUCCESS; - _value = StringParser::string_to_decimal(decimal_str, length, PRECISION, SCALE, &result); + _value = StringParser::string_to_decimal<__int128>(decimal_str, length, PRECISION, SCALE, + &result); if (result == StringParser::PARSE_FAILURE) { error = E_DEC_BAD_NUM; diff --git a/be/src/runtime/decimalv2_value.h b/be/src/runtime/decimalv2_value.h index f8191eeee6..12c13bee69 100644 --- a/be/src/runtime/decimalv2_value.h +++ b/be/src/runtime/decimalv2_value.h @@ -223,6 +223,20 @@ public: return DecimalV2Value(MAX_INT_VALUE, MAX_FRAC_VALUE); } + static DecimalV2Value get_min_decimal(int precision, int scale) { + DCHECK(precision <= 27 && scale <= 9); + return DecimalV2Value( + -MAX_INT_VALUE % get_scale_base(18 - precision + scale), + MAX_FRAC_VALUE / get_scale_base(9 - scale) * get_scale_base(9 - scale)); + } + + static DecimalV2Value get_max_decimal(int precision, int scale) { + DCHECK(precision <= 27 && scale <= 9); + return DecimalV2Value( + MAX_INT_VALUE % get_scale_base(18 - precision + scale), + MAX_FRAC_VALUE / get_scale_base(9 - scale) * get_scale_base(9 - scale)); + } + static DecimalV2Value from_decimal_val(const DecimalV2Val& val) { return DecimalV2Value(val.value()); } diff --git a/be/src/runtime/primitive_type.cpp b/be/src/runtime/primitive_type.cpp index 1a69e3368b..8e90123c9a 100644 --- a/be/src/runtime/primitive_type.cpp +++ b/be/src/runtime/primitive_type.cpp @@ -43,6 +43,12 @@ PrimitiveType convert_type_to_primitive(FunctionContext::Type type) { return PrimitiveType::TYPE_DATETIME; case FunctionContext::Type::TYPE_DECIMALV2: return PrimitiveType::TYPE_DECIMALV2; + case FunctionContext::Type::TYPE_DECIMAL32: + return PrimitiveType::TYPE_DECIMAL32; + case FunctionContext::Type::TYPE_DECIMAL64: + return PrimitiveType::TYPE_DECIMAL64; + case FunctionContext::Type::TYPE_DECIMAL128: + return PrimitiveType::TYPE_DECIMAL128; case FunctionContext::Type::TYPE_BOOLEAN: return PrimitiveType::TYPE_BOOLEAN; case FunctionContext::Type::TYPE_ARRAY: @@ -90,6 +96,9 @@ bool is_enumeration_type(PrimitiveType type) { case TYPE_DATETIMEV2: case TYPE_TIMEV2: case TYPE_DECIMALV2: + case TYPE_DECIMAL32: + case TYPE_DECIMAL64: + case TYPE_DECIMAL128: case TYPE_BOOLEAN: case TYPE_ARRAY: case TYPE_HLL: @@ -147,17 +156,20 @@ int get_byte_size(PrimitiveType type) { case TYPE_INT: case TYPE_FLOAT: + case TYPE_DECIMAL32: return 4; case TYPE_BIGINT: case TYPE_DOUBLE: case TYPE_TIME: + case TYPE_DECIMAL64: return 8; case TYPE_DATETIME: case TYPE_DATE: case TYPE_LARGEINT: case TYPE_DECIMALV2: + case TYPE_DECIMAL128: return 16; case INVALID_TYPE: @@ -271,6 +283,15 @@ PrimitiveType thrift_to_type(TPrimitiveType::type ttype) { case TPrimitiveType::DECIMALV2: return TYPE_DECIMALV2; + case TPrimitiveType::DECIMAL32: + return TYPE_DECIMAL32; + + case TPrimitiveType::DECIMAL64: + return TYPE_DECIMAL64; + + case TPrimitiveType::DECIMAL128: + return TYPE_DECIMAL128; + case TPrimitiveType::CHAR: return TYPE_CHAR; @@ -353,6 +374,15 @@ TPrimitiveType::type to_thrift(PrimitiveType ptype) { case TYPE_DECIMALV2: return TPrimitiveType::DECIMALV2; + case TYPE_DECIMAL32: + return TPrimitiveType::DECIMAL32; + + case TYPE_DECIMAL64: + return TPrimitiveType::DECIMAL64; + + case TYPE_DECIMAL128: + return TPrimitiveType::DECIMAL128; + case TYPE_CHAR: return TPrimitiveType::CHAR; @@ -435,6 +465,15 @@ std::string type_to_string(PrimitiveType t) { case TYPE_DECIMALV2: return "DECIMALV2"; + case TYPE_DECIMAL32: + return "DECIMAL32"; + + case TYPE_DECIMAL64: + return "DECIMAL64"; + + case TYPE_DECIMAL128: + return "DECIMAL128"; + case TYPE_CHAR: return "CHAR"; @@ -518,6 +557,15 @@ std::string type_to_odbc_string(PrimitiveType t) { case TYPE_DECIMALV2: return "decimalv2"; + case TYPE_DECIMAL32: + return "decimal32"; + + case TYPE_DECIMAL64: + return "decimal64"; + + case TYPE_DECIMAL128: + return "decimal128"; + case TYPE_CHAR: return "char"; @@ -587,11 +635,13 @@ int get_slot_size(PrimitiveType type) { case TYPE_INT: case TYPE_DATEV2: case TYPE_FLOAT: + case TYPE_DECIMAL32: return 4; case TYPE_BIGINT: case TYPE_DOUBLE: case TYPE_TIME: + case TYPE_DECIMAL64: return 8; case TYPE_LARGEINT: @@ -605,6 +655,7 @@ int get_slot_size(PrimitiveType type) { return sizeof(DateTimeValue); case TYPE_DECIMALV2: + case TYPE_DECIMAL128: return 16; case INVALID_TYPE: diff --git a/be/src/runtime/primitive_type.h b/be/src/runtime/primitive_type.h index 35b4b3160a..bbe473e212 100644 --- a/be/src/runtime/primitive_type.h +++ b/be/src/runtime/primitive_type.h @@ -65,6 +65,9 @@ enum PrimitiveType { TYPE_DATEV2, /* 25 */ TYPE_DATETIMEV2, /* 26 */ TYPE_TIMEV2, /* 27 */ + TYPE_DECIMAL32, /* 28 */ + TYPE_DECIMAL64, /* 29 */ + TYPE_DECIMAL128, /* 30 */ }; PrimitiveType convert_type_to_primitive(FunctionContext::Type type); @@ -154,6 +157,21 @@ struct PrimitiveTypeTraits { using ColumnType = vectorized::ColumnDecimal; }; template <> +struct PrimitiveTypeTraits { + using CppType = int32_t; + using ColumnType = vectorized::ColumnDecimal; +}; +template <> +struct PrimitiveTypeTraits { + using CppType = int64_t; + using ColumnType = vectorized::ColumnDecimal; +}; +template <> +struct PrimitiveTypeTraits { + using CppType = __int128_t; + using ColumnType = vectorized::ColumnDecimal; +}; +template <> struct PrimitiveTypeTraits { using CppType = __int128_t; using ColumnType = vectorized::ColumnInt128; diff --git a/be/src/runtime/raw_value.cpp b/be/src/runtime/raw_value.cpp index 740ca3e47a..c86bf53f0c 100644 --- a/be/src/runtime/raw_value.cpp +++ b/be/src/runtime/raw_value.cpp @@ -22,10 +22,12 @@ #include +#include "common/consts.h" #include "runtime/collection_value.h" #include "runtime/large_int_value.h" #include "runtime/tuple.h" #include "util/types.h" +#include "vec/io/io_helper.h" namespace doris { @@ -92,6 +94,18 @@ void RawValue::print_value_as_bytes(const void* value, const TypeDescriptor& typ stream->write(chars, sizeof(DecimalV2Value)); break; + case TYPE_DECIMAL32: + stream->write(chars, 4); + break; + + case TYPE_DECIMAL64: + stream->write(chars, 8); + break; + + case TYPE_DECIMAL128: + stream->write(chars, 16); + break; + case TYPE_LARGEINT: stream->write(chars, sizeof(__int128)); break; @@ -174,6 +188,24 @@ void RawValue::print_value(const void* value, const TypeDescriptor& type, int sc *stream << DecimalV2Value(reinterpret_cast(value)->value).to_string(); break; + case TYPE_DECIMAL32: { + auto decimal_val = reinterpret_cast(value); + write_text(*decimal_val, type.scale, *stream); + break; + } + + case TYPE_DECIMAL64: { + auto decimal_val = reinterpret_cast(value); + write_text(*decimal_val, type.scale, *stream); + break; + } + + case TYPE_DECIMAL128: { + auto decimal_val = reinterpret_cast(value); + write_text(*decimal_val, type.scale, *stream); + break; + } + case TYPE_LARGEINT: *stream << reinterpret_cast(value)->value; break; @@ -310,6 +342,19 @@ void RawValue::write(const void* value, void* dst, const TypeDescriptor& type, M *reinterpret_cast(dst) = *reinterpret_cast(value); break; + case TYPE_DECIMAL32: + *reinterpret_cast(dst) = + *reinterpret_cast(value); + break; + case TYPE_DECIMAL64: + *reinterpret_cast(dst) = + *reinterpret_cast(value); + break; + case TYPE_DECIMAL128: + *reinterpret_cast(dst) = + *reinterpret_cast(value); + break; + case TYPE_OBJECT: case TYPE_HLL: case TYPE_QUANTILE_STATE: @@ -412,6 +457,19 @@ void RawValue::write(const void* value, const TypeDescriptor& type, void* dst, u *reinterpret_cast(dst) = *reinterpret_cast(value); break; + case TYPE_DECIMAL32: + *reinterpret_cast(dst) = + *reinterpret_cast(value); + break; + case TYPE_DECIMAL64: + *reinterpret_cast(dst) = + *reinterpret_cast(value); + break; + case TYPE_DECIMAL128: + *reinterpret_cast(dst) = + *reinterpret_cast(value); + break; + default: DCHECK(false) << "RawValue::write(): bad type: " << type.debug_string(); } @@ -510,6 +568,25 @@ int RawValue::compare(const void* v1, const void* v2, const TypeDescriptor& type return (decimal_value1 > decimal_value2) ? 1 : (decimal_value1 < decimal_value2 ? -1 : 0); } + case TYPE_DECIMAL32: { + i1 = *reinterpret_cast(v1); + i2 = *reinterpret_cast(v2); + return i1 > i2 ? 1 : (i1 < i2 ? -1 : 0); + } + + case TYPE_DECIMAL64: { + b1 = *reinterpret_cast(v1); + b2 = *reinterpret_cast(v2); + return b1 > b2 ? 1 : (b1 < b2 ? -1 : 0); + } + + case TYPE_DECIMAL128: { + __int128 large_int_value1 = reinterpret_cast(v1)->value; + __int128 large_int_value2 = reinterpret_cast(v2)->value; + return large_int_value1 > large_int_value2 ? 1 + : (large_int_value1 < large_int_value2 ? -1 : 0); + } + case TYPE_LARGEINT: { __int128 large_int_value1 = reinterpret_cast(v1)->value; __int128 large_int_value2 = reinterpret_cast(v2)->value; diff --git a/be/src/runtime/raw_value.h b/be/src/runtime/raw_value.h index 424ede11e8..a83a64bb1e 100644 --- a/be/src/runtime/raw_value.h +++ b/be/src/runtime/raw_value.h @@ -22,6 +22,7 @@ #include +#include "common/consts.h" #include "common/logging.h" #include "runtime/string_value.h" #include "runtime/types.h" @@ -168,6 +169,14 @@ inline bool RawValue::lt(const void* v1, const void* v2, const TypeDescriptor& t return reinterpret_cast(v1)->value < reinterpret_cast(v2)->value; + case TYPE_DECIMAL32: + return *reinterpret_cast(v1) < *reinterpret_cast(v2); + case TYPE_DECIMAL64: + return *reinterpret_cast(v1) < *reinterpret_cast(v2); + case TYPE_DECIMAL128: + return reinterpret_cast(v1)->value < + reinterpret_cast(v2)->value; + case TYPE_LARGEINT: return reinterpret_cast(v1)->value < reinterpret_cast(v2)->value; @@ -224,6 +233,14 @@ inline bool RawValue::eq(const void* v1, const void* v2, const TypeDescriptor& t return reinterpret_cast(v1)->value == reinterpret_cast(v2)->value; + case TYPE_DECIMAL32: + return *reinterpret_cast(v1) == *reinterpret_cast(v2); + case TYPE_DECIMAL64: + return *reinterpret_cast(v1) == *reinterpret_cast(v2); + case TYPE_DECIMAL128: + return reinterpret_cast(v1)->value == + reinterpret_cast(v2)->value; + case TYPE_LARGEINT: return reinterpret_cast(v1)->value == reinterpret_cast(v2)->value; @@ -286,6 +303,12 @@ inline uint32_t RawValue::get_hash_value(const void* v, const PrimitiveType& typ case TYPE_DECIMALV2: return HashUtil::hash(v, 16, seed); + case TYPE_DECIMAL32: + return HashUtil::hash(v, 4, seed); + case TYPE_DECIMAL64: + return HashUtil::hash(v, 8, seed); + case TYPE_DECIMAL128: + return HashUtil::hash(v, 16, seed); case TYPE_LARGEINT: return HashUtil::hash(v, 16, seed); @@ -346,6 +369,12 @@ inline uint32_t RawValue::get_hash_value_fvn(const void* v, const PrimitiveType& case TYPE_DECIMALV2: return HashUtil::fnv_hash(v, 16, seed); + case TYPE_DECIMAL32: + return HashUtil::fnv_hash(v, 4, seed); + case TYPE_DECIMAL64: + return HashUtil::fnv_hash(v, 8, seed); + case TYPE_DECIMAL128: + return HashUtil::fnv_hash(v, 16, seed); case TYPE_LARGEINT: return HashUtil::fnv_hash(v, 16, seed); @@ -420,6 +449,13 @@ inline uint32_t RawValue::zlib_crc32(const void* v, const TypeDescriptor& type, seed = HashUtil::zlib_crc_hash(&int_val, sizeof(int_val), seed); return HashUtil::zlib_crc_hash(&frac_val, sizeof(frac_val), seed); } + + case TYPE_DECIMAL32: + return HashUtil::zlib_crc_hash(v, 4, seed); + case TYPE_DECIMAL64: + return HashUtil::zlib_crc_hash(v, 8, seed); + case TYPE_DECIMAL128: + return HashUtil::zlib_crc_hash(v, 16, seed); default: DCHECK(false) << "invalid type: " << type; return 0; @@ -480,6 +516,12 @@ inline uint32_t RawValue::zlib_crc32(const void* v, size_t len, const TypeDescri seed = HashUtil::zlib_crc_hash(&int_val, sizeof(int_val), seed); return HashUtil::zlib_crc_hash(&frac_val, sizeof(frac_val), seed); } + case TYPE_DECIMAL32: + return HashUtil::zlib_crc_hash(v, 4, seed); + case TYPE_DECIMAL64: + return HashUtil::zlib_crc_hash(v, 8, seed); + case TYPE_DECIMAL128: + return HashUtil::zlib_crc_hash(v, 16, seed); default: DCHECK(false) << "invalid type: " << type; return 0; diff --git a/be/src/runtime/types.cpp b/be/src/runtime/types.cpp index c53febfc40..c5f3eb89b9 100644 --- a/be/src/runtime/types.cpp +++ b/be/src/runtime/types.cpp @@ -37,7 +37,8 @@ TypeDescriptor::TypeDescriptor(const std::vector& types, int* idx) if (type == TYPE_CHAR || type == TYPE_VARCHAR || type == TYPE_HLL) { DCHECK(scalar_type.__isset.len); len = scalar_type.len; - } else if (type == TYPE_DECIMALV2) { + } else if (type == TYPE_DECIMALV2 || type == TYPE_DECIMAL32 || type == TYPE_DECIMAL64 || + type == TYPE_DECIMAL128) { DCHECK(scalar_type.__isset.precision); DCHECK(scalar_type.__isset.scale); precision = scalar_type.precision; @@ -113,7 +114,8 @@ void TypeDescriptor::to_thrift(TTypeDesc* thrift_type) const { if (type == TYPE_CHAR || type == TYPE_VARCHAR || type == TYPE_HLL) { // DCHECK_NE(len, -1); scalar_type.__set_len(len); - } else if (type == TYPE_DECIMALV2) { + } else if (type == TYPE_DECIMALV2 || type == TYPE_DECIMAL32 || type == TYPE_DECIMAL64 || + type == TYPE_DECIMAL128) { DCHECK_NE(precision, -1); DCHECK_NE(scale, -1); scalar_type.__set_precision(precision); @@ -131,7 +133,8 @@ void TypeDescriptor::to_protobuf(PTypeDesc* ptype) const { scalar_type->set_type(doris::to_thrift(type)); if (type == TYPE_CHAR || type == TYPE_VARCHAR || type == TYPE_HLL) { scalar_type->set_len(len); - } else if (type == TYPE_DECIMALV2) { + } else if (type == TYPE_DECIMALV2 || type == TYPE_DECIMAL32 || type == TYPE_DECIMAL64 || + type == TYPE_DECIMAL128) { DCHECK_NE(precision, -1); DCHECK_NE(scale, -1); scalar_type->set_precision(precision); @@ -158,7 +161,8 @@ TypeDescriptor::TypeDescriptor(const google::protobuf::RepeatedPtrField node_type; node_type.emplace_back(); TScalarType scalarType; scalarType.__set_type(to_thrift(type)); scalarType.__set_len(-1); - scalarType.__set_precision(-1); - scalarType.__set_scale(-1); + scalarType.__set_precision(precision); + scalarType.__set_scale(scale); node_type.back().__set_scalar_type(scalarType); type_desc.__set_types(node_type); return type_desc; diff --git a/be/src/runtime/types.h b/be/src/runtime/types.h index 20cdc4663b..1eb392ab02 100644 --- a/be/src/runtime/types.h +++ b/be/src/runtime/types.h @@ -40,22 +40,22 @@ struct TypeDescriptor { PrimitiveType type; /// Only set if type == TYPE_CHAR or type == TYPE_VARCHAR int len; - static const int MAX_VARCHAR_LENGTH = OLAP_VARCHAR_MAX_LENGTH; - static const int MAX_CHAR_LENGTH = 255; - static const int MAX_CHAR_INLINE_LENGTH = 128; + static constexpr int MAX_VARCHAR_LENGTH = OLAP_VARCHAR_MAX_LENGTH; + static constexpr int MAX_CHAR_LENGTH = 255; + static constexpr int MAX_CHAR_INLINE_LENGTH = 128; /// Only set if type == TYPE_DECIMAL int precision; int scale; /// Must be kept in sync with FE's max precision/scale. - static const int MAX_PRECISION = 38; - static const int MAX_SCALE = MAX_PRECISION; + static constexpr int MAX_PRECISION = 38; + static constexpr int MAX_SCALE = MAX_PRECISION; /// The maximum precision representable by a 4-byte decimal (Decimal4Value) - static const int MAX_DECIMAL4_PRECISION = 9; + static constexpr int MAX_DECIMAL4_PRECISION = 9; /// The maximum precision representable by a 8-byte decimal (Decimal8Value) - static const int MAX_DECIMAL8_PRECISION = 18; + static constexpr int MAX_DECIMAL8_PRECISION = 18; // Empty for scalar types std::vector children; @@ -218,6 +218,6 @@ private: std::ostream& operator<<(std::ostream& os, const TypeDescriptor& type); -TTypeDesc create_type_desc(PrimitiveType type); +TTypeDesc create_type_desc(PrimitiveType type, int precision = 0, int scale = 0); } // namespace doris diff --git a/be/src/udf/udf.h b/be/src/udf/udf.h index 2a5296f477..c74d566d42 100644 --- a/be/src/udf/udf.h +++ b/be/src/udf/udf.h @@ -90,7 +90,10 @@ public: TYPE_QUANTILE_STATE, TYPE_DATEV2, TYPE_DATETIMEV2, - TYPE_TIMEV2 + TYPE_TIMEV2, + TYPE_DECIMAL32, + TYPE_DECIMAL64, + TYPE_DECIMAL128 }; struct TypeDesc { @@ -524,6 +527,85 @@ struct BigIntVal : public AnyVal { bool operator!=(const BigIntVal& other) const { return !(*this == other); } }; +struct Decimal32Val : public AnyVal { + int32_t val; + + Decimal32Val() : val(0) {} + Decimal32Val(int32_t val) : val(val) {} + + static Decimal32Val null() { + Decimal32Val result; + result.is_null = true; + return result; + } + + bool operator==(const Decimal32Val& other) const { + if (is_null && other.is_null) { + return true; + } + + if (is_null || other.is_null) { + return false; + } + + return val == other.val; + } + bool operator!=(const Decimal32Val& other) const { return !(*this == other); } +}; + +struct Decimal64Val : public AnyVal { + int64_t val; + + Decimal64Val() : val(0) {} + Decimal64Val(int64_t val) : val(val) {} + + static Decimal64Val null() { + Decimal64Val result; + result.is_null = true; + return result; + } + + bool operator==(const Decimal64Val& other) const { + if (is_null && other.is_null) { + return true; + } + + if (is_null || other.is_null) { + return false; + } + + return val == other.val; + } + bool operator!=(const Decimal64Val& other) const { return !(*this == other); } +}; + +struct Decimal128Val : public AnyVal { + __int128 val; + + Decimal128Val() : val(0) {} + + Decimal128Val(__int128 large_value) : val(large_value) {} + + static Decimal128Val null() { + Decimal128Val result; + result.is_null = true; + return result; + } + + bool operator==(const Decimal128Val& other) const { + if (is_null && other.is_null) { + return true; + } + + if (is_null || other.is_null) { + return false; + } + + return val == other.val; + } + bool operator!=(const Decimal128Val& other) const { return !(*this == other); } +}; + struct FloatVal : public AnyVal { float val; @@ -780,3 +862,6 @@ using doris_udf::DateTimeVal; using doris_udf::HllVal; using doris_udf::FunctionContext; using doris_udf::CollectionVal; +using doris_udf::Decimal32Val; +using doris_udf::Decimal64Val; +using doris_udf::Decimal128Val; diff --git a/be/src/util/string_parser.cpp b/be/src/util/string_parser.cpp index c3a59e47fe..16e24e6154 100644 --- a/be/src/util/string_parser.cpp +++ b/be/src/util/string_parser.cpp @@ -21,6 +21,7 @@ #include "string_parser.hpp" #include "runtime/large_int_value.h" +#include "vec/common/int_exp.h" namespace doris { @@ -29,4 +30,19 @@ __int128 StringParser::numeric_limits<__int128>(bool negative) { return negative ? MIN_INT128 : MAX_INT128; } +template <> +int32_t StringParser::get_scale_multiplier(int scale) { + return common::exp10_i32(scale); +} + +template <> +int64_t StringParser::get_scale_multiplier(int scale) { + return common::exp10_i64(scale); +} + +template <> +__int128 StringParser::get_scale_multiplier(int scale) { + return common::exp10_i128(scale); +} + } // namespace doris diff --git a/be/src/util/string_parser.hpp b/be/src/util/string_parser.hpp index 82549d186c..805b93a61f 100644 --- a/be/src/util/string_parser.hpp +++ b/be/src/util/string_parser.hpp @@ -66,7 +66,8 @@ public: template static T numeric_limits(bool negative); - static inline __int128 get_scale_multiplier(int scale); + template + static T get_scale_multiplier(int scale); // This is considerably faster than glibc's implementation (25x). // In the case of overflow, the max/min value for the data type will be returned. @@ -130,8 +131,9 @@ public: return string_to_bool_internal(s + i, len - i, result); } - static inline __int128 string_to_decimal(const char* s, int len, int type_precision, - int type_scale, ParseResult* result); + template + static inline T string_to_decimal(const char* s, int len, int type_precision, int type_scale, + ParseResult* result); template static Status split_string_to_map(const std::string& base, const T element_separator, @@ -611,56 +613,9 @@ inline int StringParser::StringParseTraits<__int128>::max_ascii_len() { return 39; } -inline __int128 StringParser::get_scale_multiplier(int scale) { - DCHECK_GE(scale, 0); - static const __int128 values[] = { - static_cast<__int128>(1ll), - static_cast<__int128>(10ll), - static_cast<__int128>(100ll), - static_cast<__int128>(1000ll), - static_cast<__int128>(10000ll), - static_cast<__int128>(100000ll), - static_cast<__int128>(1000000ll), - static_cast<__int128>(10000000ll), - static_cast<__int128>(100000000ll), - static_cast<__int128>(1000000000ll), - static_cast<__int128>(10000000000ll), - static_cast<__int128>(100000000000ll), - static_cast<__int128>(1000000000000ll), - static_cast<__int128>(10000000000000ll), - static_cast<__int128>(100000000000000ll), - static_cast<__int128>(1000000000000000ll), - static_cast<__int128>(10000000000000000ll), - static_cast<__int128>(100000000000000000ll), - static_cast<__int128>(1000000000000000000ll), - static_cast<__int128>(1000000000000000000ll) * 10ll, - static_cast<__int128>(1000000000000000000ll) * 100ll, - static_cast<__int128>(1000000000000000000ll) * 1000ll, - static_cast<__int128>(1000000000000000000ll) * 10000ll, - static_cast<__int128>(1000000000000000000ll) * 100000ll, - static_cast<__int128>(1000000000000000000ll) * 1000000ll, - static_cast<__int128>(1000000000000000000ll) * 10000000ll, - static_cast<__int128>(1000000000000000000ll) * 100000000ll, - static_cast<__int128>(1000000000000000000ll) * 1000000000ll, - static_cast<__int128>(1000000000000000000ll) * 10000000000ll, - static_cast<__int128>(1000000000000000000ll) * 100000000000ll, - static_cast<__int128>(1000000000000000000ll) * 1000000000000ll, - static_cast<__int128>(1000000000000000000ll) * 10000000000000ll, - static_cast<__int128>(1000000000000000000ll) * 100000000000000ll, - static_cast<__int128>(1000000000000000000ll) * 1000000000000000ll, - static_cast<__int128>(1000000000000000000ll) * 10000000000000000ll, - static_cast<__int128>(1000000000000000000ll) * 100000000000000000ll, - static_cast<__int128>(1000000000000000000ll) * 100000000000000000ll * 10ll, - static_cast<__int128>(1000000000000000000ll) * 100000000000000000ll * 100ll, - static_cast<__int128>(1000000000000000000ll) * 100000000000000000ll * 1000ll}; - if (scale >= 0 && scale < 39) { - return values[scale]; - } - return -1; // Overflow -} - -inline __int128 StringParser::string_to_decimal(const char* s, int len, int type_precision, - int type_scale, ParseResult* result) { +template +inline T StringParser::string_to_decimal(const char* s, int len, int type_precision, int type_scale, + ParseResult* result) { // Special cases: // 1) '' == Fail, an empty string fails to parse. // 2) ' # ' == #, leading and trailing white space is ignored. @@ -715,7 +670,7 @@ inline __int128 StringParser::string_to_decimal(const char* s, int len, int type int precision = 0; bool found_exponent = false; int8_t exponent = 0; - __int128 value = 0; + T value = 0; for (int i = 0; i < len; ++i) { const char& c = s[i]; if (LIKELY('0' <= c && c <= '9')) { @@ -748,7 +703,7 @@ inline __int128 StringParser::string_to_decimal(const char* s, int len, int type return 0; } *result = StringParser::PARSE_SUCCESS; - value *= get_scale_multiplier(type_scale - scale); + value *= get_scale_multiplier(type_scale - scale); return is_negative ? -value : value; } } @@ -759,7 +714,7 @@ inline __int128 StringParser::string_to_decimal(const char* s, int len, int type // Ex: 0.1e3 (which at this point would have precision == 1 and scale == 1), the // scale must be set to 0 and the value set to 100 which means a precision of 3. precision += exponent - scale; - value *= get_scale_multiplier(exponent - scale); + value *= get_scale_multiplier(exponent - scale); scale = 0; } else { // Ex: 100e-4, the scale must be set to 4 but no adjustment to the value is needed, @@ -785,10 +740,10 @@ inline __int128 StringParser::string_to_decimal(const char* s, int len, int type shift -= truncated_digit_count; } if (shift > 0) { - __int128 divisor = get_scale_multiplier(shift); + T divisor = get_scale_multiplier(shift); if (LIKELY(divisor >= 0)) { value /= divisor; - __int128 remainder = value % divisor; + T remainder = value % divisor; if ((remainder > 0 ? remainder : -remainder) >= (divisor >> 1)) { value += 1; } @@ -803,7 +758,7 @@ inline __int128 StringParser::string_to_decimal(const char* s, int len, int type } if (type_scale > scale) { - value *= get_scale_multiplier(type_scale - scale); + value *= get_scale_multiplier(type_scale - scale); } return is_negative ? -value : value; diff --git a/be/src/util/symbols_util.cpp b/be/src/util/symbols_util.cpp index 75492ec540..c5ff9b8b37 100644 --- a/be/src/util/symbols_util.cpp +++ b/be/src/util/symbols_util.cpp @@ -164,6 +164,15 @@ static void append_any_val_type(int namespace_id, const TypeDescriptor& type, case TYPE_DECIMALV2: append_mangled_token("DecimalV2Val", s); break; + case TYPE_DECIMAL32: + append_mangled_token("Decimal32Val", s); + break; + case TYPE_DECIMAL64: + append_mangled_token("Decimal64Val", s); + break; + case TYPE_DECIMAL128: + append_mangled_token("Decimal128Val", s); + break; default: DCHECK(false) << "NYI: " << type.debug_string(); } diff --git a/be/src/vec/aggregate_functions/aggregate_function_avg.h b/be/src/vec/aggregate_functions/aggregate_function_avg.h index 03bb87f621..eab938e4c1 100644 --- a/be/src/vec/aggregate_functions/aggregate_function_avg.h +++ b/be/src/vec/aggregate_functions/aggregate_function_avg.h @@ -46,13 +46,15 @@ struct AggregateFunctionAvgData { // null is handled in AggregationNode::_get_without_key_result return static_cast(sum); } - // to keep the same result with row vesion; see AggregateFunctions::decimalv2_avg_get_value - if constexpr (std::is_same_v && std::is_same_v) { - DecimalV2Value decimal_val_count(count, 0); - DecimalV2Value decimal_val_sum(static_cast(sum)); - DecimalV2Value cal_ret = decimal_val_sum / decimal_val_count; - Decimal128 ret(cal_ret.value()); - return ret; + if (!config::enable_decimalv3) { + // to keep the same result with row vesion; see AggregateFunctions::decimalv2_avg_get_value + if constexpr (std::is_same_v && std::is_same_v) { + DecimalV2Value decimal_val_count(count, 0); + DecimalV2Value decimal_val_sum(static_cast(sum)); + DecimalV2Value cal_ret = decimal_val_sum / decimal_val_count; + Decimal128 ret(cal_ret.value()); + return ret; + } } return static_cast(sum) / count; } diff --git a/be/src/vec/aggregate_functions/aggregate_function_min_max.cpp b/be/src/vec/aggregate_functions/aggregate_function_min_max.cpp index be4ce7ef0d..43b6717c04 100644 --- a/be/src/vec/aggregate_functions/aggregate_function_min_max.cpp +++ b/be/src/vec/aggregate_functions/aggregate_function_min_max.cpp @@ -54,8 +54,16 @@ static IAggregateFunction* create_aggregate_function_single_value(const String& return new AggregateFunctionTemplate>, false>( argument_type); } + if (which.idx == TypeIndex::Decimal32) { + return new AggregateFunctionTemplate>, false>( + argument_type); + } + if (which.idx == TypeIndex::Decimal64) { + return new AggregateFunctionTemplate>, false>( + argument_type); + } if (which.idx == TypeIndex::Decimal128) { - return new AggregateFunctionTemplate>, false>( + return new AggregateFunctionTemplate>, false>( argument_type); } return nullptr; @@ -86,4 +94,4 @@ void register_aggregate_function_minmax(AggregateFunctionSimpleFactory& factory) factory.register_function("min", create_aggregate_function_min); } -} // namespace doris::vectorized \ No newline at end of file +} // namespace doris::vectorized diff --git a/be/src/vec/aggregate_functions/aggregate_function_min_max.h b/be/src/vec/aggregate_functions/aggregate_function_min_max.h index e8371c11bb..ee91c0bbb9 100644 --- a/be/src/vec/aggregate_functions/aggregate_function_min_max.h +++ b/be/src/vec/aggregate_functions/aggregate_function_min_max.h @@ -124,25 +124,25 @@ public: } }; -/// For numeric values. -template <> -struct SingleValueDataFixed { +/// For decimal values. +template +struct SingleValueDataDecimal { private: - using Self = SingleValueDataFixed; + using Self = SingleValueDataDecimal; + using Type = typename NativeType::Type; bool has_value = false; /// We need to remember if at least one value has been passed. This is necessary for AggregateFunctionIf. - int128_t value; + Type value; public: bool has() const { return has_value; } void insert_result_into(IColumn& to) const { if (has()) { - DecimalV2Value decimal(value); - assert_cast&>(to).insert_data((const char*)&decimal, 0); + assert_cast&>(to).insert_data((const char*)&value, 0); } else { - assert_cast&>(to).insert_default(); + assert_cast&>(to).insert_default(); } } @@ -168,7 +168,7 @@ public: void change(const IColumn& column, size_t row_num, Arena*) { has_value = true; - value = assert_cast&>(column).get_data()[row_num]; + value = assert_cast&>(column).get_data()[row_num]; } /// Assuming to.has() @@ -178,8 +178,7 @@ public: } bool change_if_less(const IColumn& column, size_t row_num, Arena* arena) { - if (!has() || - assert_cast&>(column).get_data()[row_num] < value) { + if (!has() || assert_cast&>(column).get_data()[row_num] < value) { change(column, row_num, arena); return true; } else { @@ -197,8 +196,7 @@ public: } bool change_if_greater(const IColumn& column, size_t row_num, Arena* arena) { - if (!has() || - assert_cast&>(column).get_data()[row_num] > value) { + if (!has() || assert_cast&>(column).get_data()[row_num] > value) { change(column, row_num, arena); return true; } else { @@ -218,8 +216,7 @@ public: bool is_equal_to(const Self& to) const { return has() && to.value == value; } bool is_equal_to(const IColumn& column, size_t row_num) const { - return has() && - assert_cast&>(column).get_data()[row_num] == value; + return has() && assert_cast&>(column).get_data()[row_num] == value; } }; @@ -471,4 +468,4 @@ AggregateFunctionPtr create_aggregate_function_min(const std::string& name, const Array& parameters, const bool result_is_nullable); -} // namespace doris::vectorized \ No newline at end of file +} // namespace doris::vectorized diff --git a/be/src/vec/aggregate_functions/aggregate_function_min_max_by.cpp b/be/src/vec/aggregate_functions/aggregate_function_min_max_by.cpp index 8b834f4202..506c037130 100644 --- a/be/src/vec/aggregate_functions/aggregate_function_min_max_by.cpp +++ b/be/src/vec/aggregate_functions/aggregate_function_min_max_by.cpp @@ -51,8 +51,16 @@ static IAggregateFunction* create_aggregate_function_min_max_by_impl( return new AggregateFunctionTemplate>, false>( value_arg_type, key_arg_type); } + if (which.idx == TypeIndex::Decimal32) { + return new AggregateFunctionTemplate>, false>( + value_arg_type, key_arg_type); + } + if (which.idx == TypeIndex::Decimal64) { + return new AggregateFunctionTemplate>, false>( + value_arg_type, key_arg_type); + } if (which.idx == TypeIndex::Decimal128) { - return new AggregateFunctionTemplate>, false>( + return new AggregateFunctionTemplate>, false>( value_arg_type, key_arg_type); } return nullptr; @@ -91,11 +99,26 @@ static IAggregateFunction* create_aggregate_function_min_max_by(const String& na SingleValueDataFixed>( argument_types); } - if (which.idx == TypeIndex::Decimal128) { + if (which.idx == TypeIndex::Decimal128 && !config::enable_decimalv3) { return create_aggregate_function_min_max_by_impl>( argument_types); } + if (which.idx == TypeIndex::Decimal32) { + return create_aggregate_function_min_max_by_impl>( + argument_types); + } + if (which.idx == TypeIndex::Decimal64) { + return create_aggregate_function_min_max_by_impl>( + argument_types); + } + if (which.idx == TypeIndex::Decimal128) { + return create_aggregate_function_min_max_by_impl>( + argument_types); + } return nullptr; } diff --git a/be/src/vec/aggregate_functions/aggregate_function_stddev.cpp b/be/src/vec/aggregate_functions/aggregate_function_stddev.cpp index 31f455655e..357b09a7a2 100644 --- a/be/src/vec/aggregate_functions/aggregate_function_stddev.cpp +++ b/be/src/vec/aggregate_functions/aggregate_function_stddev.cpp @@ -43,9 +43,20 @@ static IAggregateFunction* create_function_single_value(const String& name, FOR_NUMERIC_TYPES(DISPATCH) #undef DISPATCH - if (which.is_decimal()) { - return new AggregateFunctionTemplate>>, - is_nullable>(argument_types); + if (which.is_decimal32()) { + return new AggregateFunctionTemplate< + NameData>>, is_nullable>( + argument_types); + } + if (which.is_decimal64()) { + return new AggregateFunctionTemplate< + NameData>>, is_nullable>( + argument_types); + } + if (which.is_decimal128()) { + return new AggregateFunctionTemplate< + NameData>>, is_nullable>( + argument_types); } DCHECK(false) << "with unknowed type, failed in create_aggregate_function_stddev_variance"; return nullptr; @@ -112,4 +123,4 @@ void register_aggregate_function_stddev_variance_samp(AggregateFunctionSimpleFac factory.register_function("stddev_samp", create_aggregate_function_stddev_samp, true); } -} // namespace doris::vectorized \ No newline at end of file +} // namespace doris::vectorized diff --git a/be/src/vec/aggregate_functions/aggregate_function_stddev.h b/be/src/vec/aggregate_functions/aggregate_function_stddev.h index 3531e8db49..71755f5534 100644 --- a/be/src/vec/aggregate_functions/aggregate_function_stddev.h +++ b/be/src/vec/aggregate_functions/aggregate_function_stddev.h @@ -99,7 +99,7 @@ struct BaseData { int64_t count; }; -template +template struct BaseDatadecimal { BaseDatadecimal() : mean(0), m2(0), count(0) {} virtual ~BaseDatadecimal() = default; @@ -162,9 +162,18 @@ struct BaseDatadecimal { } void add(const IColumn* column, size_t row_num) { - DecimalV2Value source_data = DecimalV2Value(); - const auto& sources = static_cast&>(*column); - source_data = (DecimalV2Value)sources.get_data()[row_num]; + const auto& sources = static_cast&>(*column); + Field field = sources[row_num]; + auto decimal_field = field.template get>(); + int128_t value; + if (decimal_field.get_scale() > DecimalV2Value::SCALE) { + value = static_cast(decimal_field.get_value()) / + (decimal_field.get_scale_multiplier() / DecimalV2Value::ONE_BILLION); + } else { + value = static_cast(decimal_field.get_value()) * + (DecimalV2Value::ONE_BILLION / decimal_field.get_scale_multiplier()); + } + DecimalV2Value source_data = DecimalV2Value(value); DecimalV2Value new_count = DecimalV2Value(); new_count.assign_from_double(count); diff --git a/be/src/vec/columns/column.h b/be/src/vec/columns/column.h index a2a015721a..b6c8f3e67c 100644 --- a/be/src/vec/columns/column.h +++ b/be/src/vec/columns/column.h @@ -497,15 +497,18 @@ public: virtual bool is_date_type() const { return is_date; } virtual bool is_date_v2_type() const { return is_date_v2; } virtual bool is_datetime_type() const { return is_date_time; } + virtual bool is_decimalv2_type() const { return is_decimalv2; } virtual void set_date_type() { is_date = true; } virtual void set_date_v2_type() { is_date_v2 = true; } virtual void set_datetime_type() { is_date_time = true; } + virtual void set_decimalv2_type() { is_decimalv2 = true; } // todo(wb): a temporary implemention, need re-abstract here bool is_date = false; bool is_date_time = false; bool is_date_v2 = false; + bool is_decimalv2 = false; protected: /// Template is to devirtualize calls to insert_from method. diff --git a/be/src/vec/columns/column_decimal.cpp b/be/src/vec/columns/column_decimal.cpp index c50b7938be..315c1103c0 100644 --- a/be/src/vec/columns/column_decimal.cpp +++ b/be/src/vec/columns/column_decimal.cpp @@ -20,6 +20,7 @@ #include "vec/columns/column_decimal.h" +#include "common/config.h" #include "util/simd/bits.h" #include "vec/common/arena.h" #include "vec/common/assert_cast.h" @@ -156,6 +157,23 @@ void ColumnDecimal::insert_data(const char* src, size_t /*length*/) { data.emplace_back(tmp); } +template +void ColumnDecimal::insert_many_fix_len_data(const char* data_ptr, size_t num) { + if (this->is_decimalv2_type()) { + for (int i = 0; i < num; i++) { + const char* cur_ptr = data_ptr + sizeof(decimal12_t) * i; + int64_t int_value = *(int64_t*)(cur_ptr); + int32_t frac_value = *(int32_t*)(cur_ptr + sizeof(int64_t)); + DecimalV2Value decimal_val(int_value, frac_value); + this->insert_data(reinterpret_cast(&decimal_val), 0); + } + } else { + size_t old_size = data.size(); + data.resize(old_size + num); + memcpy(data.data() + old_size, data_ptr, num * sizeof(T)); + } +} + template void ColumnDecimal::insert_range_from(const IColumn& src, size_t start, size_t length) { const ColumnDecimal& src_vec = assert_cast(src); @@ -284,6 +302,21 @@ void ColumnDecimal::get_extremes(Field& min, Field& max) const { max = NearestFieldType(cur_max, scale); } +template <> +Decimal32 ColumnDecimal::get_scale_multiplier() const { + return common::exp10_i32(scale); +} + +template <> +Decimal64 ColumnDecimal::get_scale_multiplier() const { + return common::exp10_i64(scale); +} + +template <> +Decimal128 ColumnDecimal::get_scale_multiplier() const { + return common::exp10_i128(scale); +} + template class ColumnDecimal; template class ColumnDecimal; template class ColumnDecimal; diff --git a/be/src/vec/columns/column_decimal.h b/be/src/vec/columns/column_decimal.h index bdca8263d3..d28e7ae469 100644 --- a/be/src/vec/columns/column_decimal.h +++ b/be/src/vec/columns/column_decimal.h @@ -113,16 +113,7 @@ public: } } - void insert_many_fix_len_data(const char* data_ptr, size_t num) override { - for (int i = 0; i < num; i++) { - const char* cur_ptr = data_ptr + sizeof(decimal12_t) * i; - int64_t int_value = *(int64_t*)(cur_ptr); - int32_t frac_value = *(int32_t*)(cur_ptr + sizeof(int64_t)); - DecimalV2Value decimal_val(int_value, frac_value); - this->insert_data(reinterpret_cast(&decimal_val), 0); - } - } - + void insert_many_fix_len_data(const char* data_ptr, size_t num) override; void insert_data(const char* pos, size_t /*length*/) override; void insert_default() override { data.push_back(T()); } void insert(const Field& x) override { @@ -217,6 +208,10 @@ public: UInt32 get_scale() const { return scale; } + T get_scale_multiplier() const; + T get_whole_part(size_t n) const { return data[n] / get_scale_multiplier(); } + T get_fractional_part(size_t n) const { return data[n] % get_scale_multiplier(); } + protected: Container data; UInt32 scale; diff --git a/be/src/vec/columns/column_nullable.h b/be/src/vec/columns/column_nullable.h index 34d87bd0b7..168ee94d10 100644 --- a/be/src/vec/columns/column_nullable.h +++ b/be/src/vec/columns/column_nullable.h @@ -173,9 +173,11 @@ public: bool is_date_type() const override { return get_nested_column().is_date_type(); } bool is_date_v2_type() const override { return get_nested_column().is_date_v2_type(); } bool is_datetime_type() const override { return get_nested_column().is_datetime_type(); } + bool is_decimalv2_type() const { return get_nested_column().is_decimalv2_type(); } void set_date_type() override { get_nested_column().set_date_type(); } void set_date_v2_type() override { get_nested_column().set_date_v2_type(); } void set_datetime_type() override { get_nested_column().set_datetime_type(); } + void set_decimalv2_type() override { get_nested_column().set_decimalv2_type(); } bool is_nullable() const override { return true; } bool is_bitmap() const override { return get_nested_column().is_bitmap(); } diff --git a/be/src/vec/data_types/data_type_decimal.cpp b/be/src/vec/data_types/data_type_decimal.cpp index d0eae375e9..7a50eb487e 100644 --- a/be/src/vec/data_types/data_type_decimal.cpp +++ b/be/src/vec/data_types/data_type_decimal.cpp @@ -55,6 +55,16 @@ template void DataTypeDecimal::to_string(const IColumn& column, size_t row_num, BufferWritable& ostr) const { // TODO: Reduce the copy in std::string mem to ostr, like DataTypeNumber + if (config::enable_decimalv3) { + T value = assert_cast(*column.convert_to_full_column_if_const().get()) + .get_data()[row_num]; + std::ostringstream buf; + write_text(value, scale, buf); + std::string str = buf.str(); + ostr.write(str.data(), str.size()); + return; + } + DecimalV2Value value = (DecimalV2Value)assert_cast( *column.convert_to_full_column_if_const().get()) .get_data()[row_num]; @@ -66,7 +76,7 @@ template Status DataTypeDecimal::from_string(ReadBuffer& rb, IColumn* column) const { auto& column_data = static_cast(*column).get_data(); T val = 0; - if (!read_decimal_text_impl(val, rb)) { + if (!read_decimal_text_impl(val, rb, precision, scale)) { return Status::InvalidArgument("parse decimal fail, string: '{}'", std::string(rb.position(), rb.count()).c_str()); } @@ -127,7 +137,24 @@ DataTypePtr DataTypeDecimal::promote_numeric_type() const { template MutableColumnPtr DataTypeDecimal::create_column() const { - return ColumnType::create(0, scale); + if (config::enable_decimalv3) { + return ColumnType::create(0, scale); + } else { + auto col = ColumnDecimal128::create(0, scale); + col->set_decimalv2_type(); + return col; + } +} + +template +T DataTypeDecimal::parse_from_string(const std::string& str) const { + StringParser::ParseResult result = StringParser::PARSE_SUCCESS; + T value = StringParser::string_to_decimal<__int128>(str.c_str(), str.size(), precision, scale, + &result); + if (result != StringParser::PARSE_SUCCESS) { + LOG(FATAL) << "Failed to parse string of decimal"; + } + return value; } DataTypePtr create_decimal(UInt64 precision_value, UInt64 scale_value) { @@ -148,18 +175,86 @@ DataTypePtr create_decimal(UInt64 precision_value, UInt64 scale_value) { } template <> -Decimal32 DataTypeDecimal::get_scale_multiplier(UInt32 scale_) { - return common::exp10_i32(scale_); +Decimal32 DataTypeDecimal::get_scale_multiplier(UInt32 scale) { + return common::exp10_i32(scale); } template <> -Decimal64 DataTypeDecimal::get_scale_multiplier(UInt32 scale_) { - return common::exp10_i64(scale_); +Decimal64 DataTypeDecimal::get_scale_multiplier(UInt32 scale) { + return common::exp10_i64(scale); } template <> -Decimal128 DataTypeDecimal::get_scale_multiplier(UInt32 scale_) { - return common::exp10_i128(scale_); +Decimal128 DataTypeDecimal::get_scale_multiplier(UInt32 scale) { + return common::exp10_i128(scale); +} + +template +void convert_to_decimal(T* from_value, T* to_value, int32_t from_scale, int32_t to_scale, + bool* loss_accuracy) { + if (from_scale == to_scale) { + *to_value = *from_value; + return; + } + if (from_scale > to_scale) { + *to_value = + (*from_value) / static_cast(DataTypeDecimal>::get_scale_multiplier( + from_scale - to_scale)); + *loss_accuracy = + ((*from_value) % static_cast(DataTypeDecimal>::get_scale_multiplier( + from_scale - to_scale))) != 0; + } else { + if (common::mul_overflow(*from_value, + static_cast(DataTypeDecimal>::get_scale_multiplier( + to_scale - from_scale)), + *to_value)) { + LOG(FATAL) << "Decimal convert overflow"; + } + } +} + +template +typename T::NativeType max_decimal_value(UInt32 precision) { + return 0; +} +template <> +Int32 max_decimal_value(UInt32 precision) { + return 999999999 / DataTypeDecimal::get_scale_multiplier( + (UInt32)(max_decimal_precision() - precision)); +} +template <> +Int64 max_decimal_value(UInt32 precision) { + return 999999999999999999 / DataTypeDecimal::get_scale_multiplier( + (UInt64)max_decimal_precision() - precision); +} +template <> +Int128 max_decimal_value(UInt32 precision) { + return (static_cast(999999999999999999ll) * 100000000000000000ll * 1000ll + + static_cast(99999999999999999ll) * 1000ll + 999ll) / + DataTypeDecimal::get_scale_multiplier( + (UInt64)max_decimal_precision() - precision); +} + +template +typename T::NativeType min_decimal_value(UInt32 precision) { + return 0; +} +template <> +Int32 min_decimal_value(UInt32 precision) { + return -999999999 / DataTypeDecimal::get_scale_multiplier( + (UInt32)max_decimal_precision() - precision); +} +template <> +Int64 min_decimal_value(UInt32 precision) { + return -999999999999999999 / DataTypeDecimal::get_scale_multiplier( + (UInt64)max_decimal_precision() - precision); +} +template <> +Int128 min_decimal_value(UInt32 precision) { + return -(static_cast(999999999999999999ll) * 100000000000000000ll * 1000ll + + static_cast(99999999999999999ll) * 1000ll + 999ll) / + DataTypeDecimal::get_scale_multiplier( + (UInt64)max_decimal_precision() - precision); } /// Explicit template instantiations. diff --git a/be/src/vec/data_types/data_type_decimal.h b/be/src/vec/data_types/data_type_decimal.h index 7a2d9541e6..3aa85aabeb 100644 --- a/be/src/vec/data_types/data_type_decimal.h +++ b/be/src/vec/data_types/data_type_decimal.h @@ -21,6 +21,7 @@ #pragma once #include +#include "common/config.h" #include "vec/columns/column_decimal.h" #include "vec/common/arithmetic_overflow.h" #include "vec/common/typeid_cast.h" @@ -46,7 +47,25 @@ constexpr size_t max_decimal_precision() { } template <> constexpr size_t max_decimal_precision() { - return 27; + return 38; +} + +template +static constexpr typename T::NativeType max_decimal_value() { + return 0; +} +template <> +constexpr Int32 max_decimal_value() { + return 999999999; +} +template <> +constexpr Int64 max_decimal_value() { + return 999999999999999999; +} +template <> +constexpr Int128 max_decimal_value() { + return static_cast(999999999999999999ll) * 100000000000000000ll * 1000ll + + static_cast(99999999999999999ll) * 1000ll + 999ll; } DataTypePtr create_decimal(UInt64 precision, UInt64 scale); @@ -104,10 +123,6 @@ public: if (UNLIKELY(scale < 0 || static_cast(scale) > max_precision())) { LOG(FATAL) << fmt::format("Scale {} is out of bounds", scale); } - - // Now, Doris only support precision:27, scale: 9 - DCHECK(precision == 27); - DCHECK(scale == 9); } DataTypeDecimal(const DataTypeDecimal& rhs) : precision(rhs.precision), scale(rhs.scale) {} @@ -191,6 +206,8 @@ public: static T get_scale_multiplier(UInt32 scale); + T parse_from_string(const std::string& str) const; + private: const UInt32 precision; const UInt32 scale; @@ -200,26 +217,54 @@ template typename std::enable_if_t<(sizeof(T) >= sizeof(U)), const DataTypeDecimal> decimal_result_type( const DataTypeDecimal& tx, const DataTypeDecimal& ty, bool is_multiply, bool is_divide) { - return DataTypeDecimal(max_decimal_precision(), 9); + if (config::enable_decimalv3) { + UInt32 scale = (tx.get_scale() > ty.get_scale() ? tx.get_scale() : ty.get_scale()); + if (is_multiply) { + scale = tx.get_scale() + ty.get_scale(); + } else if (is_divide) { + scale = tx.get_scale(); + } + return DataTypeDecimal(max_decimal_precision(), scale); + } else { + return DataTypeDecimal(max_decimal_precision(), 9); + } } template typename std::enable_if_t<(sizeof(T) < sizeof(U)), const DataTypeDecimal> decimal_result_type( const DataTypeDecimal& tx, const DataTypeDecimal& ty, bool is_multiply, bool is_divide) { - return DataTypeDecimal(max_decimal_precision(), 9); + if (config::enable_decimalv3) { + UInt32 scale = (tx.get_scale() > ty.get_scale() ? tx.get_scale() : ty.get_scale()); + if (is_multiply) { + scale = tx.get_scale() + ty.get_scale(); + } else if (is_divide) { + scale = tx.get_scale(); + } + return DataTypeDecimal(max_decimal_precision(), scale); + } else { + return DataTypeDecimal(max_decimal_precision(), 9); + } } template const DataTypeDecimal decimal_result_type(const DataTypeDecimal& tx, const DataTypeNumber&, bool, bool) { - return DataTypeDecimal(max_decimal_precision(), 9); + if (config::enable_decimalv3) { + return DataTypeDecimal(max_decimal_precision(), tx.get_scale()); + } else { + return DataTypeDecimal(max_decimal_precision(), 9); + } } template const DataTypeDecimal decimal_result_type(const DataTypeNumber&, const DataTypeDecimal& ty, bool, bool) { - return DataTypeDecimal(max_decimal_precision(), 9); + if (config::enable_decimalv3) { + return DataTypeDecimal(max_decimal_precision(), ty.get_scale()); + } else { + return DataTypeDecimal(max_decimal_precision(), 9); + } } template @@ -290,9 +335,12 @@ convert_from_decimal(const typename FromDataType::FieldType& value, UInt32 scale using FromFieldType = typename FromDataType::FieldType; using ToFieldType = typename ToDataType::FieldType; - if constexpr (std::is_floating_point_v) + if constexpr (std::is_floating_point_v) { + if (config::enable_decimalv3) { + return static_cast(value) / FromDataType::get_scale_multiplier(scale); + } return binary_cast(value); - else { + } else { FromFieldType converted_value = convert_decimals(value, scale, 0); @@ -344,4 +392,13 @@ convert_to_decimal(const typename FromDataType::FieldType& value, UInt32 scale) } } +template +void convert_to_decimal(T* from_value, T* to_value, int32_t from_scale, int32_t to_scale, + bool* loss_accuracy); +template +typename T::NativeType max_decimal_value(UInt32 precision); + +template +typename T::NativeType min_decimal_value(UInt32 precision); + } // namespace doris::vectorized diff --git a/be/src/vec/data_types/data_type_factory.cpp b/be/src/vec/data_types/data_type_factory.cpp index 8ab9d6a933..c8be27a6ab 100644 --- a/be/src/vec/data_types/data_type_factory.cpp +++ b/be/src/vec/data_types/data_type_factory.cpp @@ -30,7 +30,8 @@ DataTypePtr DataTypeFactory::create_data_type(const doris::Field& col_desc) { DCHECK(col_desc.get_sub_field_count() == 1); nested = std::make_shared(create_data_type(*col_desc.get_sub_field(0))); } else { - nested = _create_primitive_data_type(col_desc.type()); + nested = _create_primitive_data_type(col_desc.type(), col_desc.get_precision(), + col_desc.get_scale()); } if (col_desc.is_nullable() && nested) { @@ -45,7 +46,8 @@ DataTypePtr DataTypeFactory::create_data_type(const TabletColumn& col_desc, bool DCHECK(col_desc.get_subtype_count() == 1); nested = std::make_shared(create_data_type(col_desc.get_sub_column(0))); } else { - nested = _create_primitive_data_type(col_desc.type()); + nested = + _create_primitive_data_type(col_desc.type(), col_desc.precision(), col_desc.frac()); } if ((is_nullable || col_desc.is_nullable()) && nested) { @@ -106,6 +108,11 @@ DataTypePtr DataTypeFactory::create_data_type(const TypeDescriptor& col_desc, bo case TYPE_DECIMALV2: nested = std::make_shared>(27, 9); break; + case TYPE_DECIMAL32: + case TYPE_DECIMAL64: + case TYPE_DECIMAL128: + nested = vectorized::create_decimal(col_desc.precision, col_desc.scale); + break; // Just Mock A NULL Type in Vec Exec Engine case TYPE_NULL: nested = std::make_shared(); @@ -127,7 +134,8 @@ DataTypePtr DataTypeFactory::create_data_type(const TypeDescriptor& col_desc, bo return nested; } -DataTypePtr DataTypeFactory::_create_primitive_data_type(const FieldType& type) const { +DataTypePtr DataTypeFactory::_create_primitive_data_type(const FieldType& type, int precision, + int scale) const { DataTypePtr result = nullptr; switch (type) { case OLAP_FIELD_TYPE_BOOL: @@ -177,6 +185,11 @@ DataTypePtr DataTypeFactory::_create_primitive_data_type(const FieldType& type) case OLAP_FIELD_TYPE_DECIMAL: result = std::make_shared>(27, 9); break; + case OLAP_FIELD_TYPE_DECIMAL32: + case OLAP_FIELD_TYPE_DECIMAL64: + case OLAP_FIELD_TYPE_DECIMAL128: + result = vectorized::create_decimal(precision, scale); + break; default: DCHECK(false) << "Invalid FieldType:" << (int)type; result = nullptr; diff --git a/be/src/vec/data_types/data_type_factory.hpp b/be/src/vec/data_types/data_type_factory.hpp index 59740debd3..674bb5515f 100644 --- a/be/src/vec/data_types/data_type_factory.hpp +++ b/be/src/vec/data_types/data_type_factory.hpp @@ -23,6 +23,7 @@ #include #include "arrow/type.h" +#include "common/consts.h" #include "gen_cpp/data.pb.h" #include "olap/field.h" #include "olap/tablet_schema.h" @@ -66,6 +67,12 @@ public: {"DateTime", std::make_shared()}, {"String", std::make_shared()}, {"Decimal", std::make_shared>(27, 9)}, + {"Decimal32", std::make_shared>( + BeConsts::MAX_DECIMAL32_PRECISION, 0)}, + {"Decimal64", std::make_shared>( + BeConsts::MAX_DECIMAL64_PRECISION, 0)}, + {"Decimal128", std::make_shared>( + BeConsts::MAX_DECIMAL128_PRECISION, 0)}, }; for (auto const& [key, val] : base_type_map) { @@ -89,6 +96,9 @@ public: if (entity.first->equals(*type_ptr)) { return entity.second; } + if (is_decimal(type_ptr) && type_ptr->get_type_id() == entity.first->get_type_id()) { + return entity.second; + } } return _empty_string; } @@ -107,7 +117,7 @@ public: } private: - DataTypePtr _create_primitive_data_type(const FieldType& type) const; + DataTypePtr _create_primitive_data_type(const FieldType& type, int precision, int scale) const; void register_data_type(const std::string& name, const DataTypePtr& data_type) { _data_type_map.emplace(name, data_type); diff --git a/be/src/vec/exec/join/vhash_join_node.cpp b/be/src/vec/exec/join/vhash_join_node.cpp index 7a6b04e63c..8b2a0aab64 100644 --- a/be/src/vec/exec/join/vhash_join_node.cpp +++ b/be/src/vec/exec/join/vhash_join_node.cpp @@ -1249,8 +1249,25 @@ void HashJoinNode::_hash_table_init() { break; case TYPE_LARGEINT: case TYPE_DECIMALV2: - _hash_table_variants.emplace(); + case TYPE_DECIMAL32: + case TYPE_DECIMAL64: + case TYPE_DECIMAL128: { + DataTypePtr& type_ptr = _build_expr_ctxs[0]->root()->data_type(); + TypeIndex idx = _build_expr_ctxs[0]->root()->is_nullable() + ? assert_cast(*type_ptr) + .get_nested_type() + ->get_type_id() + : type_ptr->get_type_id(); + WhichDataType which(idx); + if (which.is_decimal32()) { + _hash_table_variants.emplace(); + } else if (which.is_decimal64()) { + _hash_table_variants.emplace(); + } else { + _hash_table_variants.emplace(); + } break; + } default: _hash_table_variants.emplace(); } diff --git a/be/src/vec/exec/vaggregation_node.cpp b/be/src/vec/exec/vaggregation_node.cpp index 62e3e415b3..cbecc624ad 100644 --- a/be/src/vec/exec/vaggregation_node.cpp +++ b/be/src/vec/exec/vaggregation_node.cpp @@ -149,8 +149,24 @@ void AggregationNode::_init_hash_method(std::vector& probe_exprs) return; case TYPE_LARGEINT: case TYPE_DECIMALV2: - _agg_data.init(AggregatedDataVariants::Type::int128_key, is_nullable); + case TYPE_DECIMAL32: + case TYPE_DECIMAL64: + case TYPE_DECIMAL128: { + DataTypePtr& type_ptr = probe_exprs[0]->root()->data_type(); + TypeIndex idx = is_nullable ? assert_cast(*type_ptr) + .get_nested_type() + ->get_type_id() + : type_ptr->get_type_id(); + WhichDataType which(idx); + if (which.is_decimal32()) { + _agg_data.init(AggregatedDataVariants::Type::int32_key, is_nullable); + } else if (which.is_decimal64()) { + _agg_data.init(AggregatedDataVariants::Type::int64_key, is_nullable); + } else { + _agg_data.init(AggregatedDataVariants::Type::int128_key, is_nullable); + } return; + } default: _agg_data.init(AggregatedDataVariants::Type::serialized); } @@ -202,7 +218,7 @@ void AggregationNode::_init_hash_method(std::vector& probe_exprs) _agg_data.init(AggregatedDataVariants::Type::serialized); } } -} +} // namespace doris::vectorized Status AggregationNode::prepare(RuntimeState* state) { SCOPED_TIMER(_runtime_profile->total_time_counter()); diff --git a/be/src/vec/exec/volap_scan_node.cpp b/be/src/vec/exec/volap_scan_node.cpp index 74c9eccc76..5eadfb3162 100644 --- a/be/src/vec/exec/volap_scan_node.cpp +++ b/be/src/vec/exec/volap_scan_node.cpp @@ -28,6 +28,7 @@ #include "util/priority_thread_pool.hpp" #include "util/to_string.h" #include "vec/core/block.h" +#include "vec/data_types/data_type_decimal.h" #include "vec/exec/volap_scanner.h" #include "vec/exprs/vcompound_pred.h" #include "vec/exprs/vexpr.h" @@ -717,6 +718,30 @@ Status VOlapScanNode::normalize_conjuncts() { break; } + case TYPE_DECIMAL32: { + ColumnValueRange range(slots[slot_idx]->col_name(), + slots[slot_idx]->type().precision, + slots[slot_idx]->type().scale); + normalize_predicate(range, slots[slot_idx]); + break; + } + + case TYPE_DECIMAL64: { + ColumnValueRange range(slots[slot_idx]->col_name(), + slots[slot_idx]->type().precision, + slots[slot_idx]->type().scale); + normalize_predicate(range, slots[slot_idx]); + break; + } + + case TYPE_DECIMAL128: { + ColumnValueRange range(slots[slot_idx]->col_name(), + slots[slot_idx]->type().precision, + slots[slot_idx]->type().scale); + normalize_predicate(range, slots[slot_idx]); + break; + } + case TYPE_BOOLEAN: { ColumnValueRange range(slots[slot_idx]->col_name()); normalize_predicate(range, slots[slot_idx]); @@ -950,7 +975,9 @@ std::pair VOlapScanNode::should_push_down_eq_predicate(doris::SlotD } // get value in result pair - result_pair = std::make_pair(true, _conjunct_ctxs[conj_idx]->get_value(expr, nullptr)); + result_pair = std::make_pair( + true, _conjunct_ctxs[conj_idx]->get_value(expr, nullptr, slot->type().precision, + slot->type().scale)); return result_pair; } @@ -980,6 +1007,9 @@ Status VOlapScanNode::change_fixed_value_range(ColumnValueRange& case TYPE_INT: case TYPE_BIGINT: case TYPE_LARGEINT: + case TYPE_DECIMAL32: + case TYPE_DECIMAL64: + case TYPE_DECIMAL128: case TYPE_STRING: { func(temp_range, reinterpret_cast::CppType*>(value)); @@ -1086,7 +1116,8 @@ Status VOlapScanNode::normalize_in_and_eq_predicate(SlotDescriptor* slot, std::vector filter_conjuncts_index; for (int conj_idx = 0; conj_idx < _conjunct_ctxs.size(); ++conj_idx) { // create empty range as temp range, temp range should do intersection on range - auto temp_range = ColumnValueRange::create_empty_column_value_range(); + auto temp_range = ColumnValueRange::create_empty_column_value_range( + slot->type().precision, slot->type().scale); // 1. Normalize in conjuncts like 'where col in (v1, v2, v3)' if (TExprOpcode::FILTER_IN == _conjunct_ctxs[conj_idx]->root()->op()) { @@ -1261,7 +1292,8 @@ bool VOlapScanNode::normalize_is_null_predicate(Expr* expr, SlotDescriptor* slot return false; } - auto temp_range = ColumnValueRange::create_empty_column_value_range(); + auto temp_range = ColumnValueRange::create_empty_column_value_range(slot->type().precision, + slot->type().scale); temp_range.set_contain_null(is_null_str == "null"); range->intersection(temp_range); @@ -1321,7 +1353,8 @@ Status VOlapScanNode::normalize_noneq_binary_predicate(SlotDescriptor* slot, continue; } - void* value = _conjunct_ctxs[conj_idx]->get_value(expr, nullptr); + void* value = _conjunct_ctxs[conj_idx]->get_value( + expr, nullptr, slot->type().precision, slot->type().scale); // for case: where col > null if (value == nullptr) { continue; @@ -1360,6 +1393,9 @@ Status VOlapScanNode::normalize_noneq_binary_predicate(SlotDescriptor* slot, } case TYPE_TINYINT: case TYPE_DECIMALV2: + case TYPE_DECIMAL32: + case TYPE_DECIMAL64: + case TYPE_DECIMAL128: case TYPE_CHAR: case TYPE_VARCHAR: case TYPE_HLL: diff --git a/be/src/vec/exec/vschema_scan_node.cpp b/be/src/vec/exec/vschema_scan_node.cpp index 65f260adec..eef55ea567 100644 --- a/be/src/vec/exec/vschema_scan_node.cpp +++ b/be/src/vec/exec/vschema_scan_node.cpp @@ -358,13 +358,15 @@ Status VSchemaScanNode::write_slot_to_vectorized_column(void* slot, SlotDescript break; } - case TYPE_INT: { + case TYPE_INT: + case TYPE_DECIMAL32: { int32_t num = *reinterpret_cast(slot); reinterpret_cast*>(col_ptr)->insert_value(num); break; } - case TYPE_BIGINT: { + case TYPE_BIGINT: + case TYPE_DECIMAL64: { int64_t num = *reinterpret_cast(slot); reinterpret_cast*>(col_ptr)->insert_value(num); break; @@ -415,7 +417,8 @@ Status VSchemaScanNode::write_slot_to_vectorized_column(void* slot, SlotDescript break; } - case TYPE_DECIMALV2: { + case TYPE_DECIMALV2: + case TYPE_DECIMAL128: { __int128 num = (reinterpret_cast(slot))->value; reinterpret_cast*>(col_ptr)->insert_value( num); diff --git a/be/src/vec/exec/vset_operation_node.cpp b/be/src/vec/exec/vset_operation_node.cpp index 9dc620896e..9b644fffd7 100644 --- a/be/src/vec/exec/vset_operation_node.cpp +++ b/be/src/vec/exec/vset_operation_node.cpp @@ -163,16 +163,19 @@ void VSetOperationNode::hash_table_init() { case TYPE_INT: case TYPE_FLOAT: case TYPE_DATEV2: + case TYPE_DECIMAL32: _hash_table_variants.emplace(); break; case TYPE_BIGINT: case TYPE_DOUBLE: case TYPE_DATETIME: case TYPE_DATE: + case TYPE_DECIMAL64: _hash_table_variants.emplace(); break; case TYPE_LARGEINT: case TYPE_DECIMALV2: + case TYPE_DECIMAL128: _hash_table_variants.emplace(); break; default: diff --git a/be/src/vec/exprs/vliteral.cpp b/be/src/vec/exprs/vliteral.cpp index c1dc68b545..9936733c5c 100644 --- a/be/src/vec/exprs/vliteral.cpp +++ b/be/src/vec/exprs/vliteral.cpp @@ -22,6 +22,8 @@ #include "runtime/large_int_value.h" #include "util/string_parser.hpp" #include "vec/core/field.h" +#include "vec/data_types/data_type_decimal.h" +#include "vec/io/io_helper.h" #include "vec/runtime/vdatetime_value.h" namespace doris { @@ -122,6 +124,42 @@ void VLiteral::init(const TExprNode& node) { field = DecimalField(value.value(), value.scale()); break; } + case TYPE_DECIMAL32: { + DCHECK_EQ(node.node_type, TExprNodeType::DECIMAL_LITERAL); + DCHECK(node.__isset.decimal_literal); + DataTypePtr type_ptr = create_decimal(node.type.types[0].scalar_type.precision, + node.type.types[0].scalar_type.scale); + auto val = typeid_cast*>(type_ptr.get()) + ->parse_from_string(node.decimal_literal.value); + auto scale = + typeid_cast*>(type_ptr.get())->get_scale(); + field = DecimalField(val, scale); + break; + } + case TYPE_DECIMAL64: { + DCHECK_EQ(node.node_type, TExprNodeType::DECIMAL_LITERAL); + DCHECK(node.__isset.decimal_literal); + DataTypePtr type_ptr = create_decimal(node.type.types[0].scalar_type.precision, + node.type.types[0].scalar_type.scale); + auto val = typeid_cast*>(type_ptr.get()) + ->parse_from_string(node.decimal_literal.value); + auto scale = + typeid_cast*>(type_ptr.get())->get_scale(); + field = DecimalField(val, scale); + break; + } + case TYPE_DECIMAL128: { + DCHECK_EQ(node.node_type, TExprNodeType::DECIMAL_LITERAL); + DCHECK(node.__isset.decimal_literal); + DataTypePtr type_ptr = create_decimal(node.type.types[0].scalar_type.precision, + node.type.types[0].scalar_type.scale); + auto val = typeid_cast*>(type_ptr.get()) + ->parse_from_string(node.decimal_literal.value); + auto scale = + typeid_cast*>(type_ptr.get())->get_scale(); + field = DecimalField(val, scale); + break; + } default: { DCHECK(false) << "Invalid type: " << _type.type; break; @@ -190,6 +228,21 @@ std::string VLiteral::debug_string() const { out << value; break; } + case TYPE_DECIMAL32: { + write_text(*(reinterpret_cast(ref.data)), _type.scale, + out); + break; + } + case TYPE_DECIMAL64: { + write_text(*(reinterpret_cast(ref.data)), _type.scale, + out); + break; + } + case TYPE_DECIMAL128: { + write_text(*(reinterpret_cast(ref.data)), _type.scale, + out); + break; + } default: { out << "UNKNOWN TYPE: " << int(_type.type); break; diff --git a/be/src/vec/functions/function.h b/be/src/vec/functions/function.h index 2272430a8e..0e90f8d681 100644 --- a/be/src/vec/functions/function.h +++ b/be/src/vec/functions/function.h @@ -297,6 +297,13 @@ public: ? ((DataTypeNullable*)return_type.get())->get_nested_type() : return_type) && is_date_v2(get_return_type(arguments)->is_nullable() + ? ((DataTypeNullable*)get_return_type(arguments).get()) + ->get_nested_type() + : get_return_type(arguments))) || + (is_decimal(return_type->is_nullable() + ? ((DataTypeNullable*)return_type.get())->get_nested_type() + : return_type) && + is_decimal(get_return_type(arguments)->is_nullable() ? ((DataTypeNullable*)get_return_type(arguments).get()) ->get_nested_type() : get_return_type(arguments)))) diff --git a/be/src/vec/functions/function_binary_arithmetic.h b/be/src/vec/functions/function_binary_arithmetic.h index 3091ab2dde..f3bc32f074 100644 --- a/be/src/vec/functions/function_binary_arithmetic.h +++ b/be/src/vec/functions/function_binary_arithmetic.h @@ -66,6 +66,7 @@ struct OperationTraits { static constexpr bool is_multiply = std::is_same_v>; static constexpr bool is_division = std::is_same_v> || std::is_same_v>; + static constexpr bool is_mod = std::is_same_v>; static constexpr bool allow_decimal = std::is_same_v> || std::is_same_v> || std::is_same_v> || std::is_same_v> || @@ -212,7 +213,7 @@ struct BinaryOperationImpl { /// * no agrs scale. ScaleR = Scale1 + Scale2; /// / first arg scale. ScaleR = Scale1 (scale_a = DecimalType::get_scale()). template typename Operation, - typename ResultType, bool is_to_null_type, bool check_overflow = false> + typename ResultType, bool is_to_null_type, bool check_overflow = true> struct DecimalBinaryOperation { using OpTraits = OperationTraits; @@ -251,7 +252,26 @@ struct DecimalBinaryOperation { ArrayC& c, ResultType scale_a [[maybe_unused]], ResultType scale_b [[maybe_unused]], NullMap& null_map) { size_t size = a.size(); - + if (config::enable_decimalv3) { + if constexpr (OpTraits::is_division && IsDecimalNumber) { + for (size_t i = 0; i < size; ++i) { + c[i] = apply_scaled_div(a[i], b[i], scale_a, null_map[i]); + } + return; + } else if constexpr (OpTraits::is_mod) { + if (scale_a != 1) { + for (size_t i = 0; i < size; ++i) { + c[i] = apply_scaled_mod(a[i], b[i], scale_a, null_map[i]); + } + return; + } else if (scale_b != 1) { + for (size_t i = 0; i < size; ++i) { + c[i] = apply_scaled_mod(a[i], b[i], scale_b, null_map[i]); + } + return; + } + } + } /// default: use it if no return before for (size_t i = 0; i < size; ++i) { c[i] = apply(a[i], b[i], null_map[i]); @@ -296,6 +316,18 @@ struct DecimalBinaryOperation { c[i] = apply_scaled_div(a[i], b, scale_a, null_map[i]); } return; + } else if constexpr (OpTraits::is_mod) { + if (scale_a != 1) { + for (size_t i = 0; i < size; ++i) { + c[i] = apply_scaled_mod(a[i], b, scale_a, null_map[i]); + } + return; + } else if (scale_b != 1) { + for (size_t i = 0; i < size; ++i) { + c[i] = apply_scaled_mod(a[i], b, scale_b, null_map[i]); + } + return; + } } for (size_t i = 0; i < size; ++i) { @@ -341,6 +373,18 @@ struct DecimalBinaryOperation { c[i] = apply_scaled_div(a, b[i], scale_a, null_map[i]); } return; + } else if constexpr (OpTraits::is_mod) { + if (scale_a != 1) { + for (size_t i = 0; i < size; ++i) { + c[i] = apply_scaled_mod(a, b[i], scale_a, null_map[i]); + } + return; + } else if (scale_b != 1) { + for (size_t i = 0; i < size; ++i) { + c[i] = apply_scaled_mod(a, b[i], scale_b, null_map[i]); + } + return; + } } for (size_t i = 0; i < size; ++i) { @@ -372,6 +416,12 @@ struct DecimalBinaryOperation { } } else if constexpr (OpTraits::is_division && IsDecimalNumber) { return apply_scaled_div(a, b, scale_a, is_null); + } else if constexpr (OpTraits::is_mod) { + if (scale_a != 1) { + return apply_scaled_mod(a, b, scale_a, is_null); + } else if (scale_b != 1) { + return apply_scaled_mod(a, b, scale_b, is_null); + } } return apply(a, b, is_null); } @@ -451,6 +501,20 @@ struct DecimalBinaryOperation { private: /// there's implicit type convertion here static NativeResultType apply(NativeResultType a, NativeResultType b) { + if (config::enable_decimalv3) { + if constexpr (OpTraits::can_overflow && check_overflow) { + NativeResultType res; + // TODO handle overflow gracefully + if (Op::template apply(a, b, res)) { + LOG(WARNING) << "Decimal math overflow"; + res = max_decimal_value(); + } + return res; + } else { + return Op::template apply(a, b); + } + } + // Now, Doris only support decimal +-*/ decimal. // overflow in consider in operator DecimalV2Value l(a); @@ -463,6 +527,10 @@ private: /// null_map for divide and mod static NativeResultType apply(NativeResultType a, NativeResultType b, UInt8& is_null) { + if (config::enable_decimalv3) { + return Op::template apply(a, b, is_null); + } + DecimalV2Value l(a); DecimalV2Value r(b); auto ans = Op::template apply(l, r, is_null); @@ -491,8 +559,10 @@ private: res = Op::template apply(a, b); } + // TODO handle overflow gracefully if (overflow) { - LOG(FATAL) << "Decimal math overflow"; + LOG(WARNING) << "Decimal math overflow"; + res = max_decimal_value(); } } else { if constexpr (scale_left) { @@ -516,8 +586,10 @@ private: overflow |= common::mul_overflow(scale, scale, scale); } overflow |= common::mul_overflow(a, scale, a); + // TODO handle overflow gracefully if (overflow) { - LOG(FATAL) << "Decimal math overflow"; + LOG(WARNING) << "Decimal math overflow"; + return max_decimal_value(); } } else { if constexpr (!IsDecimalNumber) { @@ -529,6 +601,31 @@ private: return apply(a, b, is_null); } } + + template + static NativeResultType apply_scaled_mod(NativeResultType a, NativeResultType b, + NativeResultType scale, UInt8& is_null) { + if constexpr (check_overflow) { + bool overflow = false; + if constexpr (scale_left) + overflow |= common::mul_overflow(a, scale, a); + else + overflow |= common::mul_overflow(b, scale, b); + + // TODO handle overflow gracefully + if (overflow) { + LOG(WARNING) << "Decimal math overflow"; + return max_decimal_value(); + } + } else { + if constexpr (scale_left) + a *= scale; + else + b *= scale; + } + + return apply(a, b, is_null); + } }; /// Used to indicate undefined operation @@ -646,10 +743,16 @@ private: static auto get_decimal_infos(const LeftDataType& type_left, const RightDataType& type_right) { ResultDataType type = decimal_result_type(type_left, type_right, OpTraits::is_multiply, OpTraits::is_division); - typename ResultDataType::FieldType scale_a = - type.scale_factor_for(type_left, OpTraits::is_multiply); - typename ResultDataType::FieldType scale_b = - type.scale_factor_for(type_right, OpTraits::is_multiply || OpTraits::is_division); + typename ResultDataType::FieldType scale_a; + typename ResultDataType::FieldType scale_b; + if constexpr (OpTraits::is_division && IsDataTypeDecimal) { + scale_a = type_right.get_scale_multiplier(); + scale_b = 1; + } else { + scale_a = type.scale_factor_for(type_left, OpTraits::is_multiply); + scale_b = type.scale_factor_for(type_right, + OpTraits::is_multiply || OpTraits::is_division); + } return std::make_tuple(type, scale_a, scale_b); } diff --git a/be/src/vec/functions/function_cast.h b/be/src/vec/functions/function_cast.h index 1531b77a24..fed821f09f 100644 --- a/be/src/vec/functions/function_cast.h +++ b/be/src/vec/functions/function_cast.h @@ -365,8 +365,9 @@ struct NameToDateTime { static constexpr auto name = "toDateTime"; }; -template -bool try_parse_impl(typename DataType::FieldType& x, ReadBuffer& rb, const DateLUTImpl*) { +template +bool try_parse_impl(typename DataType::FieldType& x, ReadBuffer& rb, const DateLUTImpl*, + Additions additions [[maybe_unused]] = Additions()) { if constexpr (IsDateTimeType) { return try_read_datetime_text(x, rb); } @@ -393,7 +394,8 @@ bool try_parse_impl(typename DataType::FieldType& x, ReadBuffer& rb, const DateL } if constexpr (IsDataTypeDecimal) { - return try_read_decimal_text(x, rb); + UInt32 scale = additions; + return try_read_decimal_text(x, rb, DataType::max_precision(), scale); } } @@ -806,7 +808,8 @@ struct ConvertThroughParsing { typename ColVecTo::MutablePtr col_to = nullptr; if constexpr (IsDataTypeDecimal) { - col_to = ColVecTo::create(size, 9); + UInt32 scale = additions; + col_to = ColVecTo::create(size, scale); } else col_to = ColVecTo::create(size); @@ -838,9 +841,14 @@ struct ConvertThroughParsing { ReadBuffer read_buffer(&(*chars)[current_offset], string_size); - (*vec_null_map_to)[i] = - !try_parse_impl(vec_to[i], read_buffer, local_time_zone) || - !is_all_read(read_buffer); + bool parsed; + if constexpr (IsDataTypeDecimal) { + parsed = try_parse_impl(vec_to[i], read_buffer, local_time_zone, + vec_to.get_scale()); + } else { + parsed = try_parse_impl(vec_to[i], read_buffer, local_time_zone); + } + (*vec_null_map_to)[i] = !parsed || !is_all_read(read_buffer); current_offset = next_offset; } @@ -851,6 +859,16 @@ struct ConvertThroughParsing { } }; +template +struct ConvertImpl, Name> + : ConvertThroughParsing, Name> {}; +template +struct ConvertImpl, Name> + : ConvertThroughParsing, Name> {}; +template +struct ConvertImpl, Name> + : ConvertThroughParsing, Name> {}; + template class FunctionConvertFromString : public IFunction { public: @@ -869,11 +887,7 @@ public: DataTypePtr get_return_type_impl(const ColumnsWithTypeAndName& arguments) const override { DataTypePtr res; if constexpr (IsDataTypeDecimal) { - res = create_decimal(27, 9); - - if (!res) { - LOG(FATAL) << "Someting wrong with toDecimalNNOrZero() or toDecimalNNOrNull()"; - } + LOG(FATAL) << "Someting wrong with toDecimalNNOrZero() or toDecimalNNOrNull()"; } else res = std::make_shared(); @@ -1048,20 +1062,6 @@ private: from_type->get_name(), to_type->get_name()); } - if (which.is_string_or_fixed_string()) { - auto function = - FunctionConvertFromString, NameCast>::create(); - - /// Check conversion using underlying function - { function->get_return_type(ColumnsWithTypeAndName(1, {nullptr, from_type, ""})); } - - return [function](FunctionContext* context, Block& block, - const ColumnNumbers& arguments, const size_t result, - size_t input_rows_count) { - return function->execute(context, block, arguments, result, input_rows_count); - }; - } - return [type_index, precision, scale](FunctionContext* context, Block& block, const ColumnNumbers& arguments, const size_t result, size_t input_rows_count) { diff --git a/be/src/vec/io/io_helper.h b/be/src/vec/io/io_helper.h index 01c89ed155..51122b26b5 100644 --- a/be/src/vec/io/io_helper.h +++ b/be/src/vec/io/io_helper.h @@ -293,10 +293,17 @@ bool read_date_v2_text_impl(T& x, ReadBuffer& buf) { } template -bool read_decimal_text_impl(T& x, ReadBuffer& buf) { +bool read_decimal_text_impl(T& x, ReadBuffer& buf, UInt32 precision, UInt32 scale) { static_assert(IsDecimalNumber); - // TODO: open this static_assert - // static_assert(std::is_same_v); + if (config::enable_decimalv3) { + StringParser::ParseResult result = StringParser::PARSE_SUCCESS; + + x.value = StringParser::string_to_decimal( + (const char*)buf.position(), buf.count(), precision, scale, &result); + // only to match the is_all_read() check to prevent return null + buf.position() = buf.end(); + return result != StringParser::PARSE_FAILURE; + } auto dv = binary_cast(x.value); auto ans = dv.parse_from_str((const char*)buf.position(), buf.count()) == 0; @@ -335,8 +342,8 @@ bool try_read_float_text(T& x, ReadBuffer& in) { } template -bool try_read_decimal_text(T& x, ReadBuffer& in) { - return read_decimal_text_impl(x, in); +bool try_read_decimal_text(T& x, ReadBuffer& in, UInt32 precision, UInt32 scale) { + return read_decimal_text_impl(x, in, precision, scale); } template diff --git a/be/src/vec/olap/olap_data_convertor.cpp b/be/src/vec/olap/olap_data_convertor.cpp index 03a1a208cd..2bf619dece 100644 --- a/be/src/vec/olap/olap_data_convertor.cpp +++ b/be/src/vec/olap/olap_data_convertor.cpp @@ -17,6 +17,7 @@ #include "vec/olap/olap_data_convertor.h" +#include "common/consts.h" #include "olap/tablet_schema.h" #include "vec/columns/column_array.h" #include "vec/columns/column_complex.h" @@ -65,6 +66,15 @@ OlapBlockDataConvertor::create_olap_column_data_convertor(const TabletColumn& co case FieldType::OLAP_FIELD_TYPE_DECIMAL: { return std::make_unique(); } + case FieldType::OLAP_FIELD_TYPE_DECIMAL32: { + return std::make_unique>(); + } + case FieldType::OLAP_FIELD_TYPE_DECIMAL64: { + return std::make_unique>(); + } + case FieldType::OLAP_FIELD_TYPE_DECIMAL128: { + return std::make_unique>(); + } case FieldType::OLAP_FIELD_TYPE_BOOL: { return std::make_unique>(); } @@ -99,7 +109,7 @@ OlapBlockDataConvertor::create_olap_column_data_convertor(const TabletColumn& co return nullptr; } } -} +} // namespace doris::vectorized void OlapBlockDataConvertor::set_source_content(const vectorized::Block* block, size_t row_pos, size_t num_rows) { @@ -630,7 +640,7 @@ Status OlapBlockDataConvertor::OlapColumnDataConvertorDecimal::convert_to_olap() const DecimalV2Value* decimal_end = decimal_cur + _num_rows; decimal12_t* value = _values.data(); if (_nullmap) { - const UInt8* nullmap_cur = _nullmap; + const UInt8* nullmap_cur = _nullmap + _row_pos; while (decimal_cur != decimal_end) { if (!*nullmap_cur) { value->integer = decimal_cur->int_value(); diff --git a/be/src/vec/olap/olap_data_convertor.h b/be/src/vec/olap/olap_data_convertor.h index 4fad5b6f59..37bd4b1ec3 100644 --- a/be/src/vec/olap/olap_data_convertor.h +++ b/be/src/vec/olap/olap_data_convertor.h @@ -246,7 +246,7 @@ private: return Status::OK(); } - private: + protected: const T* _values = nullptr; }; @@ -340,6 +340,33 @@ private: bool from_date_to_date_v2_; }; + // decimalv3 don't need to do any convert + template + class OlapColumnDataConvertorDecimalV3 + : public OlapColumnDataConvertorSimple { + public: + using FieldType = typename T::NativeType; + OlapColumnDataConvertorDecimalV3() = default; + ~OlapColumnDataConvertorDecimalV3() override = default; + + Status convert_to_olap() override { + const vectorized::ColumnDecimal* column_data = nullptr; + if (this->_nullmap) { + auto nullable_column = assert_cast( + this->_typed_column.column.get()); + column_data = assert_cast*>( + nullable_column->get_nested_column_ptr().get()); + } else { + column_data = assert_cast*>( + this->_typed_column.column.get()); + } + + assert(column_data); + this->_values = (const FieldType*)(column_data->get_data().data()) + this->_row_pos; + return Status::OK(); + } + }; + class OlapColumnDataConvertorArray : public OlapColumnDataConvertorPaddedPODArray { public: diff --git a/be/src/vec/runtime/vfile_result_writer.cpp b/be/src/vec/runtime/vfile_result_writer.cpp index 40f8472fef..128f5c9cda 100644 --- a/be/src/vec/runtime/vfile_result_writer.cpp +++ b/be/src/vec/runtime/vfile_result_writer.cpp @@ -291,6 +291,18 @@ Status VFileResultWriter::_write_csv_file(const Block& block) { _plain_text_outstream << decimal_str; break; } + case TYPE_DECIMAL32: { + _plain_text_outstream << col.type->to_string(*col.column, i); + break; + } + case TYPE_DECIMAL64: { + _plain_text_outstream << col.type->to_string(*col.column, i); + break; + } + case TYPE_DECIMAL128: { + _plain_text_outstream << col.type->to_string(*col.column, i); + break; + } default: { // not supported type, like BITMAP, HLL, just export null _plain_text_outstream << NULL_IN_CSV; diff --git a/be/src/vec/sink/vmysql_result_writer.cpp b/be/src/vec/sink/vmysql_result_writer.cpp index 8e54e62004..a520217c25 100644 --- a/be/src/vec/sink/vmysql_result_writer.cpp +++ b/be/src/vec/sink/vmysql_result_writer.cpp @@ -24,6 +24,7 @@ #include "vec/columns/column_vector.h" #include "vec/common/assert_cast.h" #include "vec/data_types/data_type_array.h" +#include "vec/data_types/data_type_decimal.h" #include "vec/exprs/vexpr.h" #include "vec/exprs/vexpr_context.h" #include "vec/runtime/vdatetime_value.h" @@ -150,6 +151,25 @@ Status VMysqlResultWriter::_add_one_column(const ColumnPtr& column_ptr, _buffer.close_dynamic_mode(); result->result_batch.rows[i].append(_buffer.buf(), _buffer.length()); } + } else if constexpr (type == TYPE_DECIMAL32 || type == TYPE_DECIMAL64 || + type == TYPE_DECIMAL128) { + for (int i = 0; i < row_size; ++i) { + if (0 != buf_ret) { + return Status::InternalError("pack mysql buffer failed."); + } + _buffer.reset(); + + if constexpr (is_nullable) { + if (column_ptr->is_null_at(i)) { + buf_ret = _buffer.push_null(); + result->result_batch.rows[i].append(_buffer.buf(), _buffer.length()); + continue; + } + } + std::string decimal_str = nested_type_ptr->to_string(*column, i); + buf_ret = _buffer.push_string(decimal_str.c_str(), decimal_str.length()); + result->result_batch.rows[i].append(_buffer.buf(), _buffer.length()); + } } else { using ColumnType = typename PrimitiveTypeTraits::ColumnType; auto& data = assert_cast(*column).get_data(); @@ -296,12 +316,6 @@ int VMysqlResultWriter::_add_one_cell(const ColumnPtr& column_ptr, size_t row_id char buf[64]; char* pos = datetime.to_string(buf); return buffer.push_string(buf, pos - buf - 1); - } else if (which.is_decimal128()) { - auto& column_data = - static_cast&>(*column).get_data(); - DecimalV2Value decimal_val(column_data[row_idx]); - auto decimal_str = decimal_val.to_string(); - return buffer.push_string(decimal_str.c_str(), decimal_str.length()); } else if (which.is_date_v2()) { auto& column_vector = assert_cast&>(*column); auto value = column_vector[row_idx].get(); @@ -310,6 +324,38 @@ int VMysqlResultWriter::_add_one_cell(const ColumnPtr& column_ptr, size_t row_id char buf[64]; char* pos = datev2.to_string(buf); return buffer.push_string(buf, pos - buf - 1); + } else if (which.is_decimal32()) { + DataTypePtr nested_type = type; + if (type->is_nullable()) { + nested_type = assert_cast(*type).get_nested_type(); + } + auto decimal_str = assert_cast*>(nested_type.get()) + ->to_string(*column, row_idx); + return buffer.push_string(decimal_str.c_str(), decimal_str.length()); + } else if (which.is_decimal64()) { + DataTypePtr nested_type = type; + if (type->is_nullable()) { + nested_type = assert_cast(*type).get_nested_type(); + } + auto decimal_str = assert_cast*>(nested_type.get()) + ->to_string(*column, row_idx); + return buffer.push_string(decimal_str.c_str(), decimal_str.length()); + } else if (which.is_decimal128()) { + if (config::enable_decimalv3) { + DataTypePtr nested_type = type; + if (type->is_nullable()) { + nested_type = assert_cast(*type).get_nested_type(); + } + auto decimal_str = assert_cast*>(nested_type.get()) + ->to_string(*column, row_idx); + return buffer.push_string(decimal_str.c_str(), decimal_str.length()); + } else { + auto& column_data = + static_cast&>(*column).get_data(); + DecimalV2Value decimal_val(column_data[row_idx]); + auto decimal_str = decimal_val.to_string(); + return buffer.push_string(decimal_str.c_str(), decimal_str.length()); + } } else if (which.is_array()) { auto& column_array = assert_cast(*column); auto& offsets = column_array.get_offsets(); @@ -459,9 +505,49 @@ Status VMysqlResultWriter::append_block(Block& input_block) { } case TYPE_DECIMALV2: { if (type_ptr->is_nullable()) { - status = _add_one_column(column_ptr, result); + auto& nested_type = + assert_cast(*type_ptr).get_nested_type(); + status = _add_one_column(column_ptr, result, + nested_type); } else { - status = _add_one_column(column_ptr, result); + status = _add_one_column(column_ptr, result, + type_ptr); + } + break; + } + case TYPE_DECIMAL32: { + if (type_ptr->is_nullable()) { + auto& nested_type = + assert_cast(*type_ptr).get_nested_type(); + status = _add_one_column(column_ptr, result, + nested_type); + } else { + status = _add_one_column(column_ptr, result, + type_ptr); + } + break; + } + case TYPE_DECIMAL64: { + if (type_ptr->is_nullable()) { + auto& nested_type = + assert_cast(*type_ptr).get_nested_type(); + status = _add_one_column(column_ptr, result, + nested_type); + } else { + status = _add_one_column(column_ptr, result, + type_ptr); + } + break; + } + case TYPE_DECIMAL128: { + if (type_ptr->is_nullable()) { + auto& nested_type = + assert_cast(*type_ptr).get_nested_type(); + status = _add_one_column(column_ptr, result, + nested_type); + } else { + status = _add_one_column(column_ptr, result, + type_ptr); } break; } diff --git a/be/src/vec/sink/vmysql_table_writer.cpp b/be/src/vec/sink/vmysql_table_writer.cpp index 6c7d33830e..f71a7a2329 100644 --- a/be/src/vec/sink/vmysql_table_writer.cpp +++ b/be/src/vec/sink/vmysql_table_writer.cpp @@ -174,6 +174,13 @@ Status VMysqlTableWriter::insert_row(vectorized::Block& block, size_t row) { fmt::format_to(_insert_stmt_buffer, "{}", value.to_string()); break; } + case TYPE_DECIMAL32: + case TYPE_DECIMAL64: + case TYPE_DECIMAL128: { + auto val = type_ptr->to_string(*column, row); + fmt::format_to(_insert_stmt_buffer, "{}", val); + break; + } case TYPE_DATE: case TYPE_DATETIME: { int64_t int_val = assert_cast(*column).get_data()[row]; diff --git a/be/test/exprs/runtime_filter_test.cpp b/be/test/exprs/runtime_filter_test.cpp index 561845106e..6e69eb1d37 100644 --- a/be/test/exprs/runtime_filter_test.cpp +++ b/be/test/exprs/runtime_filter_test.cpp @@ -31,7 +31,7 @@ #include "runtime/runtime_state.h" namespace doris { -TTypeDesc create_type_desc(PrimitiveType type); +TTypeDesc create_type_desc(PrimitiveType type, int precision, int scale); class RuntimeFilterTest : public testing::Test { public: diff --git a/be/test/olap/rowset/segment_v2/column_reader_writer_test.cpp b/be/test/olap/rowset/segment_v2/column_reader_writer_test.cpp index 62af770e09..85b77325e4 100644 --- a/be/test/olap/rowset/segment_v2/column_reader_writer_test.cpp +++ b/be/test/olap/rowset/segment_v2/column_reader_writer_test.cpp @@ -463,10 +463,10 @@ void test_read_default_value(string value, void* result) { // read and check { TabletColumn tablet_column = create_with_default_value(value); - DefaultValueColumnIterator iter(tablet_column.has_default_value(), - tablet_column.default_value(), tablet_column.is_nullable(), - create_static_type_info_ptr(scalar_type_info), - tablet_column.length()); + DefaultValueColumnIterator iter( + tablet_column.has_default_value(), tablet_column.default_value(), + tablet_column.is_nullable(), create_static_type_info_ptr(scalar_type_info), + tablet_column.length(), tablet_column.precision(), tablet_column.frac()); ColumnIteratorOptions iter_opts; auto st = iter.init(iter_opts); EXPECT_TRUE(st.ok()); @@ -575,10 +575,10 @@ void test_v_read_default_value(string value, void* result) { // read and check { TabletColumn tablet_column = create_with_default_value(value); - DefaultValueColumnIterator iter(tablet_column.has_default_value(), - tablet_column.default_value(), tablet_column.is_nullable(), - create_static_type_info_ptr(scalar_type_info), - tablet_column.length()); + DefaultValueColumnIterator iter( + tablet_column.has_default_value(), tablet_column.default_value(), + tablet_column.is_nullable(), create_static_type_info_ptr(scalar_type_info), + tablet_column.length(), tablet_column.precision(), tablet_column.frac()); ColumnIteratorOptions iter_opts; auto st = iter.init(iter_opts); EXPECT_TRUE(st.ok()); diff --git a/fe/fe-core/src/main/cup/sql_parser.cup b/fe/fe-core/src/main/cup/sql_parser.cup index 201603af1d..ef38bfc125 100644 --- a/fe/fe-core/src/main/cup/sql_parser.cup +++ b/fe/fe-core/src/main/cup/sql_parser.cup @@ -4724,15 +4724,15 @@ type ::= | KW_CHAR {: RESULT = ScalarType.createCharType(-1); :} | KW_DECIMAL LPAREN INTEGER_LITERAL:precision RPAREN - {: RESULT = ScalarType.createDecimalV2Type(precision.intValue()); :} + {: RESULT = ScalarType.createDecimalType(precision.intValue()); :} | KW_DECIMAL LPAREN INTEGER_LITERAL:precision COMMA INTEGER_LITERAL:scale RPAREN - {: RESULT = ScalarType.createDecimalV2Type(precision.intValue(), scale.intValue()); :} + {: RESULT = ScalarType.createDecimalType(precision.intValue(), scale.intValue()); :} | KW_DECIMAL - {: RESULT = ScalarType.createDecimalV2Type(); :} + {: RESULT = ScalarType.createDecimalType(); :} | KW_DECIMAL LPAREN ident_or_text:precision RPAREN - {: RESULT = ScalarType.createDecimalV2Type(precision); :} + {: RESULT = ScalarType.createDecimalType(precision); :} | KW_DECIMAL LPAREN ident_or_text:precision COMMA ident_or_text:scale RPAREN - {: RESULT = ScalarType.createDecimalV2Type(precision, scale); :} + {: RESULT = ScalarType.createDecimalType(precision, scale); :} | KW_HLL {: ScalarType type = ScalarType.createHllType(); type.setAssignedStrLenInColDefinition(); diff --git a/fe/fe-core/src/main/java/org/apache/doris/analysis/AggregateInfoBase.java b/fe/fe-core/src/main/java/org/apache/doris/analysis/AggregateInfoBase.java index 69d833e86f..d9366b8d80 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/analysis/AggregateInfoBase.java +++ b/fe/fe-core/src/main/java/org/apache/doris/analysis/AggregateInfoBase.java @@ -171,7 +171,7 @@ public abstract class AggregateInfoBase { if (!intermediateType.isWildcardDecimal()) { slotDesc.setType(intermediateType); } else { - Preconditions.checkState(expr.getType().isDecimalV2()); + Preconditions.checkState(expr.getType().isDecimalV2() || expr.getType().isDecimalV3()); } } } diff --git a/fe/fe-core/src/main/java/org/apache/doris/analysis/ArithmeticExpr.java b/fe/fe-core/src/main/java/org/apache/doris/analysis/ArithmeticExpr.java index 4eb3b84879..a8abf3e97e 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/analysis/ArithmeticExpr.java +++ b/fe/fe-core/src/main/java/org/apache/doris/analysis/ArithmeticExpr.java @@ -24,6 +24,7 @@ import org.apache.doris.catalog.Function; import org.apache.doris.catalog.FunctionSet; import org.apache.doris.catalog.PrimitiveType; import org.apache.doris.catalog.ScalarFunction; +import org.apache.doris.catalog.ScalarType; import org.apache.doris.catalog.Type; import org.apache.doris.common.AnalysisException; import org.apache.doris.common.util.VectorizedUtil; @@ -117,6 +118,18 @@ public class ArithmeticExpr extends Expr { Operator.DIVIDE.getName(), Lists.newArrayList(Type.DECIMALV2, Type.DECIMALV2), Type.DECIMALV2, Function.NullableMode.ALWAYS_NULLABLE)); + functionSet.addBuiltin(ScalarFunction.createBuiltinOperator( + Operator.DIVIDE.getName(), + Lists.newArrayList(Type.DECIMAL32, Type.DECIMAL32), + Type.DECIMAL32, Function.NullableMode.ALWAYS_NULLABLE)); + functionSet.addBuiltin(ScalarFunction.createBuiltinOperator( + Operator.DIVIDE.getName(), + Lists.newArrayList(Type.DECIMAL64, Type.DECIMAL64), + Type.DECIMAL64, Function.NullableMode.ALWAYS_NULLABLE)); + functionSet.addBuiltin(ScalarFunction.createBuiltinOperator( + Operator.DIVIDE.getName(), + Lists.newArrayList(Type.DECIMAL128, Type.DECIMAL128), + Type.DECIMAL128, Function.NullableMode.ALWAYS_NULLABLE)); // MOD(), FACTORIAL(), BITAND(), BITOR(), BITXOR(), and BITNOT() are registered as // builtins, see palo_functions.py @@ -152,6 +165,18 @@ public class ArithmeticExpr extends Expr { Operator.DIVIDE.getName(), Lists.newArrayList(Type.DECIMALV2, Type.DECIMALV2), Type.DECIMALV2, Function.NullableMode.ALWAYS_NULLABLE)); + functionSet.addBuiltin(ScalarFunction.createVecBuiltinOperator( + Operator.DIVIDE.getName(), + Lists.newArrayList(Type.DECIMAL32, Type.DECIMAL32), + Type.DECIMAL32, Function.NullableMode.ALWAYS_NULLABLE)); + functionSet.addBuiltin(ScalarFunction.createVecBuiltinOperator( + Operator.DIVIDE.getName(), + Lists.newArrayList(Type.DECIMAL64, Type.DECIMAL64), + Type.DECIMAL64, Function.NullableMode.ALWAYS_NULLABLE)); + functionSet.addBuiltin(ScalarFunction.createVecBuiltinOperator( + Operator.DIVIDE.getName(), + Lists.newArrayList(Type.DECIMAL128, Type.DECIMAL128), + Type.DECIMAL128, Function.NullableMode.ALWAYS_NULLABLE)); functionSet.addBuiltin(ScalarFunction.createVecBuiltinOperator( Operator.MOD.getName(), @@ -165,6 +190,18 @@ public class ArithmeticExpr extends Expr { Operator.MOD.getName(), Lists.newArrayList(Type.DECIMALV2, Type.DECIMALV2), Type.DECIMALV2, Function.NullableMode.ALWAYS_NULLABLE)); + functionSet.addBuiltin(ScalarFunction.createVecBuiltinOperator( + Operator.MOD.getName(), + Lists.newArrayList(Type.DECIMAL32, Type.DECIMAL32), + Type.DECIMAL32, Function.NullableMode.ALWAYS_NULLABLE)); + functionSet.addBuiltin(ScalarFunction.createVecBuiltinOperator( + Operator.MOD.getName(), + Lists.newArrayList(Type.DECIMAL64, Type.DECIMAL64), + Type.DECIMAL64, Function.NullableMode.ALWAYS_NULLABLE)); + functionSet.addBuiltin(ScalarFunction.createVecBuiltinOperator( + Operator.MOD.getName(), + Lists.newArrayList(Type.DECIMAL128, Type.DECIMAL128), + Type.DECIMAL128, Function.NullableMode.ALWAYS_NULLABLE)); for (int i = 0; i < Type.getIntegerTypes().size(); i++) { Type t1 = Type.getIntegerTypes().get(i); @@ -236,7 +273,7 @@ public class ArithmeticExpr extends Expr { @Override protected void toThrift(TExprNode msg) { msg.node_type = TExprNodeType.ARITHMETIC_EXPR; - if (!type.isDecimalV2()) { + if (!(type.isDecimalV2() && type.isDecimalV3())) { msg.setOpcode(op.getOpcode()); msg.setOutputColumn(outputColumn); } @@ -272,6 +309,12 @@ public class ArithmeticExpr extends Expr { return Type.DOUBLE; } else if (pt1 == PrimitiveType.DECIMALV2 || pt2 == PrimitiveType.DECIMALV2) { return Type.DECIMALV2; + } else if (pt1 == PrimitiveType.DECIMAL32 || pt2 == PrimitiveType.DECIMAL32) { + return Type.DECIMAL32; + } else if (pt1 == PrimitiveType.DECIMAL64 || pt2 == PrimitiveType.DECIMAL64) { + return Type.DECIMAL64; + } else if (pt1 == PrimitiveType.DECIMAL128 || pt2 == PrimitiveType.DECIMAL128) { + return Type.DECIMAL128; } else if (pt1 == PrimitiveType.LARGEINT || pt2 == PrimitiveType.LARGEINT) { return Type.LARGEINT; } else { @@ -309,6 +352,133 @@ public class ArithmeticExpr extends Expr { } } + private void analyzeNoneDecimalOp(Type t1, Type t2) throws AnalysisException { + Type commonType; + switch (op) { + case MULTIPLY: + case ADD: + case SUBTRACT: + if (t1.isDecimalV2() || t2.isDecimalV2()) { + castBinaryOp(findCommonType(t1, t2)); + } + if (isConstant()) { + castUpperInteger(t1, t2); + } + break; + case MOD: + if (t1.isDecimalV2() || t2.isDecimalV2()) { + castBinaryOp(findCommonType(t1, t2)); + } else if ((t1.isFloatingPointType() || t2.isFloatingPointType()) && !t1.equals(t2)) { + castBinaryOp(Type.DOUBLE); + } + break; + case INT_DIVIDE: + if (!t1.isFixedPointType() || !t2.isFloatingPointType()) { + castBinaryOp(Type.BIGINT); + } + break; + case DIVIDE: + t1 = getChild(0).getType().getNumResultType(); + t2 = getChild(1).getType().getNumResultType(); + commonType = findCommonType(t1, t2); + if (commonType.getPrimitiveType() == PrimitiveType.BIGINT + || commonType.getPrimitiveType() == PrimitiveType.LARGEINT) { + commonType = Type.DOUBLE; + } + castBinaryOp(commonType); + break; + case BITAND: + case BITOR: + case BITXOR: + if (t1 == Type.BOOLEAN && t2 == Type.BOOLEAN) { + t1 = Type.TINYINT; + t2 = Type.TINYINT; + } + commonType = Type.getAssignmentCompatibleType(t1, t2, false); + if (commonType.getPrimitiveType().ordinal() > PrimitiveType.LARGEINT.ordinal()) { + commonType = Type.BIGINT; + } + type = castBinaryOp(commonType); + break; + default: + Preconditions.checkState(false, + "Unknown arithmetic operation " + op.toString() + " in: " + this.toSql()); + break; + } + } + + /** + * Convert integer type to decimal type. + */ + public static Type convertIntToDecimalV3Type(Type type) throws AnalysisException { + if (type.isLargeIntType()) { + return ScalarType.createDecimalType(ScalarType.MAX_DECIMAL128_PRECISION, 0); + } else if (type.isBigIntType()) { + return ScalarType.createDecimalType(ScalarType.MAX_DECIMAL64_PRECISION, 0); + } else if (type.isInteger32Type()) { + return ScalarType.createDecimalType(ScalarType.MAX_DECIMAL32_PRECISION, 0); + } else { + Preconditions.checkState(false, + "Implicit converting to decimal for arithmetic operations only support integer"); + return Type.INVALID; + } + } + + private void analyzeDecimalV3Op(Type t1, Type t2) throws AnalysisException { + Type t1TargetType = t1; + Type t2TargetType = t2; + switch (op) { + case MULTIPLY: + case ADD: + case SUBTRACT: + case MOD: + case DIVIDE: + if (t1.isFloatingPointType() || t2.isFloatingPointType()) { + castBinaryOp(type.DOUBLE); + break; + } + if (t1.isFixedPointType()) { + t1TargetType = convertIntToDecimalV3Type(t1); + castChild(t1TargetType, 0); + } + if (t2.isFixedPointType()) { + t2TargetType = convertIntToDecimalV3Type(t2); + castChild(t2TargetType, 1); + } + final int t1Precision = ((ScalarType) t1TargetType).getScalarPrecision(); + final int t2Precision = ((ScalarType) t2TargetType).getScalarPrecision(); + final int t1Scale = ((ScalarType) t1TargetType).getScalarScale(); + final int t2Scale = ((ScalarType) t2TargetType).getScalarScale(); + final int precision = Math.max(t1Precision, t2Precision); + int scale = Math.max(t1Scale, t2Scale); + if (op == Operator.MULTIPLY) { + scale = t1Scale + t2Scale; + } + if (op == Operator.DIVIDE) { + scale = t1Scale; + } + type = ScalarType.createWiderDecimalV3Type(precision, scale); + break; + case INT_DIVIDE: + if (!t1.isFixedPointType() || !t2.isFloatingPointType()) { + castBinaryOp(Type.BIGINT); + } + break; + case BITAND: + case BITOR: + case BITXOR: + type = castBinaryOp(Type.BIGINT); + break; + case BITNOT: + case FACTORIAL: + break; + default: + Preconditions.checkState(false, + "Unknown arithmetic operation " + op.toString() + " in: " + this.toSql()); + break; + } + } + @Override public void analyzeImpl(Analyzer analyzer) throws AnalysisException { if (VectorizedUtil.isVectorized()) { @@ -335,7 +505,6 @@ public class ArithmeticExpr extends Expr { Type t1 = getChild(0).getType(); Type t2 = getChild(1).getType(); - Type commonType; // Support null operation if (t1.isNull() || t2.isNull()) { @@ -354,63 +523,19 @@ public class ArithmeticExpr extends Expr { t2 = t2.getNumResultType(); } - switch (op) { - case MULTIPLY: - case ADD: - case SUBTRACT: - if (t1.isDecimalV2() || t2.isDecimalV2()) { - castBinaryOp(findCommonType(t1, t2)); - } - if (isConstant()) { - castUpperInteger(t1, t2); - } - break; - case MOD: - if (t1.isDecimalV2() || t2.isDecimalV2()) { - castBinaryOp(findCommonType(t1, t2)); - } else if ((t1.isFloatingPointType() || t2.isFloatingPointType()) && !t1.equals(t2)) { - castBinaryOp(Type.DOUBLE); - } - break; - case INT_DIVIDE: - if (!t1.isFixedPointType() || !t2.isFloatingPointType()) { - castBinaryOp(Type.BIGINT); - } - break; - case DIVIDE: - t1 = getChild(0).getType().getNumResultType(); - t2 = getChild(1).getType().getNumResultType(); - commonType = findCommonType(t1, t2); - if (commonType.getPrimitiveType() == PrimitiveType.BIGINT - || commonType.getPrimitiveType() == PrimitiveType.LARGEINT) { - commonType = Type.DOUBLE; - } - castBinaryOp(commonType); - break; - case BITAND: - case BITOR: - case BITXOR: - if (t1 == Type.BOOLEAN && t2 == Type.BOOLEAN) { - t1 = Type.TINYINT; - t2 = Type.TINYINT; - } - commonType = Type.getAssignmentCompatibleType(t1, t2, false); - if (commonType.getPrimitiveType().ordinal() > PrimitiveType.LARGEINT.ordinal()) { - commonType = Type.BIGINT; - } - type = castBinaryOp(commonType); - break; - default: - Preconditions.checkState(false, - "Unknown arithmetic operation " + op.toString() + " in: " + this.toSql()); - break; + if (t1.isDecimalV3() || t2.isDecimalV3()) { + analyzeDecimalV3Op(t1, t2); + } else { + analyzeNoneDecimalOp(t1, t2); } fn = getBuiltinFunction(op.name, collectChildReturnTypes(), Function.CompareMode.IS_IDENTICAL); if (fn == null) { Preconditions.checkState(false, String.format( "No match for vec function '%s' with operand types %s and %s", toSql(), t1, t2)); } - type = fn.getReturnType(); + if (!type.isValid()) { + type = fn.getReturnType(); + } } else { // bitnot is the only unary op, deal with it here if (op == Operator.BITNOT) { diff --git a/fe/fe-core/src/main/java/org/apache/doris/analysis/BinaryPredicate.java b/fe/fe-core/src/main/java/org/apache/doris/analysis/BinaryPredicate.java index d43143e3db..e11ebe4d4b 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/analysis/BinaryPredicate.java +++ b/fe/fe-core/src/main/java/org/apache/doris/analysis/BinaryPredicate.java @@ -336,6 +336,9 @@ public class BinaryPredicate extends Predicate implements Writable { if (t1 == PrimitiveType.BIGINT && t2 == PrimitiveType.BIGINT) { return Type.getAssignmentCompatibleType(getChild(0).getType(), getChild(1).getType(), false); } + if (t1.isDecimalV3Type() || t2.isDecimalV3Type()) { + return Type.getAssignmentCompatibleType(getChild(0).getType(), getChild(1).getType(), false); + } if ((t1 == PrimitiveType.BIGINT || t1 == PrimitiveType.DECIMALV2) && (t2 == PrimitiveType.BIGINT || t2 == PrimitiveType.DECIMALV2)) { return Type.DECIMALV2; diff --git a/fe/fe-core/src/main/java/org/apache/doris/analysis/BuiltinAggregateFunction.java b/fe/fe-core/src/main/java/org/apache/doris/analysis/BuiltinAggregateFunction.java index 13a58967e6..8bd6c50d42 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/analysis/BuiltinAggregateFunction.java +++ b/fe/fe-core/src/main/java/org/apache/doris/analysis/BuiltinAggregateFunction.java @@ -66,8 +66,8 @@ public class BuiltinAggregateFunction extends Function { } @Override - public TFunction toThrift() { - TFunction fn = super.toThrift(); + public TFunction toThrift(Type realReturnType, Type[] realArgTypes) { + TFunction fn = super.toThrift(realReturnType, realArgTypes); // TODO: for now, just put the op_ enum as the id. if (op == BuiltinAggregateFunction.Operator.FIRST_VALUE_REWRITE) { fn.setId(0); diff --git a/fe/fe-core/src/main/java/org/apache/doris/analysis/CastExpr.java b/fe/fe-core/src/main/java/org/apache/doris/analysis/CastExpr.java index e4f8bc1ea0..1ada8c8e4c 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/analysis/CastExpr.java +++ b/fe/fe-core/src/main/java/org/apache/doris/analysis/CastExpr.java @@ -138,7 +138,8 @@ public class CastExpr extends Expr { private static boolean disableRegisterCastingFunction(Type fromType, Type toType) { // Disable casting from boolean to decimal or datetime or date - if (fromType.isBoolean() && (toType.equals(Type.DECIMALV2) || toType.isDateType())) { + if (fromType.isBoolean() && (toType.equals(Type.DECIMALV2) || toType.isDecimalV3() + || toType.isDateType())) { return true; } @@ -159,8 +160,8 @@ public class CastExpr extends Expr { if (toType.isNull() || disableRegisterCastingFunction(fromType, toType)) { continue; } - String beClass = toType.isDecimalV2() - || fromType.isDecimalV2() ? "DecimalV2Operators" : "CastFunctions"; + String beClass = toType.isDecimalV2() || fromType.isDecimalV2() + ? "DecimalV2Operators" : "CastFunctions"; if (fromType.isTime()) { beClass = "TimeOperators"; } @@ -264,8 +265,14 @@ public class CastExpr extends Expr { // this cast may result in loss of precision, but the user requested it if (childType.matchesType(type)) { - noOp = true; - return; + // For types which has precision and scale, we also need to check quality between precisions and scales + if (!PrimitiveType.typeWithPrecision.contains( + type.getPrimitiveType()) || ((((ScalarType) type).decimalPrecision() + == ((ScalarType) childType).decimalPrecision()) && (((ScalarType) type).decimalScale() + == ((ScalarType) childType).decimalScale()))) { + noOp = true; + return; + } } // select stmt will make BE coredump when its castExpr is like cast(int as array<>), // it is necessary to check if it is castable before creating fn. @@ -387,7 +394,7 @@ public class CastExpr extends Expr { return new IntLiteral(value.getLongValue(), type); } else if (type.isLargeIntType()) { return new LargeIntLiteral(value.getStringValue()); - } else if (type.isDecimalV2()) { + } else if (type.isDecimalV2() || type.isDecimalV3()) { return new DecimalLiteral(value.getStringValue()); } else if (type.isFloatingPointType()) { @@ -456,6 +463,9 @@ public class CastExpr extends Expr { ScalarType newTargetType = null; switch (primitiveType) { case DECIMALV2: + case DECIMAL32: + case DECIMAL64: + case DECIMAL128: // normal decimal if (targetType.getPrecision() != 0) { newTargetType = targetType; @@ -465,7 +475,7 @@ public class CastExpr extends Expr { int scale = getDigital(targetType.getScalarScaleStr(), parameters, inputParamsExprs); if (precision != -1 && scale != -1) { newTargetType = ScalarType.createType(primitiveType, 0, precision, scale); - } else if (precision != -1 && scale == -1) { + } else if (precision != -1) { newTargetType = ScalarType.createType(primitiveType, 0, precision, ScalarType.DEFAULT_SCALE); } break; diff --git a/fe/fe-core/src/main/java/org/apache/doris/analysis/ColumnDef.java b/fe/fe-core/src/main/java/org/apache/doris/analysis/ColumnDef.java index 75b3290030..afae03e99f 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/analysis/ColumnDef.java +++ b/fe/fe-core/src/main/java/org/apache/doris/analysis/ColumnDef.java @@ -334,6 +334,9 @@ public class ColumnDef { new FloatLiteral(defaultValue); break; case DECIMALV2: + case DECIMAL32: + case DECIMAL64: + case DECIMAL128: DecimalLiteral decimalLiteral = new DecimalLiteral(defaultValue); decimalLiteral.checkPrecisionAndScale(scalarType.getScalarPrecision(), scalarType.getScalarScale()); break; diff --git a/fe/fe-core/src/main/java/org/apache/doris/analysis/CreateFunctionStmt.java b/fe/fe-core/src/main/java/org/apache/doris/analysis/CreateFunctionStmt.java index d5f7253d6c..94857ca13a 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/analysis/CreateFunctionStmt.java +++ b/fe/fe-core/src/main/java/org/apache/doris/analysis/CreateFunctionStmt.java @@ -528,6 +528,9 @@ public class CreateFunctionStmt extends DdlStmt { .put(PrimitiveType.DATETIMEV2, Sets.newHashSet(LocalDateTime.class)) .put(PrimitiveType.LARGEINT, Sets.newHashSet(BigInteger.class)) .put(PrimitiveType.DECIMALV2, Sets.newHashSet(BigDecimal.class)) + .put(PrimitiveType.DECIMAL32, Sets.newHashSet(BigDecimal.class)) + .put(PrimitiveType.DECIMAL64, Sets.newHashSet(BigDecimal.class)) + .put(PrimitiveType.DECIMAL128, Sets.newHashSet(BigDecimal.class)) .build(); private void checkUdfType(Class clazz, Method method, Type expType, Class pType, String pname) @@ -633,11 +636,24 @@ public class CreateFunctionStmt extends DdlStmt { typeBuilder.setId(Types.PGenericType.TypeId.DATETIMEV2); break; case DECIMALV2: + case DECIMAL128: typeBuilder.setId(Types.PGenericType.TypeId.DECIMAL128) .getDecimalTypeBuilder() .setPrecision(((ScalarType) arg).getScalarPrecision()) .setScale(((ScalarType) arg).getScalarScale()); break; + case DECIMAL32: + typeBuilder.setId(Types.PGenericType.TypeId.DECIMAL32) + .getDecimalTypeBuilder() + .setPrecision(((ScalarType) arg).getScalarPrecision()) + .setScale(((ScalarType) arg).getScalarScale()); + break; + case DECIMAL64: + typeBuilder.setId(Types.PGenericType.TypeId.DECIMAL64) + .getDecimalTypeBuilder() + .setPrecision(((ScalarType) arg).getScalarPrecision()) + .setScale(((ScalarType) arg).getScalarScale()); + break; case LARGEINT: typeBuilder.setId(Types.PGenericType.TypeId.INT128); break; diff --git a/fe/fe-core/src/main/java/org/apache/doris/analysis/DecimalLiteral.java b/fe/fe-core/src/main/java/org/apache/doris/analysis/DecimalLiteral.java index 3e5bf9abc7..1c013cf5f1 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/analysis/DecimalLiteral.java +++ b/fe/fe-core/src/main/java/org/apache/doris/analysis/DecimalLiteral.java @@ -18,6 +18,7 @@ package org.apache.doris.analysis; import org.apache.doris.catalog.PrimitiveType; +import org.apache.doris.catalog.ScalarType; import org.apache.doris.catalog.Type; import org.apache.doris.common.AnalysisException; import org.apache.doris.common.NotImplementedException; @@ -72,9 +73,35 @@ public class DecimalLiteral extends LiteralExpr { return new DecimalLiteral(this); } + /** + * Get precision and scale of java BigDecimal. + * The precision is the number of digits in the unscaled value for BigDecimal. + * The unscaled value of BigDecimal computes this * 10^this.scale(). + * If zero or positive, the scale is the number of digits to the right of the decimal point. + * If negative, the unscaled value of the number is multiplied by ten to the power of the negation of the scale. + * There are two scenarios that do not meet the limit: 0 < P and 0 <= S <= P + * case1: S >= 0 and S > P. i.e. BigDecimal(0.01234), precision = 4, scale = 5 + * case2: S < 0. i.e. BigDecimal(2000), precision = 1, scale = -3 + */ + public static int getBigDecimalPrecision(BigDecimal decimal) { + int scale = decimal.scale(); + int precision = decimal.precision(); + if (scale < 0) { + return Math.abs(scale) + precision; + } else { + return Math.max(scale, precision); + } + } + + public static int getBigDecimalScale(BigDecimal decimal) { + return Math.max(0, decimal.scale()); + } + private void init(BigDecimal value) { this.value = value; - type = Type.DECIMALV2; + int precision = getBigDecimalPrecision(this.value); + int scale = getBigDecimalScale(this.value); + type = ScalarType.createDecimalType(precision, scale); } public BigDecimal getValue() { @@ -132,6 +159,9 @@ public class DecimalLiteral extends LiteralExpr { buffer.putLong(value.longValue()); break; case DECIMALV2: + case DECIMAL32: + case DECIMAL64: + case DECIMAL128: buffer = ByteBuffer.allocate(12); buffer.order(ByteOrder.LITTLE_ENDIAN); @@ -232,16 +262,32 @@ public class DecimalLiteral extends LiteralExpr { } public void roundCeiling() { - value = value.setScale(0, RoundingMode.CEILING); + roundCeiling(0); } public void roundFloor() { - value = value.setScale(0, RoundingMode.FLOOR); + roundFloor(0); + } + + public void roundCeiling(int newScale) { + value = value.setScale(newScale, RoundingMode.CEILING); + type = ScalarType.createDecimalType(((ScalarType) type) + .getPrimitiveType(), ((ScalarType) type).getScalarPrecision(), newScale); + } + + public void roundFloor(int newScale) { + value = value.setScale(newScale, RoundingMode.FLOOR); + type = ScalarType.createDecimalType(((ScalarType) type) + .getPrimitiveType(), ((ScalarType) type).getScalarPrecision(), newScale); } @Override protected Expr uncheckedCastTo(Type targetType) throws AnalysisException { - if (targetType.isDecimalV2()) { + if (targetType.isDecimalV2() && type.isDecimalV2()) { + return this; + } else if (targetType.isDecimalV3() && type.isDecimalV3() + && (((ScalarType) targetType).decimalPrecision() == value.precision()) + && (((ScalarType) targetType).decimalScale() == value.precision())) { return this; } else if (targetType.isFloatingPointType()) { return new FloatLiteral(value.doubleValue(), targetType); diff --git a/fe/fe-core/src/main/java/org/apache/doris/analysis/Expr.java b/fe/fe-core/src/main/java/org/apache/doris/analysis/Expr.java index 383a75ba31..3e0b479f36 100755 --- a/fe/fe-core/src/main/java/org/apache/doris/analysis/Expr.java +++ b/fe/fe-core/src/main/java/org/apache/doris/analysis/Expr.java @@ -934,7 +934,7 @@ public abstract class Expr extends TreeNode implements ParseNode, Cloneabl msg.type = type.toThrift(); msg.num_children = children.size(); if (fn != null) { - msg.setFn(fn.toThrift()); + msg.setFn(fn.toThrift(type, collectChildReturnTypes())); if (fn.hasVarArgs()) { msg.setVarargStartIdx(fn.getNumArgs() - 1); } @@ -1295,7 +1295,14 @@ public abstract class Expr extends TreeNode implements ParseNode, Cloneabl public Expr checkTypeCompatibility(Type targetType) throws AnalysisException { if (targetType.getPrimitiveType() != PrimitiveType.ARRAY && targetType.getPrimitiveType() == type.getPrimitiveType()) { - return this; + if (targetType.isDecimalV2() && type.isDecimalV2()) { + return this; + } else if (!PrimitiveType.typeWithPrecision.contains(type.getPrimitiveType())) { + return this; + } else if (((ScalarType) targetType).decimalScale() == ((ScalarType) type).decimalScale() + && ((ScalarType) targetType).decimalPrecision() == ((ScalarType) type).decimalPrecision()) { + return this; + } } // bitmap must match exactly if (targetType.getPrimitiveType() == PrimitiveType.BITMAP) { diff --git a/fe/fe-core/src/main/java/org/apache/doris/analysis/ExpressionFunctions.java b/fe/fe-core/src/main/java/org/apache/doris/analysis/ExpressionFunctions.java index 4325f1168b..670abc3f61 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/analysis/ExpressionFunctions.java +++ b/fe/fe-core/src/main/java/org/apache/doris/analysis/ExpressionFunctions.java @@ -233,7 +233,7 @@ public enum ExpressionFunctions { exprs = new IntLiteral[args.size()]; } else if (argType.isDateType()) { exprs = new DateLiteral[args.size()]; - } else if (argType.isDecimalV2()) { + } else if (argType.isDecimalV2() || argType.isDecimalV3()) { exprs = new DecimalLiteral[args.size()]; } else if (argType.isFloatingPointType()) { exprs = new FloatLiteral[args.size()]; diff --git a/fe/fe-core/src/main/java/org/apache/doris/analysis/FloatLiteral.java b/fe/fe-core/src/main/java/org/apache/doris/analysis/FloatLiteral.java index 37b405390f..4e00a8ea25 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/analysis/FloatLiteral.java +++ b/fe/fe-core/src/main/java/org/apache/doris/analysis/FloatLiteral.java @@ -173,7 +173,7 @@ public class FloatLiteral extends LiteralExpr { @Override protected Expr uncheckedCastTo(Type targetType) throws AnalysisException { - if (!(targetType.isFloatingPointType() || targetType.isDecimalV2())) { + if (!(targetType.isFloatingPointType() || targetType.isDecimalV2() || targetType.isDecimalV3())) { return super.uncheckedCastTo(targetType); } if (targetType.isFloatingPointType()) { @@ -183,7 +183,7 @@ public class FloatLiteral extends LiteralExpr { return floatLiteral; } return this; - } else if (targetType.isDecimalV2()) { + } else if (targetType.isDecimalV2() || targetType.isDecimalV3()) { // the double constructor does an exact translation, use valueOf() instead. return new DecimalLiteral(BigDecimal.valueOf(value)); } diff --git a/fe/fe-core/src/main/java/org/apache/doris/analysis/FunctionCallExpr.java b/fe/fe-core/src/main/java/org/apache/doris/analysis/FunctionCallExpr.java index 1ffd2060db..2505a2501d 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/analysis/FunctionCallExpr.java +++ b/fe/fe-core/src/main/java/org/apache/doris/analysis/FunctionCallExpr.java @@ -63,7 +63,28 @@ import java.util.Set; // TODO: for aggregations, we need to unify the code paths for builtins and UDAs. public class FunctionCallExpr extends Expr { + private static final ImmutableSet STDDEV_FUNCTION_SET = + new ImmutableSortedSet.Builder(String.CASE_INSENSITIVE_ORDER) + .add("stddev").add("stddev_val").add("stddev_samp").add("stddev_pop") + .add("variance").add("variance_pop").add("variance_pop").add("var_samp").add("var_pop").build(); + private static final ImmutableSet DECIMAL_SAME_TYPE_SET = + new ImmutableSortedSet.Builder(String.CASE_INSENSITIVE_ORDER) + .add("min").add("max").add("lead").add("lag") + .add("first_value").add("last_value").add("abs") + .add("positive").add("negative").build(); + private static final ImmutableSet DECIMAL_WIDER_TYPE_SET = + new ImmutableSortedSet.Builder(String.CASE_INSENSITIVE_ORDER) + .add("sum").add("avg").add("multi_distinct_sum").build(); + private static final ImmutableSet DECIMAL_FUNCTION_SET = + new ImmutableSortedSet.Builder<>(String.CASE_INSENSITIVE_ORDER) + .addAll(DECIMAL_SAME_TYPE_SET) + .addAll(DECIMAL_WIDER_TYPE_SET) + .addAll(STDDEV_FUNCTION_SET).build(); + private static final int STDDEV_DECIMAL_SCALE = 9; + private static final String ELEMENT_EXTRACT_FN_NAME = "%element_extract%"; + private static final Logger LOG = LogManager.getLogger(FunctionCallExpr.class); + private FunctionName fnName; // private BuiltinAggregateFunction.Operator aggOp; private FunctionParams fnParams; @@ -81,12 +102,6 @@ public class FunctionCallExpr extends Expr { // resetAnalysisState() which is used during expr substitution. private boolean isMergeAggFn; - private static final ImmutableSet STDDEV_FUNCTION_SET = - new ImmutableSortedSet.Builder(String.CASE_INSENSITIVE_ORDER) - .add("stddev").add("stddev_val").add("stddev_samp") - .add("variance").add("variance_pop").add("variance_pop").add("var_samp").add("var_pop").build(); - private static final String ELEMENT_EXTRACT_FN_NAME = "%element_extract%"; - // use to record the num of json_object parameters private int originChildSize; // Save the functionCallExpr in the original statement @@ -202,7 +217,7 @@ public class FunctionCallExpr extends Expr { sb.append("1"); } else if (type.isFixedPointType()) { sb.append("2"); - } else if (type.isFloatingPointType() || type.isDecimalV2()) { + } else if (type.isFloatingPointType() || type.isDecimalV2() || type.isDecimalV3()) { sb.append("3"); } else if (type.isTime()) { sb.append("4"); @@ -1059,6 +1074,20 @@ public class FunctionCallExpr extends Expr { } else { this.type = fn.getReturnType(); } + + // DECIMAL need to pass precision and scale to be + if (DECIMAL_FUNCTION_SET.contains(fn.getFunctionName().getFunction()) + && (this.type.isDecimalV2() || this.type.isDecimalV3())) { + if (DECIMAL_SAME_TYPE_SET.contains(fnName.getFunction())) { + this.type = argTypes[0]; + } else if (DECIMAL_WIDER_TYPE_SET.contains(fnName.getFunction())) { + this.type = ScalarType.createDecimalType(ScalarType.MAX_DECIMAL128_PRECISION, + ((ScalarType) argTypes[0]).getScalarScale()); + } else if (STDDEV_FUNCTION_SET.contains(fnName.getFunction())) { + // for all stddev function, use decimal(38,9) as computing result + this.type = ScalarType.createDecimalType(ScalarType.MAX_DECIMAL128_PRECISION, STDDEV_DECIMAL_SCALE); + } + } // rewrite return type if is nested type function analyzeNestedFunction(); } diff --git a/fe/fe-core/src/main/java/org/apache/doris/analysis/IndexDef.java b/fe/fe-core/src/main/java/org/apache/doris/analysis/IndexDef.java index fd67d085c5..fca80403c1 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/analysis/IndexDef.java +++ b/fe/fe-core/src/main/java/org/apache/doris/analysis/IndexDef.java @@ -133,8 +133,8 @@ public class IndexDef { if (indexType == IndexType.BITMAP) { String indexColName = column.getName(); PrimitiveType colType = column.getDataType(); - if (!(colType.isDateType() || colType.isDecimalV2Type() || colType.isFixedPointType() - || colType.isStringType() || colType == PrimitiveType.BOOLEAN)) { + if (!(colType.isDateType() || colType.isDecimalV2Type() || colType.isDecimalV3Type() + || colType.isFixedPointType() || colType.isStringType() || colType == PrimitiveType.BOOLEAN)) { throw new AnalysisException(colType + " is not supported in bitmap index. " + "invalid column: " + indexColName); } else if ((keysType == KeysType.AGG_KEYS && !column.isKey())) { diff --git a/fe/fe-core/src/main/java/org/apache/doris/analysis/LargeIntLiteral.java b/fe/fe-core/src/main/java/org/apache/doris/analysis/LargeIntLiteral.java index 33f13c74e1..9b344caa42 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/analysis/LargeIntLiteral.java +++ b/fe/fe-core/src/main/java/org/apache/doris/analysis/LargeIntLiteral.java @@ -190,7 +190,7 @@ public class LargeIntLiteral extends LiteralExpr { protected Expr uncheckedCastTo(Type targetType) throws AnalysisException { if (targetType.isFloatingPointType()) { return new FloatLiteral(new Double(value.doubleValue()), targetType); - } else if (targetType.isDecimalV2()) { + } else if (targetType.isDecimalV2() || targetType.isDecimalV3()) { return new DecimalLiteral(new BigDecimal(value)); } else if (targetType.isNumericType()) { try { diff --git a/fe/fe-core/src/main/java/org/apache/doris/analysis/LiteralExpr.java b/fe/fe-core/src/main/java/org/apache/doris/analysis/LiteralExpr.java index 0c41bf3a15..6cce229ad6 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/analysis/LiteralExpr.java +++ b/fe/fe-core/src/main/java/org/apache/doris/analysis/LiteralExpr.java @@ -69,6 +69,9 @@ public abstract class LiteralExpr extends Expr implements Comparable 9) { - throw new AnalysisException("Scale of decimal must between 0 and 9." - + " Scale was set to: " + scale + "."); + throw new AnalysisException( + "Scale of decimal must between 0 and 9." + " Scale was set to: " + scale + "."); } - // scale < precision - if (scale >= precision) { - throw new AnalysisException("Scale of decimal must be smaller than precision." - + " Scale is " + scale + " and precision is " + precision); + break; + } + case DECIMAL32: { + int decimal32Precision = scalarType.decimalPrecision(); + int decimal32Scale = scalarType.decimalScale(); + if (decimal32Precision < 1 || decimal32Precision > ScalarType.MAX_DECIMAL32_PRECISION) { + throw new AnalysisException("Precision of decimal must between 1 and 9." + + " Precision was set to: " + decimal32Precision + "."); + } + // scale >= 0 + if (decimal32Scale < 0) { + throw new AnalysisException( + "Scale of decimal must not be less than 0." + " Scale was set to: " + decimal32Scale + "."); + } + break; + } + case DECIMAL64: { + int decimal64Precision = scalarType.decimalPrecision(); + int decimal64Scale = scalarType.decimalScale(); + if (decimal64Precision < 1 || decimal64Precision > ScalarType.MAX_DECIMAL64_PRECISION) { + throw new AnalysisException("Precision of decimal64 must between 1 and 18." + + " Precision was set to: " + decimal64Precision + "."); + } + // scale >= 0 + if (decimal64Scale < 0) { + throw new AnalysisException( + "Scale of decimal must not be less than 0." + " Scale was set to: " + decimal64Scale + "."); + } + break; + } + case DECIMAL128: { + int decimal128Precision = scalarType.decimalPrecision(); + int decimal128Scale = scalarType.decimalScale(); + if (decimal128Precision < 1 || decimal128Precision > ScalarType.MAX_DECIMAL128_PRECISION) { + throw new AnalysisException("Precision of decimal128 must between 1 and 38." + + " Precision was set to: " + decimal128Precision + "."); + } + // scale >= 0 + if (decimal128Scale < 0) { + throw new AnalysisException("Scale of decimal must not be less than 0." + " Scale was set to: " + + decimal128Scale + "."); } break; } diff --git a/fe/fe-core/src/main/java/org/apache/doris/catalog/AggregateFunction.java b/fe/fe-core/src/main/java/org/apache/doris/catalog/AggregateFunction.java index d0b5e28733..72957c0eff 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/catalog/AggregateFunction.java +++ b/fe/fe-core/src/main/java/org/apache/doris/catalog/AggregateFunction.java @@ -552,8 +552,8 @@ public class AggregateFunction extends Function { } @Override - public TFunction toThrift() { - TFunction fn = super.toThrift(); + public TFunction toThrift(Type realReturnType, Type[] realArgTypes) { + TFunction fn = super.toThrift(realReturnType, realArgTypes); TAggregateFunction aggFn = new TAggregateFunction(); aggFn.setIsAnalyticOnlyFn(isAnalyticFn && !isAggregateFn); aggFn.setUpdateFnSymbol(updateFnSymbol); diff --git a/fe/fe-core/src/main/java/org/apache/doris/catalog/Column.java b/fe/fe-core/src/main/java/org/apache/doris/catalog/Column.java index 25cd47f628..8e9bbe4ac9 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/catalog/Column.java +++ b/fe/fe-core/src/main/java/org/apache/doris/catalog/Column.java @@ -463,8 +463,8 @@ public class Column implements Writable { } // now we support convert decimal to varchar type - if (getDataType() == PrimitiveType.DECIMALV2 && (other.getDataType() == PrimitiveType.VARCHAR - || other.getDataType() == PrimitiveType.STRING)) { + if ((getDataType() == PrimitiveType.DECIMALV2 || getDataType().isDecimalV3Type()) + && (other.getDataType() == PrimitiveType.VARCHAR || other.getDataType() == PrimitiveType.STRING)) { return; } } @@ -678,6 +678,9 @@ public class Column implements Writable { sb.append(String.format(typeStringMap.get(dataType), getStrLen())); break; case DECIMALV2: + case DECIMAL32: + case DECIMAL64: + case DECIMAL128: sb.append(String.format(typeStringMap.get(dataType), getPrecision(), getScale())); break; case ARRAY: diff --git a/fe/fe-core/src/main/java/org/apache/doris/catalog/ColumnType.java b/fe/fe-core/src/main/java/org/apache/doris/catalog/ColumnType.java index ae82bfeaa8..872322ed4d 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/catalog/ColumnType.java +++ b/fe/fe-core/src/main/java/org/apache/doris/catalog/ColumnType.java @@ -109,6 +109,27 @@ public abstract class ColumnType { schemaChangeMatrix[PrimitiveType.DECIMALV2.ordinal()][PrimitiveType.VARCHAR.ordinal()] = true; schemaChangeMatrix[PrimitiveType.DECIMALV2.ordinal()][PrimitiveType.STRING.ordinal()] = true; + schemaChangeMatrix[PrimitiveType.DECIMALV2.ordinal()][PrimitiveType.DECIMAL32.ordinal()] = true; + schemaChangeMatrix[PrimitiveType.DECIMALV2.ordinal()][PrimitiveType.DECIMAL64.ordinal()] = true; + schemaChangeMatrix[PrimitiveType.DECIMALV2.ordinal()][PrimitiveType.DECIMAL128.ordinal()] = true; + + schemaChangeMatrix[PrimitiveType.DECIMAL32.ordinal()][PrimitiveType.VARCHAR.ordinal()] = true; + schemaChangeMatrix[PrimitiveType.DECIMAL32.ordinal()][PrimitiveType.STRING.ordinal()] = true; + schemaChangeMatrix[PrimitiveType.DECIMAL32.ordinal()][PrimitiveType.DECIMALV2.ordinal()] = true; + schemaChangeMatrix[PrimitiveType.DECIMAL32.ordinal()][PrimitiveType.DECIMAL64.ordinal()] = true; + schemaChangeMatrix[PrimitiveType.DECIMAL32.ordinal()][PrimitiveType.DECIMAL128.ordinal()] = true; + + schemaChangeMatrix[PrimitiveType.DECIMAL64.ordinal()][PrimitiveType.VARCHAR.ordinal()] = true; + schemaChangeMatrix[PrimitiveType.DECIMAL64.ordinal()][PrimitiveType.STRING.ordinal()] = true; + schemaChangeMatrix[PrimitiveType.DECIMAL64.ordinal()][PrimitiveType.DECIMAL32.ordinal()] = true; + schemaChangeMatrix[PrimitiveType.DECIMAL64.ordinal()][PrimitiveType.DECIMALV2.ordinal()] = true; + schemaChangeMatrix[PrimitiveType.DECIMAL64.ordinal()][PrimitiveType.DECIMAL128.ordinal()] = true; + + schemaChangeMatrix[PrimitiveType.DECIMAL128.ordinal()][PrimitiveType.VARCHAR.ordinal()] = true; + schemaChangeMatrix[PrimitiveType.DECIMAL128.ordinal()][PrimitiveType.STRING.ordinal()] = true; + schemaChangeMatrix[PrimitiveType.DECIMAL128.ordinal()][PrimitiveType.DECIMAL32.ordinal()] = true; + schemaChangeMatrix[PrimitiveType.DECIMAL128.ordinal()][PrimitiveType.DECIMAL64.ordinal()] = true; + schemaChangeMatrix[PrimitiveType.DECIMAL128.ordinal()][PrimitiveType.DECIMALV2.ordinal()] = true; schemaChangeMatrix[PrimitiveType.DATETIME.ordinal()][PrimitiveType.DATE.ordinal()] = true; schemaChangeMatrix[PrimitiveType.DATE.ordinal()][PrimitiveType.DATETIME.ordinal()] = true; diff --git a/fe/fe-core/src/main/java/org/apache/doris/catalog/Function.java b/fe/fe-core/src/main/java/org/apache/doris/catalog/Function.java index 93c218d0b4..500e709495 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/catalog/Function.java +++ b/fe/fe-core/src/main/java/org/apache/doris/catalog/Function.java @@ -452,7 +452,7 @@ public class Function implements Writable { } } - public TFunction toThrift() { + public TFunction toThrift(Type realReturnType, Type[] realArgTypes) { TFunction fn = new TFunction(); fn.setSignature(signatureString()); fn.setName(name.toThrift()); @@ -460,8 +460,14 @@ public class Function implements Writable { if (location != null) { fn.setHdfsLocation(location.getLocation()); } - fn.setArgTypes(Type.toThrift(argTypes)); - fn.setRetType(getReturnType().toThrift()); + fn.setArgTypes(Type.toThrift(Lists.newArrayList(argTypes), Lists.newArrayList(realArgTypes))); + // For types with different precisions and scales, return type only indicates a type with default + // precision and scale so we need to transform it to the correct type. + if (PrimitiveType.typeWithPrecision.contains(realReturnType.getPrimitiveType())) { + fn.setRetType(realReturnType.toThrift()); + } else { + fn.setRetType(getReturnType().toThrift()); + } fn.setHasVarArgs(hasVarArgs); // TODO: Comment field is missing? // fn.setComment(comment) @@ -512,6 +518,12 @@ public class Function implements Writable { return "datetime_val"; case DECIMALV2: return "decimalv2_val"; + case DECIMAL32: + return "decimal32_val"; + case DECIMAL64: + return "decimal64_val"; + case DECIMAL128: + return "decimal128_val"; default: Preconditions.checkState(false, t.toString()); return ""; @@ -554,6 +566,12 @@ public class Function implements Writable { return "DateTimeVal"; case DECIMALV2: return "DecimalV2Val"; + case DECIMAL32: + return "Decimal32Val"; + case DECIMAL64: + return "Decimal64Val"; + case DECIMAL128: + return "Decimal128Val"; default: Preconditions.checkState(false, t.toString()); return ""; diff --git a/fe/fe-core/src/main/java/org/apache/doris/catalog/FunctionSet.java b/fe/fe-core/src/main/java/org/apache/doris/catalog/FunctionSet.java index 13a4d938ea..8a1a4b0c81 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/catalog/FunctionSet.java +++ b/fe/fe-core/src/main/java/org/apache/doris/catalog/FunctionSet.java @@ -252,6 +252,9 @@ public class FunctionSet { .put(Type.DOUBLE, Type.DOUBLE) .put(Type.LARGEINT, Type.LARGEINT) .put(Type.DECIMALV2, Type.DECIMALV2) + .put(Type.DECIMAL32, Type.DECIMAL32) + .put(Type.DECIMAL64, Type.DECIMAL64) + .put(Type.DECIMAL128, Type.DECIMAL128) .build(); private static final Map MULTI_DISTINCT_INIT_SYMBOL = @@ -366,6 +369,9 @@ public class FunctionSet { .put(Type.FLOAT, Type.DOUBLE) .put(Type.DOUBLE, Type.DOUBLE) .put(Type.DECIMALV2, Type.DECIMALV2) + .put(Type.DECIMAL32, Type.DECIMAL32) + .put(Type.DECIMAL64, Type.DECIMAL64) + .put(Type.DECIMAL128, Type.DECIMAL128) .build(); private static final Map STDDEV_INIT_SYMBOL = @@ -1521,6 +1527,45 @@ public class FunctionSet { null, prefix + "", false, true, true, true)); + } else if (t.equals(Type.DECIMAL32)) { + // vectorized + addBuiltin(AggregateFunction.createBuiltin("multi_distinct_count", Lists.newArrayList(t), + Type.BIGINT, + Type.DECIMAL32, + prefix + "", + prefix + "", + prefix + "", + prefix + "", + null, + null, + prefix + "", + false, true, true, true)); + } else if (t.equals(Type.DECIMAL64)) { + // vectorized + addBuiltin(AggregateFunction.createBuiltin("multi_distinct_count", Lists.newArrayList(t), + Type.BIGINT, + Type.DECIMAL64, + prefix + "", + prefix + "", + prefix + "", + prefix + "", + null, + null, + prefix + "", + false, true, true, true)); + } else if (t.equals(Type.DECIMAL128)) { + // vectorized + addBuiltin(AggregateFunction.createBuiltin("multi_distinct_count", Lists.newArrayList(t), + Type.BIGINT, + Type.DECIMAL128, + prefix + "", + prefix + "", + prefix + "", + prefix + "", + null, + null, + prefix + "", + false, true, true, true)); } // sum in multi distinct @@ -1573,6 +1618,43 @@ public class FunctionSet { null, prefix + "", false, true, true, true)); + } else if (t.equals(Type.DECIMAL32)) { + // vectorized + addBuiltin(AggregateFunction.createBuiltin("multi_distinct_sum", Lists.newArrayList(t), + MULTI_DISTINCT_SUM_RETURN_TYPE.get(t), + Type.DECIMAL32, + prefix + "", + prefix + "", + prefix + "", + prefix + "", + null, + null, + prefix + "", + false, true, true, true)); + } else if (t.equals(Type.DECIMAL64)) { + addBuiltin(AggregateFunction.createBuiltin("multi_distinct_sum", Lists.newArrayList(t), + MULTI_DISTINCT_SUM_RETURN_TYPE.get(t), + Type.DECIMAL64, + prefix + "", + prefix + "", + prefix + "", + prefix + "", + null, + null, + prefix + "", + false, true, true, true)); + } else if (t.equals(Type.DECIMAL128)) { + addBuiltin(AggregateFunction.createBuiltin("multi_distinct_sum", Lists.newArrayList(t), + MULTI_DISTINCT_SUM_RETURN_TYPE.get(t), + Type.DECIMAL128, + prefix + "", + prefix + "", + prefix + "", + prefix + "", + null, + null, + prefix + "", + false, true, true, true)); } // Min String minMaxSerializeOrFinalize = t.isStringType() ? stringValSerializeOrFinalize : null; @@ -2019,6 +2101,27 @@ public class FunctionSet { null, null, prefix + "10sum_removeIN9doris_udf12DecimalV2ValES3_EEvPNS2_15FunctionContextERKT_PT0_", null, false, true, false, true)); + addBuiltin(AggregateFunction.createBuiltin(name, + Lists.newArrayList(Type.DECIMAL32), Type.DECIMAL32, Type.DECIMAL32, initNull, + prefix + "3sumIN9doris_udf12DecimalV2ValES3_EEvPNS2_15FunctionContextERKT_PT0_", + prefix + "3sumIN9doris_udf12DecimalV2ValES3_EEvPNS2_15FunctionContextERKT_PT0_", + null, null, + prefix + "10sum_removeIN9doris_udf12DecimalV2ValES3_EEvPNS2_15FunctionContextERKT_PT0_", + null, false, true, false, true)); + addBuiltin(AggregateFunction.createBuiltin(name, + Lists.newArrayList(Type.DECIMAL64), Type.DECIMAL64, Type.DECIMAL64, initNull, + prefix + "3sumIN9doris_udf12DecimalV2ValES3_EEvPNS2_15FunctionContextERKT_PT0_", + prefix + "3sumIN9doris_udf12DecimalV2ValES3_EEvPNS2_15FunctionContextERKT_PT0_", + null, null, + prefix + "10sum_removeIN9doris_udf12DecimalV2ValES3_EEvPNS2_15FunctionContextERKT_PT0_", + null, false, true, false, true)); + addBuiltin(AggregateFunction.createBuiltin(name, + Lists.newArrayList(Type.DECIMAL128), Type.DECIMAL128, Type.DECIMAL128, initNull, + prefix + "3sumIN9doris_udf12DecimalV2ValES3_EEvPNS2_15FunctionContextERKT_PT0_", + prefix + "3sumIN9doris_udf12DecimalV2ValES3_EEvPNS2_15FunctionContextERKT_PT0_", + null, null, + prefix + "10sum_removeIN9doris_udf12DecimalV2ValES3_EEvPNS2_15FunctionContextERKT_PT0_", + null, false, true, false, true)); addBuiltin(AggregateFunction.createBuiltin(name, Lists.newArrayList(Type.LARGEINT), Type.LARGEINT, Type.LARGEINT, initNull, prefix + "3sumIN9doris_udf11LargeIntValES3_EEvPNS2_15FunctionContextERKT_PT0_", @@ -2224,7 +2327,8 @@ public class FunctionSet { // collect_list Type[] arraySubTypes = {Type.BOOLEAN, Type.SMALLINT, Type.TINYINT, Type.INT, Type.BIGINT, Type.LARGEINT, - Type.FLOAT, Type.DOUBLE, Type.DATE, Type.DATETIME, Type.DECIMALV2, Type.VARCHAR, Type.STRING}; + Type.FLOAT, Type.DOUBLE, Type.DATE, Type.DATETIME, Type.DECIMALV2, Type.DECIMAL32, Type.DECIMAL64, + Type.DECIMAL128, Type.VARCHAR, Type.STRING}; for (Type t : arraySubTypes) { addBuiltin(AggregateFunction.createBuiltin(COLLECT_LIST, Lists.newArrayList(t), new ArrayType(t), t, "", "", "", "", "", true, false, true, true)); @@ -2318,6 +2422,36 @@ public class FunctionSet { prefix + "20decimalv2_avg_removeEPN9doris_udf15FunctionContextERKNS1_12DecimalV2ValEPNS1_9StringValE", prefix + "22decimalv2_avg_finalizeEPN9doris_udf15FunctionContextERKNS1_9StringValE", false, true, false, true)); + addBuiltin(AggregateFunction.createBuiltin("avg", + Lists.newArrayList(Type.DECIMAL32), Type.DECIMAL32, Type.DECIMAL32, + prefix + "18decimalv2_avg_initEPN9doris_udf15FunctionContextEPNS1_9StringValE", + prefix + "20decimalv2_avg_updateEPN9doris_udf15FunctionContextERKNS1_12DecimalV2ValEPNS1_9StringValE", + prefix + "19decimalv2_avg_mergeEPN9doris_udf15FunctionContextERKNS1_9StringValEPS4_", + prefix + "23decimalv2_avg_serializeEPN9doris_udf15FunctionContextERKNS1_9StringValE", + prefix + "23decimalv2_avg_get_valueEPN9doris_udf15FunctionContextERKNS1_9StringValE", + prefix + "20decimalv2_avg_removeEPN9doris_udf15FunctionContextERKNS1_12DecimalV2ValEPNS1_9StringValE", + prefix + "22decimalv2_avg_finalizeEPN9doris_udf15FunctionContextERKNS1_9StringValE", + false, true, false, true)); + addBuiltin(AggregateFunction.createBuiltin("avg", + Lists.newArrayList(Type.DECIMAL64), Type.DECIMAL64, Type.DECIMAL64, + prefix + "18decimalv2_avg_initEPN9doris_udf15FunctionContextEPNS1_9StringValE", + prefix + "20decimalv2_avg_updateEPN9doris_udf15FunctionContextERKNS1_12DecimalV2ValEPNS1_9StringValE", + prefix + "19decimalv2_avg_mergeEPN9doris_udf15FunctionContextERKNS1_9StringValEPS4_", + prefix + "23decimalv2_avg_serializeEPN9doris_udf15FunctionContextERKNS1_9StringValE", + prefix + "23decimalv2_avg_get_valueEPN9doris_udf15FunctionContextERKNS1_9StringValE", + prefix + "20decimalv2_avg_removeEPN9doris_udf15FunctionContextERKNS1_12DecimalV2ValEPNS1_9StringValE", + prefix + "22decimalv2_avg_finalizeEPN9doris_udf15FunctionContextERKNS1_9StringValE", + false, true, false, true)); + addBuiltin(AggregateFunction.createBuiltin("avg", + Lists.newArrayList(Type.DECIMAL128), Type.DECIMAL128, Type.DECIMAL128, + prefix + "18decimalv2_avg_initEPN9doris_udf15FunctionContextEPNS1_9StringValE", + prefix + "20decimalv2_avg_updateEPN9doris_udf15FunctionContextERKNS1_12DecimalV2ValEPNS1_9StringValE", + prefix + "19decimalv2_avg_mergeEPN9doris_udf15FunctionContextERKNS1_9StringValEPS4_", + prefix + "23decimalv2_avg_serializeEPN9doris_udf15FunctionContextERKNS1_9StringValE", + prefix + "23decimalv2_avg_get_valueEPN9doris_udf15FunctionContextERKNS1_9StringValE", + prefix + "20decimalv2_avg_removeEPN9doris_udf15FunctionContextERKNS1_12DecimalV2ValEPNS1_9StringValE", + prefix + "22decimalv2_avg_finalizeEPN9doris_udf15FunctionContextERKNS1_9StringValE", + false, true, false, true)); addBuiltin(AggregateFunction.createBuiltin("avg", Lists.newArrayList(Type.DATE), Type.DATE, Type.DATE, prefix + "8avg_initEPN9doris_udf15FunctionContextEPNS1_9StringValE", diff --git a/fe/fe-core/src/main/java/org/apache/doris/catalog/HiveMetaStoreClientHelper.java b/fe/fe-core/src/main/java/org/apache/doris/catalog/HiveMetaStoreClientHelper.java index 4fabab3742..22974ddea2 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/catalog/HiveMetaStoreClientHelper.java +++ b/fe/fe-core/src/main/java/org/apache/doris/catalog/HiveMetaStoreClientHelper.java @@ -680,6 +680,9 @@ public class HiveMetaStoreClientHelper { return TypeInfoFactory.floatTypeInfo; case DOUBLE: return TypeInfoFactory.doubleTypeInfo; + case DECIMAL32: + case DECIMAL64: + case DECIMAL128: case DECIMALV2: return TypeInfoFactory.decimalTypeInfo; case DATE: @@ -799,7 +802,7 @@ public class HiveMetaStoreClientHelper { if (match.find()) { scale = Integer.parseInt(match.group(1)); } - return ScalarType.createDecimalV2Type(precision, scale); + return ScalarType.createDecimalType(precision, scale); } // TODO: Handle unsupported types. LOG.warn("Hive type {} may not supported yet, will use STRING instead.", hiveType); diff --git a/fe/fe-core/src/main/java/org/apache/doris/catalog/PrimitiveType.java b/fe/fe-core/src/main/java/org/apache/doris/catalog/PrimitiveType.java index eef63d23ea..5cf46e53ca 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/catalog/PrimitiveType.java +++ b/fe/fe-core/src/main/java/org/apache/doris/catalog/PrimitiveType.java @@ -22,6 +22,7 @@ import org.apache.doris.mysql.MysqlColType; import org.apache.doris.thrift.TPrimitiveType; import com.google.common.base.Preconditions; +import com.google.common.collect.ImmutableSet; import com.google.common.collect.ImmutableSetMultimap; import com.google.common.collect.Lists; @@ -50,6 +51,9 @@ public enum PrimitiveType { VARCHAR("VARCHAR", 16, TPrimitiveType.VARCHAR), DECIMALV2("DECIMALV2", 16, TPrimitiveType.DECIMALV2), + DECIMAL32("DECIMAL32", 4, TPrimitiveType.DECIMAL32), + DECIMAL64("DECIMAL64", 8, TPrimitiveType.DECIMAL64), + DECIMAL128("DECIMAL128", 16, TPrimitiveType.DECIMAL128), TIME("TIME", 8, TPrimitiveType.TIME), // these following types are stored as object binary in BE. HLL("HLL", 16, TPrimitiveType.HLL), @@ -76,6 +80,16 @@ public enum PrimitiveType { private static final int STRING_INDEX_LEN = 20; private static final int DECIMAL_INDEX_LEN = 12; + public static ImmutableSet typeWithPrecision; + + static { + ImmutableSet.Builder builder = ImmutableSet.builder(); + builder.add(DECIMAL32); + builder.add(DECIMAL64); + builder.add(DECIMAL128); + typeWithPrecision = builder.build(); + } + private static ImmutableSetMultimap implicitCastMap; static { ImmutableSetMultimap.Builder builder = ImmutableSetMultimap.builder(); @@ -93,6 +107,9 @@ public enum PrimitiveType { builder.put(NULL_TYPE, DATEV2); builder.put(NULL_TYPE, DATETIMEV2); builder.put(NULL_TYPE, DECIMALV2); + builder.put(NULL_TYPE, DECIMAL32); + builder.put(NULL_TYPE, DECIMAL64); + builder.put(NULL_TYPE, DECIMAL128); builder.put(NULL_TYPE, CHAR); builder.put(NULL_TYPE, VARCHAR); builder.put(NULL_TYPE, STRING); @@ -113,6 +130,9 @@ public enum PrimitiveType { builder.put(BOOLEAN, DATEV2); builder.put(BOOLEAN, DATETIMEV2); builder.put(BOOLEAN, DECIMALV2); + builder.put(BOOLEAN, DECIMAL32); + builder.put(BOOLEAN, DECIMAL64); + builder.put(BOOLEAN, DECIMAL128); builder.put(BOOLEAN, VARCHAR); builder.put(BOOLEAN, STRING); // Tinyint @@ -129,6 +149,9 @@ public enum PrimitiveType { builder.put(TINYINT, DATEV2); builder.put(TINYINT, DATETIMEV2); builder.put(TINYINT, DECIMALV2); + builder.put(TINYINT, DECIMAL32); + builder.put(TINYINT, DECIMAL64); + builder.put(TINYINT, DECIMAL128); builder.put(TINYINT, VARCHAR); builder.put(TINYINT, STRING); // Smallint @@ -145,6 +168,9 @@ public enum PrimitiveType { builder.put(SMALLINT, DATEV2); builder.put(SMALLINT, DATETIMEV2); builder.put(SMALLINT, DECIMALV2); + builder.put(SMALLINT, DECIMAL32); + builder.put(SMALLINT, DECIMAL64); + builder.put(SMALLINT, DECIMAL128); builder.put(SMALLINT, VARCHAR); builder.put(SMALLINT, STRING); // Int @@ -161,6 +187,9 @@ public enum PrimitiveType { builder.put(INT, DATEV2); builder.put(INT, DATETIMEV2); builder.put(INT, DECIMALV2); + builder.put(INT, DECIMAL32); + builder.put(INT, DECIMAL64); + builder.put(INT, DECIMAL128); builder.put(INT, VARCHAR); builder.put(INT, STRING); // Bigint @@ -177,6 +206,9 @@ public enum PrimitiveType { builder.put(BIGINT, DATEV2); builder.put(BIGINT, DATETIMEV2); builder.put(BIGINT, DECIMALV2); + builder.put(BIGINT, DECIMAL32); + builder.put(BIGINT, DECIMAL64); + builder.put(BIGINT, DECIMAL128); builder.put(BIGINT, VARCHAR); builder.put(BIGINT, STRING); // Largeint @@ -193,6 +225,9 @@ public enum PrimitiveType { builder.put(LARGEINT, DATEV2); builder.put(LARGEINT, DATETIMEV2); builder.put(LARGEINT, DECIMALV2); + builder.put(LARGEINT, DECIMAL32); + builder.put(LARGEINT, DECIMAL64); + builder.put(LARGEINT, DECIMAL128); builder.put(LARGEINT, VARCHAR); builder.put(LARGEINT, STRING); // Float @@ -209,6 +244,9 @@ public enum PrimitiveType { builder.put(FLOAT, DATEV2); builder.put(FLOAT, DATETIMEV2); builder.put(FLOAT, DECIMALV2); + builder.put(FLOAT, DECIMAL32); + builder.put(FLOAT, DECIMAL64); + builder.put(FLOAT, DECIMAL128); builder.put(FLOAT, VARCHAR); builder.put(FLOAT, STRING); // Double @@ -225,6 +263,9 @@ public enum PrimitiveType { builder.put(DOUBLE, DATEV2); builder.put(DOUBLE, DATETIMEV2); builder.put(DOUBLE, DECIMALV2); + builder.put(DOUBLE, DECIMAL32); + builder.put(DOUBLE, DECIMAL64); + builder.put(DOUBLE, DECIMAL128); builder.put(DOUBLE, VARCHAR); builder.put(DOUBLE, STRING); // Date @@ -241,6 +282,9 @@ public enum PrimitiveType { builder.put(DATE, DATEV2); builder.put(DATE, DATETIMEV2); builder.put(DATE, DECIMALV2); + builder.put(DATE, DECIMAL32); + builder.put(DATE, DECIMAL64); + builder.put(DATE, DECIMAL128); builder.put(DATE, VARCHAR); builder.put(DATE, STRING); // Datetime @@ -257,6 +301,9 @@ public enum PrimitiveType { builder.put(DATETIME, DATEV2); builder.put(DATETIME, DATETIMEV2); builder.put(DATETIME, DECIMALV2); + builder.put(DATETIME, DECIMAL32); + builder.put(DATETIME, DECIMAL64); + builder.put(DATETIME, DECIMAL128); builder.put(DATETIME, VARCHAR); builder.put(DATETIME, STRING); // DateV2 @@ -273,6 +320,9 @@ public enum PrimitiveType { builder.put(DATEV2, DATEV2); builder.put(DATEV2, DATETIMEV2); builder.put(DATEV2, DECIMALV2); + builder.put(DATEV2, DECIMAL32); + builder.put(DATEV2, DECIMAL64); + builder.put(DATEV2, DECIMAL128); builder.put(DATEV2, VARCHAR); builder.put(DATEV2, STRING); // DatetimeV2 @@ -289,6 +339,9 @@ public enum PrimitiveType { builder.put(DATETIMEV2, DATEV2); builder.put(DATETIMEV2, DATETIMEV2); builder.put(DATETIMEV2, DECIMALV2); + builder.put(DATETIMEV2, DECIMAL32); + builder.put(DATETIMEV2, DECIMAL64); + builder.put(DATETIMEV2, DECIMAL128); builder.put(DATETIMEV2, VARCHAR); builder.put(DATETIMEV2, STRING); // Char @@ -306,6 +359,9 @@ public enum PrimitiveType { builder.put(CHAR, DATEV2); builder.put(CHAR, DATETIMEV2); builder.put(CHAR, DECIMALV2); + builder.put(CHAR, DECIMAL32); + builder.put(CHAR, DECIMAL64); + builder.put(CHAR, DECIMAL128); builder.put(CHAR, VARCHAR); builder.put(CHAR, STRING); // Varchar @@ -322,6 +378,9 @@ public enum PrimitiveType { builder.put(VARCHAR, DATEV2); builder.put(VARCHAR, DATETIMEV2); builder.put(VARCHAR, DECIMALV2); + builder.put(VARCHAR, DECIMAL32); + builder.put(VARCHAR, DECIMAL64); + builder.put(VARCHAR, DECIMAL128); builder.put(VARCHAR, VARCHAR); builder.put(VARCHAR, STRING); @@ -339,6 +398,9 @@ public enum PrimitiveType { builder.put(STRING, DATEV2); builder.put(STRING, DATETIMEV2); builder.put(STRING, DECIMALV2); + builder.put(STRING, DECIMAL32); + builder.put(STRING, DECIMAL64); + builder.put(STRING, DECIMAL128); builder.put(STRING, VARCHAR); builder.put(STRING, STRING); @@ -352,9 +414,57 @@ public enum PrimitiveType { builder.put(DECIMALV2, FLOAT); builder.put(DECIMALV2, DOUBLE); builder.put(DECIMALV2, DECIMALV2); + builder.put(DECIMALV2, DECIMAL32); + builder.put(DECIMALV2, DECIMAL64); + builder.put(DECIMALV2, DECIMAL128); builder.put(DECIMALV2, VARCHAR); builder.put(DECIMALV2, STRING); + builder.put(DECIMAL32, BOOLEAN); + builder.put(DECIMAL32, TINYINT); + builder.put(DECIMAL32, SMALLINT); + builder.put(DECIMAL32, INT); + builder.put(DECIMAL32, BIGINT); + builder.put(DECIMAL32, LARGEINT); + builder.put(DECIMAL32, FLOAT); + builder.put(DECIMAL32, DOUBLE); + builder.put(DECIMAL32, DECIMALV2); + builder.put(DECIMAL32, DECIMAL32); + builder.put(DECIMAL32, DECIMAL64); + builder.put(DECIMAL32, DECIMAL128); + builder.put(DECIMAL32, VARCHAR); + builder.put(DECIMAL32, STRING); + + builder.put(DECIMAL64, BOOLEAN); + builder.put(DECIMAL64, TINYINT); + builder.put(DECIMAL64, SMALLINT); + builder.put(DECIMAL64, INT); + builder.put(DECIMAL64, BIGINT); + builder.put(DECIMAL64, LARGEINT); + builder.put(DECIMAL64, FLOAT); + builder.put(DECIMAL64, DOUBLE); + builder.put(DECIMAL64, DECIMALV2); + builder.put(DECIMAL64, DECIMAL32); + builder.put(DECIMAL64, DECIMAL64); + builder.put(DECIMAL64, DECIMAL128); + builder.put(DECIMAL64, VARCHAR); + builder.put(DECIMAL64, STRING); + + builder.put(DECIMAL128, BOOLEAN); + builder.put(DECIMAL128, TINYINT); + builder.put(DECIMAL128, SMALLINT); + builder.put(DECIMAL128, INT); + builder.put(DECIMAL128, BIGINT); + builder.put(DECIMAL128, LARGEINT); + builder.put(DECIMAL128, FLOAT); + builder.put(DECIMAL128, DOUBLE); + builder.put(DECIMAL128, DECIMALV2); + builder.put(DECIMAL128, DECIMAL32); + builder.put(DECIMAL128, DECIMAL64); + builder.put(DECIMAL128, DECIMAL128); + builder.put(DECIMAL128, VARCHAR); + builder.put(DECIMAL128, STRING); + // HLL builder.put(HLL, HLL); @@ -398,6 +508,9 @@ public enum PrimitiveType { numericTypes.add(FLOAT); numericTypes.add(DOUBLE); numericTypes.add(DECIMALV2); + numericTypes.add(DECIMAL32); + numericTypes.add(DECIMAL64); + numericTypes.add(DECIMAL128); supportedTypes = Lists.newArrayList(); supportedTypes.add(NULL_TYPE); @@ -420,6 +533,9 @@ public enum PrimitiveType { supportedTypes.add(DATETIMEV2); supportedTypes.add(TIMEV2); supportedTypes.add(DECIMALV2); + supportedTypes.add(DECIMAL32); + supportedTypes.add(DECIMAL64); + supportedTypes.add(DECIMAL128); supportedTypes.add(BITMAP); supportedTypes.add(ARRAY); supportedTypes.add(MAP); @@ -479,6 +595,9 @@ public enum PrimitiveType { compatibilityMatrix[NULL_TYPE.ordinal()][VARCHAR.ordinal()] = VARCHAR; compatibilityMatrix[NULL_TYPE.ordinal()][STRING.ordinal()] = STRING; compatibilityMatrix[NULL_TYPE.ordinal()][DECIMALV2.ordinal()] = DECIMALV2; + compatibilityMatrix[NULL_TYPE.ordinal()][DECIMAL32.ordinal()] = DECIMAL32; + compatibilityMatrix[NULL_TYPE.ordinal()][DECIMAL64.ordinal()] = DECIMAL64; + compatibilityMatrix[NULL_TYPE.ordinal()][DECIMAL128.ordinal()] = DECIMAL128; compatibilityMatrix[NULL_TYPE.ordinal()][TIME.ordinal()] = TIME; compatibilityMatrix[NULL_TYPE.ordinal()][TIMEV2.ordinal()] = TIMEV2; compatibilityMatrix[NULL_TYPE.ordinal()][BITMAP.ordinal()] = BITMAP; @@ -500,6 +619,9 @@ public enum PrimitiveType { compatibilityMatrix[BOOLEAN.ordinal()][VARCHAR.ordinal()] = INVALID_TYPE; compatibilityMatrix[BOOLEAN.ordinal()][STRING.ordinal()] = INVALID_TYPE; compatibilityMatrix[BOOLEAN.ordinal()][DECIMALV2.ordinal()] = DECIMALV2; + compatibilityMatrix[BOOLEAN.ordinal()][DECIMAL32.ordinal()] = DECIMAL32; + compatibilityMatrix[BOOLEAN.ordinal()][DECIMAL64.ordinal()] = DECIMAL64; + compatibilityMatrix[BOOLEAN.ordinal()][DECIMAL128.ordinal()] = DECIMAL128; compatibilityMatrix[BOOLEAN.ordinal()][TIME.ordinal()] = TIME; compatibilityMatrix[BOOLEAN.ordinal()][TIMEV2.ordinal()] = TIMEV2; @@ -518,6 +640,9 @@ public enum PrimitiveType { compatibilityMatrix[TINYINT.ordinal()][VARCHAR.ordinal()] = INVALID_TYPE; compatibilityMatrix[TINYINT.ordinal()][STRING.ordinal()] = INVALID_TYPE; compatibilityMatrix[TINYINT.ordinal()][DECIMALV2.ordinal()] = DECIMALV2; + compatibilityMatrix[TINYINT.ordinal()][DECIMAL32.ordinal()] = DECIMAL32; + compatibilityMatrix[TINYINT.ordinal()][DECIMAL64.ordinal()] = DECIMAL64; + compatibilityMatrix[TINYINT.ordinal()][DECIMAL128.ordinal()] = DECIMAL128; compatibilityMatrix[TINYINT.ordinal()][TIME.ordinal()] = TIME; compatibilityMatrix[TINYINT.ordinal()][TIMEV2.ordinal()] = TIMEV2; @@ -535,6 +660,9 @@ public enum PrimitiveType { compatibilityMatrix[SMALLINT.ordinal()][VARCHAR.ordinal()] = INVALID_TYPE; compatibilityMatrix[SMALLINT.ordinal()][STRING.ordinal()] = INVALID_TYPE; compatibilityMatrix[SMALLINT.ordinal()][DECIMALV2.ordinal()] = DECIMALV2; + compatibilityMatrix[SMALLINT.ordinal()][DECIMAL32.ordinal()] = DECIMAL32; + compatibilityMatrix[SMALLINT.ordinal()][DECIMAL64.ordinal()] = DECIMAL64; + compatibilityMatrix[SMALLINT.ordinal()][DECIMAL128.ordinal()] = DECIMAL128; compatibilityMatrix[SMALLINT.ordinal()][TIME.ordinal()] = TIME; compatibilityMatrix[SMALLINT.ordinal()][TIMEV2.ordinal()] = TIMEV2; @@ -551,6 +679,9 @@ public enum PrimitiveType { compatibilityMatrix[INT.ordinal()][VARCHAR.ordinal()] = INVALID_TYPE; compatibilityMatrix[INT.ordinal()][STRING.ordinal()] = INVALID_TYPE; compatibilityMatrix[INT.ordinal()][DECIMALV2.ordinal()] = DECIMALV2; + compatibilityMatrix[INT.ordinal()][DECIMAL32.ordinal()] = DECIMAL32; + compatibilityMatrix[INT.ordinal()][DECIMAL64.ordinal()] = DECIMAL64; + compatibilityMatrix[INT.ordinal()][DECIMAL128.ordinal()] = DECIMAL128; compatibilityMatrix[INT.ordinal()][TIME.ordinal()] = TIME; compatibilityMatrix[INT.ordinal()][TIMEV2.ordinal()] = TIMEV2; @@ -566,6 +697,9 @@ public enum PrimitiveType { compatibilityMatrix[BIGINT.ordinal()][VARCHAR.ordinal()] = INVALID_TYPE; compatibilityMatrix[BIGINT.ordinal()][STRING.ordinal()] = INVALID_TYPE; compatibilityMatrix[BIGINT.ordinal()][DECIMALV2.ordinal()] = DECIMALV2; + compatibilityMatrix[BIGINT.ordinal()][DECIMAL32.ordinal()] = DECIMAL32; + compatibilityMatrix[BIGINT.ordinal()][DECIMAL64.ordinal()] = DECIMAL64; + compatibilityMatrix[BIGINT.ordinal()][DECIMAL128.ordinal()] = DECIMAL128; compatibilityMatrix[BIGINT.ordinal()][TIME.ordinal()] = TIME; compatibilityMatrix[BIGINT.ordinal()][TIMEV2.ordinal()] = TIMEV2; @@ -580,6 +714,9 @@ public enum PrimitiveType { compatibilityMatrix[LARGEINT.ordinal()][VARCHAR.ordinal()] = INVALID_TYPE; compatibilityMatrix[LARGEINT.ordinal()][STRING.ordinal()] = INVALID_TYPE; compatibilityMatrix[LARGEINT.ordinal()][DECIMALV2.ordinal()] = DECIMALV2; + compatibilityMatrix[LARGEINT.ordinal()][DECIMAL32.ordinal()] = DECIMAL32; + compatibilityMatrix[LARGEINT.ordinal()][DECIMAL64.ordinal()] = DECIMAL64; + compatibilityMatrix[LARGEINT.ordinal()][DECIMAL128.ordinal()] = DECIMAL128; compatibilityMatrix[LARGEINT.ordinal()][TIME.ordinal()] = TIME; compatibilityMatrix[LARGEINT.ordinal()][TIMEV2.ordinal()] = TIMEV2; @@ -593,6 +730,9 @@ public enum PrimitiveType { compatibilityMatrix[FLOAT.ordinal()][VARCHAR.ordinal()] = INVALID_TYPE; compatibilityMatrix[FLOAT.ordinal()][STRING.ordinal()] = INVALID_TYPE; compatibilityMatrix[FLOAT.ordinal()][DECIMALV2.ordinal()] = DECIMALV2; + compatibilityMatrix[FLOAT.ordinal()][DECIMAL32.ordinal()] = DECIMAL32; + compatibilityMatrix[FLOAT.ordinal()][DECIMAL64.ordinal()] = DECIMAL64; + compatibilityMatrix[FLOAT.ordinal()][DECIMAL128.ordinal()] = DECIMAL128; compatibilityMatrix[FLOAT.ordinal()][TIME.ordinal()] = TIME; compatibilityMatrix[FLOAT.ordinal()][TIMEV2.ordinal()] = TIMEV2; @@ -605,6 +745,9 @@ public enum PrimitiveType { compatibilityMatrix[DOUBLE.ordinal()][VARCHAR.ordinal()] = INVALID_TYPE; compatibilityMatrix[DOUBLE.ordinal()][STRING.ordinal()] = INVALID_TYPE; compatibilityMatrix[DOUBLE.ordinal()][DECIMALV2.ordinal()] = DECIMALV2; + compatibilityMatrix[DOUBLE.ordinal()][DECIMAL32.ordinal()] = DECIMAL32; + compatibilityMatrix[DOUBLE.ordinal()][DECIMAL64.ordinal()] = DECIMAL64; + compatibilityMatrix[DOUBLE.ordinal()][DECIMAL128.ordinal()] = DECIMAL128; compatibilityMatrix[DOUBLE.ordinal()][TIME.ordinal()] = TIME; compatibilityMatrix[DOUBLE.ordinal()][TIMEV2.ordinal()] = TIMEV2; @@ -616,6 +759,9 @@ public enum PrimitiveType { compatibilityMatrix[DATE.ordinal()][VARCHAR.ordinal()] = INVALID_TYPE; compatibilityMatrix[DATE.ordinal()][STRING.ordinal()] = INVALID_TYPE; compatibilityMatrix[DATE.ordinal()][DECIMALV2.ordinal()] = INVALID_TYPE; + compatibilityMatrix[DATE.ordinal()][DECIMAL32.ordinal()] = INVALID_TYPE; + compatibilityMatrix[DATE.ordinal()][DECIMAL64.ordinal()] = INVALID_TYPE; + compatibilityMatrix[DATE.ordinal()][DECIMAL128.ordinal()] = INVALID_TYPE; compatibilityMatrix[DATE.ordinal()][TIME.ordinal()] = INVALID_TYPE; compatibilityMatrix[DATE.ordinal()][TIMEV2.ordinal()] = INVALID_TYPE; @@ -627,6 +773,9 @@ public enum PrimitiveType { compatibilityMatrix[DATEV2.ordinal()][VARCHAR.ordinal()] = INVALID_TYPE; compatibilityMatrix[DATEV2.ordinal()][STRING.ordinal()] = INVALID_TYPE; compatibilityMatrix[DATEV2.ordinal()][DECIMALV2.ordinal()] = INVALID_TYPE; + compatibilityMatrix[DATEV2.ordinal()][DECIMAL32.ordinal()] = INVALID_TYPE; + compatibilityMatrix[DATEV2.ordinal()][DECIMAL64.ordinal()] = INVALID_TYPE; + compatibilityMatrix[DATEV2.ordinal()][DECIMAL128.ordinal()] = INVALID_TYPE; compatibilityMatrix[DATEV2.ordinal()][TIME.ordinal()] = INVALID_TYPE; compatibilityMatrix[DATEV2.ordinal()][TIMEV2.ordinal()] = INVALID_TYPE; @@ -636,6 +785,9 @@ public enum PrimitiveType { compatibilityMatrix[DATETIME.ordinal()][VARCHAR.ordinal()] = INVALID_TYPE; compatibilityMatrix[DATETIME.ordinal()][STRING.ordinal()] = INVALID_TYPE; compatibilityMatrix[DATETIME.ordinal()][DECIMALV2.ordinal()] = INVALID_TYPE; + compatibilityMatrix[DATETIME.ordinal()][DECIMAL32.ordinal()] = INVALID_TYPE; + compatibilityMatrix[DATETIME.ordinal()][DECIMAL64.ordinal()] = INVALID_TYPE; + compatibilityMatrix[DATETIME.ordinal()][DECIMAL128.ordinal()] = INVALID_TYPE; compatibilityMatrix[DATETIME.ordinal()][TIME.ordinal()] = INVALID_TYPE; compatibilityMatrix[DATETIME.ordinal()][TIMEV2.ordinal()] = INVALID_TYPE; @@ -645,6 +797,9 @@ public enum PrimitiveType { compatibilityMatrix[DATETIMEV2.ordinal()][VARCHAR.ordinal()] = INVALID_TYPE; compatibilityMatrix[DATETIMEV2.ordinal()][STRING.ordinal()] = INVALID_TYPE; compatibilityMatrix[DATETIMEV2.ordinal()][DECIMALV2.ordinal()] = INVALID_TYPE; + compatibilityMatrix[DATETIMEV2.ordinal()][DECIMAL32.ordinal()] = INVALID_TYPE; + compatibilityMatrix[DATETIMEV2.ordinal()][DECIMAL64.ordinal()] = INVALID_TYPE; + compatibilityMatrix[DATETIMEV2.ordinal()][DECIMAL128.ordinal()] = INVALID_TYPE; compatibilityMatrix[DATETIMEV2.ordinal()][TIME.ordinal()] = INVALID_TYPE; compatibilityMatrix[DATETIMEV2.ordinal()][TIMEV2.ordinal()] = INVALID_TYPE; @@ -652,24 +807,57 @@ public enum PrimitiveType { compatibilityMatrix[CHAR.ordinal()][VARCHAR.ordinal()] = VARCHAR; compatibilityMatrix[CHAR.ordinal()][STRING.ordinal()] = STRING; compatibilityMatrix[CHAR.ordinal()][DECIMALV2.ordinal()] = INVALID_TYPE; + compatibilityMatrix[CHAR.ordinal()][DECIMAL32.ordinal()] = INVALID_TYPE; + compatibilityMatrix[CHAR.ordinal()][DECIMAL64.ordinal()] = INVALID_TYPE; + compatibilityMatrix[CHAR.ordinal()][DECIMAL128.ordinal()] = INVALID_TYPE; compatibilityMatrix[CHAR.ordinal()][TIME.ordinal()] = INVALID_TYPE; compatibilityMatrix[CHAR.ordinal()][TIMEV2.ordinal()] = INVALID_TYPE; compatibilityMatrix[VARCHAR.ordinal()][VARCHAR.ordinal()] = VARCHAR; compatibilityMatrix[VARCHAR.ordinal()][STRING.ordinal()] = STRING; compatibilityMatrix[VARCHAR.ordinal()][DECIMALV2.ordinal()] = INVALID_TYPE; + compatibilityMatrix[VARCHAR.ordinal()][DECIMAL32.ordinal()] = INVALID_TYPE; + compatibilityMatrix[VARCHAR.ordinal()][DECIMAL64.ordinal()] = INVALID_TYPE; + compatibilityMatrix[VARCHAR.ordinal()][DECIMAL128.ordinal()] = INVALID_TYPE; compatibilityMatrix[VARCHAR.ordinal()][TIME.ordinal()] = INVALID_TYPE; compatibilityMatrix[VARCHAR.ordinal()][TIMEV2.ordinal()] = INVALID_TYPE; compatibilityMatrix[STRING.ordinal()][STRING.ordinal()] = STRING; compatibilityMatrix[STRING.ordinal()][DECIMALV2.ordinal()] = INVALID_TYPE; + compatibilityMatrix[STRING.ordinal()][DECIMAL32.ordinal()] = INVALID_TYPE; + compatibilityMatrix[STRING.ordinal()][DECIMAL64.ordinal()] = INVALID_TYPE; + compatibilityMatrix[STRING.ordinal()][DECIMAL128.ordinal()] = INVALID_TYPE; compatibilityMatrix[STRING.ordinal()][TIME.ordinal()] = INVALID_TYPE; compatibilityMatrix[STRING.ordinal()][TIMEV2.ordinal()] = INVALID_TYPE; compatibilityMatrix[DECIMALV2.ordinal()][DECIMALV2.ordinal()] = DECIMALV2; + compatibilityMatrix[DECIMALV2.ordinal()][DECIMAL32.ordinal()] = DECIMALV2; + compatibilityMatrix[DECIMALV2.ordinal()][DECIMAL64.ordinal()] = DECIMALV2; + compatibilityMatrix[DECIMALV2.ordinal()][DECIMAL128.ordinal()] = DECIMAL128; compatibilityMatrix[DECIMALV2.ordinal()][TIME.ordinal()] = INVALID_TYPE; compatibilityMatrix[DECIMALV2.ordinal()][TIMEV2.ordinal()] = INVALID_TYPE; + compatibilityMatrix[DECIMAL32.ordinal()][DECIMALV2.ordinal()] = DECIMALV2; + compatibilityMatrix[DECIMAL32.ordinal()][DECIMAL32.ordinal()] = DECIMAL32; + compatibilityMatrix[DECIMAL32.ordinal()][DECIMAL64.ordinal()] = DECIMAL64; + compatibilityMatrix[DECIMAL32.ordinal()][DECIMAL128.ordinal()] = DECIMAL128; + compatibilityMatrix[DECIMAL32.ordinal()][TIME.ordinal()] = INVALID_TYPE; + compatibilityMatrix[DECIMAL32.ordinal()][TIMEV2.ordinal()] = INVALID_TYPE; + + compatibilityMatrix[DECIMAL64.ordinal()][DECIMALV2.ordinal()] = DECIMALV2; + compatibilityMatrix[DECIMAL64.ordinal()][DECIMAL32.ordinal()] = DECIMAL64; + compatibilityMatrix[DECIMAL64.ordinal()][DECIMAL64.ordinal()] = DECIMAL64; + compatibilityMatrix[DECIMAL64.ordinal()][DECIMAL128.ordinal()] = DECIMAL128; + compatibilityMatrix[DECIMAL64.ordinal()][TIME.ordinal()] = INVALID_TYPE; + compatibilityMatrix[DECIMAL64.ordinal()][TIMEV2.ordinal()] = INVALID_TYPE; + + compatibilityMatrix[DECIMAL128.ordinal()][DECIMALV2.ordinal()] = DECIMAL128; + compatibilityMatrix[DECIMAL128.ordinal()][DECIMAL32.ordinal()] = DECIMAL128; + compatibilityMatrix[DECIMAL128.ordinal()][DECIMAL64.ordinal()] = DECIMAL128; + compatibilityMatrix[DECIMAL128.ordinal()][DECIMAL128.ordinal()] = DECIMAL128; + compatibilityMatrix[DECIMAL128.ordinal()][TIME.ordinal()] = INVALID_TYPE; + compatibilityMatrix[DECIMAL128.ordinal()][TIMEV2.ordinal()] = INVALID_TYPE; + compatibilityMatrix[HLL.ordinal()][HLL.ordinal()] = HLL; compatibilityMatrix[HLL.ordinal()][TIME.ordinal()] = INVALID_TYPE; compatibilityMatrix[HLL.ordinal()][TIMEV2.ordinal()] = INVALID_TYPE; @@ -748,6 +936,12 @@ public enum PrimitiveType { return BINARY; case DECIMALV2: return DECIMALV2; + case DECIMAL32: + return DECIMAL32; + case DECIMAL64: + return DECIMAL64; + case DECIMAL128: + return DECIMAL128; case TIME: return TIME; case TIMEV2: @@ -844,8 +1038,12 @@ public enum PrimitiveType { return this == DECIMALV2; } + public boolean isDecimalV3Type() { + return this == DECIMAL32 || this == DECIMAL64 || this == DECIMAL128; + } + public boolean isNumericType() { - return isFixedPointType() || isFloatingPointType() || isDecimalV2Type(); + return isFixedPointType() || isFloatingPointType() || isDecimalV2Type() || isDecimalV3Type(); } public boolean isValid() { @@ -913,6 +1111,9 @@ public enum PrimitiveType { } } case DECIMALV2: + case DECIMAL32: + case DECIMAL64: + case DECIMAL128: return MysqlColType.MYSQL_TYPE_NEWDECIMAL; case STRING: return MysqlColType.MYSQL_TYPE_BLOB; @@ -939,6 +1140,12 @@ public enum PrimitiveType { return STRING_INDEX_LEN; case DECIMALV2: return DECIMAL_INDEX_LEN; + case DECIMAL32: + return 4; + case DECIMAL64: + return 8; + case DECIMAL128: + return 16; default: return this.getSlotSize(); } diff --git a/fe/fe-core/src/main/java/org/apache/doris/catalog/ScalarFunction.java b/fe/fe-core/src/main/java/org/apache/doris/catalog/ScalarFunction.java index cb4055e949..e607ea5ad2 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/catalog/ScalarFunction.java +++ b/fe/fe-core/src/main/java/org/apache/doris/catalog/ScalarFunction.java @@ -171,6 +171,9 @@ public class ScalarFunction extends Function { beFn += "_datetime_val"; break; case DECIMALV2: + case DECIMAL32: + case DECIMAL64: + case DECIMAL128: beFn += "_decimalv2_val"; usesDecimalV2 = true; break; @@ -246,6 +249,9 @@ public class ScalarFunction extends Function { beFn.append("_datetime_val"); break; case DECIMALV2: + case DECIMAL32: + case DECIMAL64: + case DECIMAL128: beFn.append("_decimalv2_val"); usesDecimalV2 = true; break; @@ -378,8 +384,8 @@ public class ScalarFunction extends Function { } @Override - public TFunction toThrift() { - TFunction fn = super.toThrift(); + public TFunction toThrift(Type realReturnType, Type[] realArgTypes) { + TFunction fn = super.toThrift(realReturnType, realArgTypes); fn.setScalarFn(new TScalarFunction()); fn.getScalarFn().setSymbol(symbolName); if (prepareFnSymbol != null) { diff --git a/fe/fe-core/src/main/java/org/apache/doris/catalog/ScalarType.java b/fe/fe-core/src/main/java/org/apache/doris/catalog/ScalarType.java index 2cced324a5..26ba564769 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/catalog/ScalarType.java +++ b/fe/fe-core/src/main/java/org/apache/doris/catalog/ScalarType.java @@ -75,6 +75,9 @@ public class ScalarType extends Type { // Hive, mysql, sql server standard. public static final int MAX_PRECISION = 38; + public static final int MAX_DECIMAL32_PRECISION = 9; + public static final int MAX_DECIMAL64_PRECISION = 18; + public static final int MAX_DECIMAL128_PRECISION = 38; private static final Logger LOG = LogManager.getLogger(ScalarType.class); @SerializedName(value = "type") @@ -118,8 +121,11 @@ public class ScalarType extends Type { return createVarcharType(len); case STRING: return createStringType(); + case DECIMAL32: + case DECIMAL64: + case DECIMAL128: case DECIMALV2: - return createDecimalV2Type(precision, scale); + return createDecimalType(precision, scale); default: return createType(type); } @@ -169,6 +175,12 @@ public class ScalarType extends Type { return TIMEV2; case TIME: return TIME; + case DECIMAL32: + return DEFAULT_DECIMAL32; + case DECIMAL64: + return DEFAULT_DECIMAL64; + case DECIMAL128: + return DEFAULT_DECIMAL128; case DECIMALV2: return DEFAULT_DECIMALV2; case LARGEINT: @@ -227,7 +239,10 @@ public class ScalarType extends Type { return TIME; case "DECIMAL": case "DECIMALV2": - return (ScalarType) createDecimalV2Type(); + case "DECIMAL32": + case "DECIMAL64": + case "DECIMAL128": + return (ScalarType) createDecimalType(); case "LARGEINT": return LARGEINT; default: @@ -255,39 +270,63 @@ public class ScalarType extends Type { return type; } - public static ScalarType createDecimalV2Type() { - return DEFAULT_DECIMALV2; + public static ScalarType createDecimalType() { + if (Config.enable_decimalv3) { + return DEFAULT_DECIMALV3; + } else { + return DEFAULT_DECIMALV2; + } } - public static ScalarType createDecimalV2Type(int precision) { - return createDecimalV2Type(precision, DEFAULT_SCALE); + public static ScalarType createDecimalType(int precision) { + return createDecimalType(precision, DEFAULT_SCALE); } - public static ScalarType createDecimalV2Type(int precision, int scale) { - // Preconditions.checkState(precision >= 0); // Enforced by parser - // Preconditions.checkState(scale >= 0); // Enforced by parser. - ScalarType type = new ScalarType(PrimitiveType.DECIMALV2); + public static ScalarType createDecimalType(int precision, int scale) { + ScalarType type = new ScalarType(getSuitableDecimalType(precision)); type.precision = precision; type.scale = scale; return type; } - public static ScalarType createDecimalV2Type(String precisionStr) { - ScalarType type = new ScalarType(PrimitiveType.DECIMALV2); + public static ScalarType createDecimalType(PrimitiveType primitiveType, int precision, int scale) { + ScalarType type = new ScalarType(primitiveType); + type.precision = precision; + type.scale = scale; + return type; + } + + public static ScalarType createDecimalType(String precisionStr) { + int precision = Integer.parseInt(precisionStr); + ScalarType type = new ScalarType(getSuitableDecimalType(precision)); type.precisionStr = precisionStr; type.scaleStr = null; return type; } - public static ScalarType createDecimalV2Type(String precisionStr, String scaleStr) { + public static ScalarType createDecimalType(String precisionStr, String scaleStr) { ScalarType type = new ScalarType(PrimitiveType.DECIMALV2); type.precisionStr = precisionStr; type.scaleStr = scaleStr; return type; } - public static ScalarType createDecimalV2TypeInternal(int precision, int scale) { - ScalarType type = new ScalarType(PrimitiveType.DECIMALV2); + public static PrimitiveType getSuitableDecimalType(int precision) { + if (Config.enable_decimalv3) { + if (precision <= MAX_DECIMAL32_PRECISION) { + return PrimitiveType.DECIMAL32; + } else if (precision <= MAX_DECIMAL64_PRECISION) { + return PrimitiveType.DECIMAL64; + } else { + return PrimitiveType.DECIMAL128; + } + } else { + return PrimitiveType.DECIMALV2; + } + } + + public static ScalarType createDecimalTypeInternal(int precision, int scale) { + ScalarType type = new ScalarType(getSuitableDecimalType(precision)); type.precision = Math.min(precision, MAX_PRECISION); type.scale = Math.min(type.precision, scale); return type; @@ -340,6 +379,22 @@ public class ScalarType extends Type { return type; } + /** + * create a wider decimal type. + */ + public static ScalarType createWiderDecimalV3Type(int precision, int scale) { + ScalarType type = new ScalarType(PrimitiveType.DECIMALV2); + if (precision <= MAX_DECIMAL32_PRECISION) { + type.precision = MAX_DECIMAL32_PRECISION; + } else if (precision <= MAX_DECIMAL64_PRECISION) { + type.precision = MAX_DECIMAL64_PRECISION; + } else { + type.precision = MAX_DECIMAL128_PRECISION; + } + type.scale = scale; + return type; + } + public static ScalarType createVarcharType(int len) { // length checked in analysis ScalarType type = new ScalarType(PrimitiveType.VARCHAR); @@ -385,7 +440,7 @@ public class ScalarType extends Type { return "CHAR(*)"; } return "CHAR(" + len + ")"; - } else if (type == PrimitiveType.DECIMALV2) { + } else if (type == PrimitiveType.DECIMALV2 || type.isDecimalV3Type()) { if (isWildcardDecimal()) { return "DECIMAL(*,*)"; } @@ -424,6 +479,9 @@ public class ScalarType extends Type { } break; case DECIMALV2: + case DECIMAL32: + case DECIMAL64: + case DECIMAL128: if (Strings.isNullOrEmpty(precisionStr)) { stringBuilder.append("decimal").append("(").append(precision) .append(", ").append(scale).append(")"); @@ -496,7 +554,10 @@ public class ScalarType extends Type { scalarType.setLen(len); break; } - case DECIMALV2: { + case DECIMALV2: + case DECIMAL32: + case DECIMAL64: + case DECIMAL128: { scalarType.setScale(scale); scalarType.setPrecision(precision); break; @@ -509,13 +570,15 @@ public class ScalarType extends Type { public int decimalPrecision() { Preconditions.checkState(type == PrimitiveType.DECIMALV2 || type == PrimitiveType.DATETIMEV2 - || type == PrimitiveType.TIMEV2); + || type == PrimitiveType.TIMEV2 || type == PrimitiveType.DECIMAL32 + || type == PrimitiveType.DECIMAL64 || type == PrimitiveType.DECIMAL128); return precision; } public int decimalScale() { Preconditions.checkState(type == PrimitiveType.DECIMALV2 || type == PrimitiveType.DATETIMEV2 - || type == PrimitiveType.TIMEV2); + || type == PrimitiveType.TIMEV2 || type == PrimitiveType.DECIMAL32 + || type == PrimitiveType.DECIMAL64 || type == PrimitiveType.DECIMAL128); return scale; } @@ -568,7 +631,7 @@ public class ScalarType extends Type { @Override public boolean isWildcardDecimal() { - return (type == PrimitiveType.DECIMALV2) + return (type.isDecimalV2Type() || type.isDecimalV3Type()) && precision == -1 && scale == -1; } @@ -588,7 +651,7 @@ public class ScalarType extends Type { || type == PrimitiveType.SMALLINT || type == PrimitiveType.INT || type == PrimitiveType.BIGINT || type == PrimitiveType.FLOAT || type == PrimitiveType.DOUBLE || type == PrimitiveType.DATE - || type == PrimitiveType.DATETIME || type == PrimitiveType.DECIMALV2 + || type == PrimitiveType.DATETIME || type == PrimitiveType.DECIMALV2 || type.isDecimalV3Type() || type == PrimitiveType.CHAR || type == PrimitiveType.DATEV2 || type == PrimitiveType.DATETIMEV2 || type == PrimitiveType.TIMEV2; } @@ -648,9 +711,16 @@ public class ScalarType extends Type { Preconditions.checkState(!isWildcardDecimal()); return true; } + if (isDecimalV3() && scalarType.isWildcardDecimal()) { + Preconditions.checkState(!isWildcardDecimal()); + return true; + } if (isDecimalV2() && scalarType.isDecimalV2()) { return true; } + if (isDecimalV3() && scalarType.isDecimalV3()) { + return true; + } return false; } @@ -681,7 +751,8 @@ public class ScalarType extends Type { if (type == PrimitiveType.VARCHAR) { return len == other.len; } - if (type == PrimitiveType.DECIMALV2 || type == PrimitiveType.DATETIMEV2 || type == PrimitiveType.TIMEV2) { + if (type.isDecimalV2Type() || type.isDecimalV3Type() + || type == PrimitiveType.DATETIMEV2 || type == PrimitiveType.TIMEV2) { return precision == other.precision && scale == other.scale; } return true; @@ -696,7 +767,13 @@ public class ScalarType extends Type { } else if (isNull()) { return ScalarType.NULL; } else if (isDecimalV2()) { - return createDecimalV2TypeInternal(MAX_PRECISION, scale); + return createDecimalTypeInternal(MAX_PRECISION, scale); + } else if (getPrimitiveType() == PrimitiveType.DECIMAL32) { + return createDecimalTypeInternal(MAX_DECIMAL32_PRECISION, scale); + } else if (getPrimitiveType() == PrimitiveType.DECIMAL64) { + return createDecimalTypeInternal(MAX_DECIMAL64_PRECISION, scale); + } else if (getPrimitiveType() == PrimitiveType.DECIMAL128) { + return createDecimalTypeInternal(MAX_DECIMAL128_PRECISION, scale); } else if (isLargeIntType()) { return ScalarType.LARGEINT; } else if (isDatetimeV2()) { @@ -713,7 +790,13 @@ public class ScalarType extends Type { if (type == PrimitiveType.DOUBLE || type == PrimitiveType.BIGINT || isNull()) { return this; } else if (type == PrimitiveType.DECIMALV2) { - return createDecimalV2TypeInternal(MAX_PRECISION, scale); + return createDecimalTypeInternal(MAX_PRECISION, scale); + } else if (type == PrimitiveType.DECIMAL32) { + return createDecimalTypeInternal(MAX_DECIMAL64_PRECISION, scale); + } else if (type == PrimitiveType.DECIMAL64) { + return createDecimalTypeInternal(MAX_DECIMAL128_PRECISION, scale); + } else if (type == PrimitiveType.DECIMAL128) { + return createDecimalTypeInternal(MAX_DECIMAL128_PRECISION, scale); } else if (type == PrimitiveType.DATETIMEV2) { return createDatetimeV2Type(6); } else if (type == PrimitiveType.TIMEV2) { @@ -733,17 +816,17 @@ public class ScalarType extends Type { case DECIMALV2: return this; case TINYINT: - return createDecimalV2Type(3); + return createDecimalType(3); case SMALLINT: - return createDecimalV2Type(5); + return createDecimalType(5); case INT: - return createDecimalV2Type(10); + return createDecimalType(10); case BIGINT: - return createDecimalV2Type(19); + return createDecimalType(19); case FLOAT: - return createDecimalV2TypeInternal(MAX_PRECISION, 9); + return createDecimalTypeInternal(MAX_PRECISION, 9); case DOUBLE: - return createDecimalV2TypeInternal(MAX_PRECISION, 17); + return createDecimalTypeInternal(MAX_PRECISION, 17); default: return ScalarType.INVALID; } @@ -756,8 +839,8 @@ public class ScalarType extends Type { * the decimal point must be greater or equal. */ public boolean isSupertypeOf(ScalarType o) { - Preconditions.checkState(isDecimalV2()); - Preconditions.checkState(o.isDecimalV2()); + Preconditions.checkState(isDecimalV2() || isDecimalV3()); + Preconditions.checkState(o.isDecimalV2() || o.isDecimalV3()); if (isWildcardDecimal()) { return true; } @@ -823,6 +906,21 @@ public class ScalarType extends Type { return INVALID; } + if ((t1.isDecimalV2() && t2.isDateV2()) + || (t1.isDateV2() && t2.isDecimalV2())) { + return INVALID; + } + + if ((t1.isDecimalV3() && t2.isDate()) + || (t1.isDate() && t2.isDecimalV3())) { + return INVALID; + } + + if ((t1.isDecimalV3() && t2.isDateV2()) + || (t1.isDateV2() && t2.isDecimalV3())) { + return INVALID; + } + if (t1.isDecimalV2() || t2.isDecimalV2()) { return DECIMALV2; } @@ -865,6 +963,7 @@ public class ScalarType extends Type { return 2; case INT: case FLOAT: + case DECIMAL32: return 4; case BIGINT: case TIME: @@ -872,9 +971,11 @@ public class ScalarType extends Type { // TODO(Gabriel): unify execution engine and storage engine case TIMEV2: case DATETIMEV2: + case DECIMAL64: return 8; case LARGEINT: case DECIMALV2: + case DECIMAL128: return 16; case DOUBLE: return 12; @@ -904,7 +1005,8 @@ public class ScalarType extends Type { if (type == PrimitiveType.CHAR || type == PrimitiveType.VARCHAR || type == PrimitiveType.HLL) { thrift.setLen(len); } - if (type == PrimitiveType.DECIMALV2 || type == PrimitiveType.DATETIMEV2 || type == PrimitiveType.TIMEV2) { + if (type == PrimitiveType.DECIMALV2 || type.isDecimalV3Type() + || type == PrimitiveType.DATETIMEV2 || type == PrimitiveType.TIMEV2) { thrift.setPrecision(precision); thrift.setScale(scale); } diff --git a/fe/fe-core/src/main/java/org/apache/doris/catalog/Type.java b/fe/fe-core/src/main/java/org/apache/doris/catalog/Type.java index 98cbbb83d5..4e97454fe6 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/catalog/Type.java +++ b/fe/fe-core/src/main/java/org/apache/doris/catalog/Type.java @@ -69,12 +69,28 @@ public abstract class Type { public static final ScalarType TIMEV2 = new ScalarType(PrimitiveType.TIMEV2); public static final ScalarType TIME = new ScalarType(PrimitiveType.TIME); public static final ScalarType STRING = new ScalarType(PrimitiveType.STRING); - public static final ScalarType DEFAULT_DECIMALV2 = - ScalarType.createDecimalV2Type(ScalarType.DEFAULT_PRECISION, ScalarType.DEFAULT_SCALE); + public static final ScalarType DEFAULT_DECIMALV2 = ScalarType.createDecimalType(PrimitiveType.DECIMALV2, + ScalarType.DEFAULT_PRECISION, ScalarType.DEFAULT_SCALE); + + public static final ScalarType DEFAULT_DECIMAL32 = + ScalarType.createDecimalType(PrimitiveType.DECIMAL32, ScalarType.MAX_DECIMAL32_PRECISION, + ScalarType.DEFAULT_SCALE); + + public static final ScalarType DEFAULT_DECIMAL64 = + ScalarType.createDecimalType(PrimitiveType.DECIMAL64, ScalarType.MAX_DECIMAL64_PRECISION, + ScalarType.DEFAULT_SCALE); + + public static final ScalarType DEFAULT_DECIMAL128 = + ScalarType.createDecimalType(PrimitiveType.DECIMAL128, ScalarType.MAX_DECIMAL128_PRECISION, + ScalarType.DEFAULT_SCALE); + public static final ScalarType DEFAULT_DECIMALV3 = DEFAULT_DECIMAL32; public static final ScalarType DEFAULT_DATETIMEV2 = ScalarType.createDatetimeV2Type(0); public static final ScalarType DATETIMEV2 = DEFAULT_DATETIMEV2; public static final ScalarType DEFAULT_TIMEV2 = ScalarType.createTimeV2Type(0); public static final ScalarType DECIMALV2 = DEFAULT_DECIMALV2; + public static final ScalarType DECIMAL32 = DEFAULT_DECIMAL32; + public static final ScalarType DECIMAL64 = DEFAULT_DECIMAL64; + public static final ScalarType DECIMAL128 = DEFAULT_DECIMAL128; // (ScalarType) ScalarType.createDecimalTypeInternal(-1, -1); public static final ScalarType DEFAULT_VARCHAR = ScalarType.createVarcharType(-1); public static final ScalarType VARCHAR = ScalarType.createVarcharType(-1); @@ -110,6 +126,9 @@ public abstract class Type { numericTypes.add(FLOAT); numericTypes.add(DOUBLE); numericTypes.add(DECIMALV2); + numericTypes.add(DECIMAL32); + numericTypes.add(DECIMAL64); + numericTypes.add(DECIMAL128); supportedTypes = Lists.newArrayList(); supportedTypes.add(NULL); @@ -131,6 +150,9 @@ public abstract class Type { supportedTypes.add(DATEV2); supportedTypes.add(DATETIMEV2); supportedTypes.add(DECIMALV2); + supportedTypes.add(DECIMAL32); + supportedTypes.add(DECIMAL64); + supportedTypes.add(DECIMAL128); supportedTypes.add(TIME); supportedTypes.add(TIMEV2); supportedTypes.add(STRING); @@ -195,6 +217,11 @@ public abstract class Type { return isScalarType(PrimitiveType.DECIMALV2); } + public boolean isDecimalV3() { + return isScalarType(PrimitiveType.DECIMAL32) || isScalarType(PrimitiveType.DECIMAL64) + || isScalarType(PrimitiveType.DECIMAL128); + } + public boolean isDatetimeV2() { return isScalarType(PrimitiveType.DATETIMEV2); } @@ -284,6 +311,10 @@ public abstract class Type { || isScalarType(PrimitiveType.INT); } + public boolean isBigIntType() { + return isScalarType(PrimitiveType.BIGINT); + } + public boolean isLargeIntType() { return isScalarType(PrimitiveType.LARGEINT); } @@ -294,7 +325,7 @@ public abstract class Type { } public boolean isNumericType() { - return isFixedPointType() || isFloatingPointType() || isDecimalV2(); + return isFixedPointType() || isFloatingPointType() || isDecimalV2() || isDecimalV3(); } public boolean isNativeType() { @@ -478,6 +509,12 @@ public abstract class Type { return DOUBLE; case DECIMALV2: return DECIMALV2; + case DECIMAL32: + return DECIMAL32; + case DECIMAL64: + return DECIMAL64; + case DECIMAL128: + return DECIMAL128; default: return INVALID; } @@ -574,6 +611,12 @@ public abstract class Type { return Type.TIMEV2; case DECIMALV2: return Type.DECIMALV2; + case DECIMAL32: + return Type.DECIMAL32; + case DECIMAL64: + return Type.DECIMAL64; + case DECIMAL128: + return Type.DECIMAL128; case CHAR: return Type.CHAR; case VARCHAR: @@ -609,6 +652,18 @@ public abstract class Type { return result; } + public static List toThrift(ArrayList types, ArrayList realTypes) { + ArrayList result = Lists.newArrayList(); + for (int i = 0; i < types.size(); i++) { + if (PrimitiveType.typeWithPrecision.contains(realTypes.get(i).getPrimitiveType())) { + result.add(realTypes.get(i).toThrift()); + } else { + result.add(types.get(i).toThrift()); + } + } + return result; + } + public static Type fromThrift(TTypeDesc thrift) { Preconditions.checkState(thrift.types.size() > 0); Pair t = fromThrift(thrift, 0); @@ -637,10 +692,13 @@ public abstract class Type { type = ScalarType.createVarcharType(scalarType.getLen()); } else if (scalarType.getType() == TPrimitiveType.HLL) { type = ScalarType.createHllType(); - } else if (scalarType.getType() == TPrimitiveType.DECIMALV2) { + } else if (scalarType.getType() == TPrimitiveType.DECIMALV2 + || scalarType.getType() == TPrimitiveType.DECIMAL32 + || scalarType.getType() == TPrimitiveType.DECIMAL64 + || scalarType.getType() == TPrimitiveType.DECIMAL128) { Preconditions.checkState(scalarType.isSetPrecision() && scalarType.isSetPrecision()); - type = ScalarType.createDecimalV2Type(scalarType.getPrecision(), + type = ScalarType.createDecimalType(scalarType.getPrecision(), scalarType.getScale()); } else if (scalarType.getType() == TPrimitiveType.DATETIMEV2) { Preconditions.checkState(scalarType.isSetPrecision() @@ -746,7 +804,7 @@ public abstract class Type { } if (isNumericType()) { int size = getPrecision() + 1; // +1 for minus symbol - if (isScalarType(PrimitiveType.DECIMALV2)) { + if (isScalarType(PrimitiveType.DECIMALV2) || isDecimalV3()) { size += 1; // +1 for decimal point } return size; @@ -789,6 +847,9 @@ public abstract class Type { case DOUBLE: return 15; case DECIMALV2: + case DECIMAL32: + case DECIMAL64: + case DECIMAL128: case DATETIMEV2: case TIMEV2: return t.decimalPrecision(); @@ -822,6 +883,9 @@ public abstract class Type { case DATETIMEV2: case TIMEV2: case DECIMALV2: + case DECIMAL32: + case DECIMAL64: + case DECIMAL128: return t.decimalScale(); default: return null; @@ -853,6 +917,9 @@ public abstract class Type { case FLOAT: case DOUBLE: case DECIMALV2: + case DECIMAL32: + case DECIMAL64: + case DECIMAL128: return 10; default: // everything else (including boolean and string) is null @@ -915,6 +982,9 @@ public abstract class Type { compatibilityMatrix[BOOLEAN.ordinal()][BITMAP.ordinal()] = PrimitiveType.INVALID_TYPE; compatibilityMatrix[BOOLEAN.ordinal()][STRING.ordinal()] = PrimitiveType.INVALID_TYPE; compatibilityMatrix[BOOLEAN.ordinal()][QUANTILE_STATE.ordinal()] = PrimitiveType.INVALID_TYPE; + compatibilityMatrix[BOOLEAN.ordinal()][DECIMAL32.ordinal()] = PrimitiveType.INVALID_TYPE; + compatibilityMatrix[BOOLEAN.ordinal()][DECIMAL64.ordinal()] = PrimitiveType.INVALID_TYPE; + compatibilityMatrix[BOOLEAN.ordinal()][DECIMAL128.ordinal()] = PrimitiveType.INVALID_TYPE; // TINYINT @@ -933,6 +1003,9 @@ public abstract class Type { compatibilityMatrix[TINYINT.ordinal()][CHAR.ordinal()] = PrimitiveType.INVALID_TYPE; compatibilityMatrix[TINYINT.ordinal()][VARCHAR.ordinal()] = PrimitiveType.INVALID_TYPE; compatibilityMatrix[TINYINT.ordinal()][DECIMALV2.ordinal()] = PrimitiveType.INVALID_TYPE; + compatibilityMatrix[TINYINT.ordinal()][DECIMAL32.ordinal()] = PrimitiveType.DECIMAL32; + compatibilityMatrix[TINYINT.ordinal()][DECIMAL64.ordinal()] = PrimitiveType.DECIMAL64; + compatibilityMatrix[TINYINT.ordinal()][DECIMAL128.ordinal()] = PrimitiveType.DECIMAL128; compatibilityMatrix[TINYINT.ordinal()][HLL.ordinal()] = PrimitiveType.INVALID_TYPE; compatibilityMatrix[TINYINT.ordinal()][TIME.ordinal()] = PrimitiveType.DOUBLE; compatibilityMatrix[TINYINT.ordinal()][BITMAP.ordinal()] = PrimitiveType.INVALID_TYPE; @@ -954,6 +1027,9 @@ public abstract class Type { compatibilityMatrix[SMALLINT.ordinal()][CHAR.ordinal()] = PrimitiveType.INVALID_TYPE; compatibilityMatrix[SMALLINT.ordinal()][VARCHAR.ordinal()] = PrimitiveType.INVALID_TYPE; compatibilityMatrix[SMALLINT.ordinal()][DECIMALV2.ordinal()] = PrimitiveType.INVALID_TYPE; + compatibilityMatrix[SMALLINT.ordinal()][DECIMAL32.ordinal()] = PrimitiveType.DECIMAL32; + compatibilityMatrix[SMALLINT.ordinal()][DECIMAL64.ordinal()] = PrimitiveType.DECIMAL64; + compatibilityMatrix[SMALLINT.ordinal()][DECIMAL128.ordinal()] = PrimitiveType.DECIMAL128; compatibilityMatrix[SMALLINT.ordinal()][HLL.ordinal()] = PrimitiveType.INVALID_TYPE; compatibilityMatrix[SMALLINT.ordinal()][TIME.ordinal()] = PrimitiveType.DOUBLE; compatibilityMatrix[SMALLINT.ordinal()][BITMAP.ordinal()] = PrimitiveType.INVALID_TYPE; @@ -978,6 +1054,9 @@ public abstract class Type { compatibilityMatrix[INT.ordinal()][CHAR.ordinal()] = PrimitiveType.INVALID_TYPE; compatibilityMatrix[INT.ordinal()][VARCHAR.ordinal()] = PrimitiveType.INVALID_TYPE; compatibilityMatrix[INT.ordinal()][DECIMALV2.ordinal()] = PrimitiveType.INVALID_TYPE; + compatibilityMatrix[INT.ordinal()][DECIMAL32.ordinal()] = PrimitiveType.DECIMAL32; + compatibilityMatrix[INT.ordinal()][DECIMAL64.ordinal()] = PrimitiveType.DECIMAL64; + compatibilityMatrix[INT.ordinal()][DECIMAL128.ordinal()] = PrimitiveType.DECIMAL128; compatibilityMatrix[INT.ordinal()][HLL.ordinal()] = PrimitiveType.INVALID_TYPE; compatibilityMatrix[INT.ordinal()][TIME.ordinal()] = PrimitiveType.DOUBLE; compatibilityMatrix[INT.ordinal()][BITMAP.ordinal()] = PrimitiveType.INVALID_TYPE; @@ -1004,6 +1083,9 @@ public abstract class Type { compatibilityMatrix[BIGINT.ordinal()][CHAR.ordinal()] = PrimitiveType.INVALID_TYPE; compatibilityMatrix[BIGINT.ordinal()][VARCHAR.ordinal()] = PrimitiveType.INVALID_TYPE; compatibilityMatrix[BIGINT.ordinal()][DECIMALV2.ordinal()] = PrimitiveType.INVALID_TYPE; + compatibilityMatrix[BIGINT.ordinal()][DECIMAL32.ordinal()] = PrimitiveType.DECIMAL64; + compatibilityMatrix[BIGINT.ordinal()][DECIMAL64.ordinal()] = PrimitiveType.DECIMAL64; + compatibilityMatrix[BIGINT.ordinal()][DECIMAL128.ordinal()] = PrimitiveType.DECIMAL128; compatibilityMatrix[BIGINT.ordinal()][HLL.ordinal()] = PrimitiveType.INVALID_TYPE; compatibilityMatrix[BIGINT.ordinal()][TIME.ordinal()] = PrimitiveType.DOUBLE; compatibilityMatrix[BIGINT.ordinal()][BITMAP.ordinal()] = PrimitiveType.INVALID_TYPE; @@ -1021,6 +1103,9 @@ public abstract class Type { compatibilityMatrix[LARGEINT.ordinal()][CHAR.ordinal()] = PrimitiveType.INVALID_TYPE; compatibilityMatrix[LARGEINT.ordinal()][VARCHAR.ordinal()] = PrimitiveType.INVALID_TYPE; compatibilityMatrix[LARGEINT.ordinal()][DECIMALV2.ordinal()] = PrimitiveType.DECIMALV2; + compatibilityMatrix[LARGEINT.ordinal()][DECIMAL32.ordinal()] = PrimitiveType.DECIMAL128; + compatibilityMatrix[LARGEINT.ordinal()][DECIMAL64.ordinal()] = PrimitiveType.DECIMAL128; + compatibilityMatrix[LARGEINT.ordinal()][DECIMAL128.ordinal()] = PrimitiveType.DECIMAL128; compatibilityMatrix[LARGEINT.ordinal()][HLL.ordinal()] = PrimitiveType.INVALID_TYPE; compatibilityMatrix[LARGEINT.ordinal()][TIME.ordinal()] = PrimitiveType.DOUBLE; compatibilityMatrix[LARGEINT.ordinal()][BITMAP.ordinal()] = PrimitiveType.INVALID_TYPE; @@ -1037,6 +1122,9 @@ public abstract class Type { compatibilityMatrix[FLOAT.ordinal()][CHAR.ordinal()] = PrimitiveType.INVALID_TYPE; compatibilityMatrix[FLOAT.ordinal()][VARCHAR.ordinal()] = PrimitiveType.INVALID_TYPE; compatibilityMatrix[FLOAT.ordinal()][DECIMALV2.ordinal()] = PrimitiveType.INVALID_TYPE; + compatibilityMatrix[FLOAT.ordinal()][DECIMAL32.ordinal()] = PrimitiveType.DECIMAL128; + compatibilityMatrix[FLOAT.ordinal()][DECIMAL64.ordinal()] = PrimitiveType.DECIMAL128; + compatibilityMatrix[FLOAT.ordinal()][DECIMAL128.ordinal()] = PrimitiveType.DECIMAL128; compatibilityMatrix[FLOAT.ordinal()][HLL.ordinal()] = PrimitiveType.INVALID_TYPE; compatibilityMatrix[FLOAT.ordinal()][TIME.ordinal()] = PrimitiveType.DOUBLE; compatibilityMatrix[FLOAT.ordinal()][BITMAP.ordinal()] = PrimitiveType.INVALID_TYPE; @@ -1049,6 +1137,9 @@ public abstract class Type { compatibilityMatrix[DOUBLE.ordinal()][CHAR.ordinal()] = PrimitiveType.INVALID_TYPE; compatibilityMatrix[DOUBLE.ordinal()][VARCHAR.ordinal()] = PrimitiveType.INVALID_TYPE; compatibilityMatrix[DOUBLE.ordinal()][DECIMALV2.ordinal()] = PrimitiveType.INVALID_TYPE; + compatibilityMatrix[DOUBLE.ordinal()][DECIMAL32.ordinal()] = PrimitiveType.DECIMAL128; + compatibilityMatrix[DOUBLE.ordinal()][DECIMAL64.ordinal()] = PrimitiveType.DECIMAL128; + compatibilityMatrix[DOUBLE.ordinal()][DECIMAL128.ordinal()] = PrimitiveType.DECIMAL128; compatibilityMatrix[DOUBLE.ordinal()][HLL.ordinal()] = PrimitiveType.INVALID_TYPE; compatibilityMatrix[DOUBLE.ordinal()][TIME.ordinal()] = PrimitiveType.DOUBLE; compatibilityMatrix[DOUBLE.ordinal()][BITMAP.ordinal()] = PrimitiveType.INVALID_TYPE; @@ -1065,6 +1156,9 @@ public abstract class Type { compatibilityMatrix[DATE.ordinal()][CHAR.ordinal()] = PrimitiveType.INVALID_TYPE; compatibilityMatrix[DATE.ordinal()][VARCHAR.ordinal()] = PrimitiveType.INVALID_TYPE; compatibilityMatrix[DATE.ordinal()][DECIMALV2.ordinal()] = PrimitiveType.DECIMALV2; + compatibilityMatrix[DATE.ordinal()][DECIMAL32.ordinal()] = PrimitiveType.DECIMAL32; + compatibilityMatrix[DATE.ordinal()][DECIMAL64.ordinal()] = PrimitiveType.DECIMAL64; + compatibilityMatrix[DATE.ordinal()][DECIMAL128.ordinal()] = PrimitiveType.DECIMAL128; compatibilityMatrix[DATE.ordinal()][HLL.ordinal()] = PrimitiveType.INVALID_TYPE; compatibilityMatrix[DATE.ordinal()][TIME.ordinal()] = PrimitiveType.INVALID_TYPE; compatibilityMatrix[DATE.ordinal()][TIMEV2.ordinal()] = PrimitiveType.INVALID_TYPE; @@ -1079,6 +1173,9 @@ public abstract class Type { compatibilityMatrix[DATEV2.ordinal()][CHAR.ordinal()] = PrimitiveType.INVALID_TYPE; compatibilityMatrix[DATEV2.ordinal()][VARCHAR.ordinal()] = PrimitiveType.INVALID_TYPE; compatibilityMatrix[DATEV2.ordinal()][DECIMALV2.ordinal()] = PrimitiveType.DECIMALV2; + compatibilityMatrix[DATEV2.ordinal()][DECIMAL32.ordinal()] = PrimitiveType.DECIMAL32; + compatibilityMatrix[DATEV2.ordinal()][DECIMAL64.ordinal()] = PrimitiveType.DECIMAL64; + compatibilityMatrix[DATEV2.ordinal()][DECIMAL128.ordinal()] = PrimitiveType.DECIMAL128; compatibilityMatrix[DATEV2.ordinal()][HLL.ordinal()] = PrimitiveType.INVALID_TYPE; compatibilityMatrix[DATEV2.ordinal()][TIME.ordinal()] = PrimitiveType.INVALID_TYPE; compatibilityMatrix[DATEV2.ordinal()][TIMEV2.ordinal()] = PrimitiveType.INVALID_TYPE; @@ -1092,6 +1189,9 @@ public abstract class Type { compatibilityMatrix[DATETIME.ordinal()][DATEV2.ordinal()] = PrimitiveType.DATETIMEV2; compatibilityMatrix[DATETIME.ordinal()][VARCHAR.ordinal()] = PrimitiveType.INVALID_TYPE; compatibilityMatrix[DATETIME.ordinal()][DECIMALV2.ordinal()] = PrimitiveType.DECIMALV2; + compatibilityMatrix[DATETIME.ordinal()][DECIMAL32.ordinal()] = PrimitiveType.DECIMAL64; + compatibilityMatrix[DATETIME.ordinal()][DECIMAL64.ordinal()] = PrimitiveType.DECIMAL64; + compatibilityMatrix[DATETIME.ordinal()][DECIMAL128.ordinal()] = PrimitiveType.DECIMAL128; compatibilityMatrix[DATETIME.ordinal()][HLL.ordinal()] = PrimitiveType.INVALID_TYPE; compatibilityMatrix[DATETIME.ordinal()][TIME.ordinal()] = PrimitiveType.INVALID_TYPE; compatibilityMatrix[DATETIME.ordinal()][TIMEV2.ordinal()] = PrimitiveType.INVALID_TYPE; @@ -1105,6 +1205,9 @@ public abstract class Type { compatibilityMatrix[DATETIMEV2.ordinal()][DATEV2.ordinal()] = PrimitiveType.DATETIMEV2; compatibilityMatrix[DATETIMEV2.ordinal()][VARCHAR.ordinal()] = PrimitiveType.INVALID_TYPE; compatibilityMatrix[DATETIMEV2.ordinal()][DECIMALV2.ordinal()] = PrimitiveType.DECIMALV2; + compatibilityMatrix[DATETIMEV2.ordinal()][DECIMAL32.ordinal()] = PrimitiveType.DECIMAL64; + compatibilityMatrix[DATETIMEV2.ordinal()][DECIMAL64.ordinal()] = PrimitiveType.DECIMAL64; + compatibilityMatrix[DATETIMEV2.ordinal()][DECIMAL128.ordinal()] = PrimitiveType.DECIMAL128; compatibilityMatrix[DATETIMEV2.ordinal()][HLL.ordinal()] = PrimitiveType.INVALID_TYPE; compatibilityMatrix[DATETIMEV2.ordinal()][TIME.ordinal()] = PrimitiveType.INVALID_TYPE; compatibilityMatrix[DATETIMEV2.ordinal()][TIMEV2.ordinal()] = PrimitiveType.INVALID_TYPE; @@ -1118,6 +1221,9 @@ public abstract class Type { compatibilityMatrix[CHAR.ordinal()][DATEV2.ordinal()] = PrimitiveType.INVALID_TYPE; compatibilityMatrix[CHAR.ordinal()][DATETIMEV2.ordinal()] = PrimitiveType.INVALID_TYPE; compatibilityMatrix[CHAR.ordinal()][DECIMALV2.ordinal()] = PrimitiveType.INVALID_TYPE; + compatibilityMatrix[CHAR.ordinal()][DECIMAL32.ordinal()] = PrimitiveType.INVALID_TYPE; + compatibilityMatrix[CHAR.ordinal()][DECIMAL64.ordinal()] = PrimitiveType.INVALID_TYPE; + compatibilityMatrix[CHAR.ordinal()][DECIMAL128.ordinal()] = PrimitiveType.INVALID_TYPE; compatibilityMatrix[CHAR.ordinal()][HLL.ordinal()] = PrimitiveType.INVALID_TYPE; compatibilityMatrix[CHAR.ordinal()][TIME.ordinal()] = PrimitiveType.INVALID_TYPE; compatibilityMatrix[CHAR.ordinal()][TIMEV2.ordinal()] = PrimitiveType.INVALID_TYPE; @@ -1127,6 +1233,9 @@ public abstract class Type { // VARCHAR compatibilityMatrix[VARCHAR.ordinal()][DECIMALV2.ordinal()] = PrimitiveType.INVALID_TYPE; + compatibilityMatrix[VARCHAR.ordinal()][DECIMAL32.ordinal()] = PrimitiveType.INVALID_TYPE; + compatibilityMatrix[VARCHAR.ordinal()][DECIMAL64.ordinal()] = PrimitiveType.INVALID_TYPE; + compatibilityMatrix[VARCHAR.ordinal()][DECIMAL128.ordinal()] = PrimitiveType.INVALID_TYPE; compatibilityMatrix[VARCHAR.ordinal()][HLL.ordinal()] = PrimitiveType.INVALID_TYPE; compatibilityMatrix[VARCHAR.ordinal()][TIME.ordinal()] = PrimitiveType.INVALID_TYPE; compatibilityMatrix[VARCHAR.ordinal()][TIMEV2.ordinal()] = PrimitiveType.INVALID_TYPE; @@ -1144,6 +1253,9 @@ public abstract class Type { compatibilityMatrix[STRING.ordinal()][DATETIMEV2.ordinal()] = PrimitiveType.INVALID_TYPE; compatibilityMatrix[STRING.ordinal()][BITMAP.ordinal()] = PrimitiveType.INVALID_TYPE; compatibilityMatrix[STRING.ordinal()][QUANTILE_STATE.ordinal()] = PrimitiveType.INVALID_TYPE; + compatibilityMatrix[STRING.ordinal()][DECIMAL32.ordinal()] = PrimitiveType.INVALID_TYPE; + compatibilityMatrix[STRING.ordinal()][DECIMAL64.ordinal()] = PrimitiveType.INVALID_TYPE; + compatibilityMatrix[STRING.ordinal()][DECIMAL128.ordinal()] = PrimitiveType.INVALID_TYPE; // DECIMALV2 @@ -1155,6 +1267,48 @@ public abstract class Type { compatibilityMatrix[DECIMALV2.ordinal()][BITMAP.ordinal()] = PrimitiveType.INVALID_TYPE; compatibilityMatrix[DECIMALV2.ordinal()][STRING.ordinal()] = PrimitiveType.INVALID_TYPE; compatibilityMatrix[DECIMALV2.ordinal()][QUANTILE_STATE.ordinal()] = PrimitiveType.INVALID_TYPE; + compatibilityMatrix[DECIMALV2.ordinal()][DECIMAL32.ordinal()] = PrimitiveType.INVALID_TYPE; + compatibilityMatrix[DECIMALV2.ordinal()][DECIMAL64.ordinal()] = PrimitiveType.INVALID_TYPE; + compatibilityMatrix[DECIMALV2.ordinal()][DECIMAL128.ordinal()] = PrimitiveType.INVALID_TYPE; + + // DECIMAL32 + compatibilityMatrix[DECIMAL32.ordinal()][HLL.ordinal()] = PrimitiveType.INVALID_TYPE; + compatibilityMatrix[DECIMAL32.ordinal()][TIME.ordinal()] = PrimitiveType.INVALID_TYPE; + compatibilityMatrix[DECIMAL32.ordinal()][TIMEV2.ordinal()] = PrimitiveType.INVALID_TYPE; + compatibilityMatrix[DECIMAL32.ordinal()][DATEV2.ordinal()] = PrimitiveType.DECIMAL32; + compatibilityMatrix[DECIMAL32.ordinal()][DATETIMEV2.ordinal()] = PrimitiveType.DECIMAL32; + compatibilityMatrix[DECIMAL32.ordinal()][BITMAP.ordinal()] = PrimitiveType.INVALID_TYPE; + compatibilityMatrix[DECIMAL32.ordinal()][STRING.ordinal()] = PrimitiveType.INVALID_TYPE; + compatibilityMatrix[DECIMAL32.ordinal()][QUANTILE_STATE.ordinal()] = PrimitiveType.INVALID_TYPE; + compatibilityMatrix[DECIMAL32.ordinal()][DECIMALV2.ordinal()] = PrimitiveType.INVALID_TYPE; + compatibilityMatrix[DECIMAL32.ordinal()][DECIMAL64.ordinal()] = PrimitiveType.DECIMAL64; + compatibilityMatrix[DECIMAL32.ordinal()][DECIMAL128.ordinal()] = PrimitiveType.DECIMAL128; + + // DECIMAL64 + compatibilityMatrix[DECIMAL64.ordinal()][HLL.ordinal()] = PrimitiveType.INVALID_TYPE; + compatibilityMatrix[DECIMAL64.ordinal()][TIME.ordinal()] = PrimitiveType.INVALID_TYPE; + compatibilityMatrix[DECIMAL64.ordinal()][TIMEV2.ordinal()] = PrimitiveType.INVALID_TYPE; + compatibilityMatrix[DECIMAL64.ordinal()][DATEV2.ordinal()] = PrimitiveType.DECIMAL64; + compatibilityMatrix[DECIMAL64.ordinal()][DATETIMEV2.ordinal()] = PrimitiveType.DECIMAL64; + compatibilityMatrix[DECIMAL64.ordinal()][BITMAP.ordinal()] = PrimitiveType.INVALID_TYPE; + compatibilityMatrix[DECIMAL64.ordinal()][STRING.ordinal()] = PrimitiveType.INVALID_TYPE; + compatibilityMatrix[DECIMAL64.ordinal()][QUANTILE_STATE.ordinal()] = PrimitiveType.INVALID_TYPE; + compatibilityMatrix[DECIMAL64.ordinal()][DECIMALV2.ordinal()] = PrimitiveType.INVALID_TYPE; + compatibilityMatrix[DECIMAL64.ordinal()][DECIMAL32.ordinal()] = PrimitiveType.DECIMAL64; + compatibilityMatrix[DECIMAL64.ordinal()][DECIMAL128.ordinal()] = PrimitiveType.DECIMAL128; + + // DECIMAL128 + compatibilityMatrix[DECIMAL128.ordinal()][HLL.ordinal()] = PrimitiveType.INVALID_TYPE; + compatibilityMatrix[DECIMAL128.ordinal()][TIME.ordinal()] = PrimitiveType.INVALID_TYPE; + compatibilityMatrix[DECIMAL128.ordinal()][TIMEV2.ordinal()] = PrimitiveType.INVALID_TYPE; + compatibilityMatrix[DECIMAL128.ordinal()][DATEV2.ordinal()] = PrimitiveType.DECIMAL128; + compatibilityMatrix[DECIMAL128.ordinal()][DATETIMEV2.ordinal()] = PrimitiveType.DECIMAL128; + compatibilityMatrix[DECIMAL128.ordinal()][BITMAP.ordinal()] = PrimitiveType.INVALID_TYPE; + compatibilityMatrix[DECIMAL128.ordinal()][STRING.ordinal()] = PrimitiveType.INVALID_TYPE; + compatibilityMatrix[DECIMAL128.ordinal()][QUANTILE_STATE.ordinal()] = PrimitiveType.INVALID_TYPE; + compatibilityMatrix[DECIMAL128.ordinal()][DECIMALV2.ordinal()] = PrimitiveType.INVALID_TYPE; + compatibilityMatrix[DECIMAL128.ordinal()][DECIMAL32.ordinal()] = PrimitiveType.DECIMAL128; + compatibilityMatrix[DECIMAL128.ordinal()][DECIMAL64.ordinal()] = PrimitiveType.DECIMAL128; // HLL @@ -1165,6 +1319,9 @@ public abstract class Type { compatibilityMatrix[HLL.ordinal()][BITMAP.ordinal()] = PrimitiveType.INVALID_TYPE; compatibilityMatrix[HLL.ordinal()][STRING.ordinal()] = PrimitiveType.INVALID_TYPE; compatibilityMatrix[HLL.ordinal()][QUANTILE_STATE.ordinal()] = PrimitiveType.INVALID_TYPE; + compatibilityMatrix[HLL.ordinal()][DECIMAL32.ordinal()] = PrimitiveType.INVALID_TYPE; + compatibilityMatrix[HLL.ordinal()][DECIMAL64.ordinal()] = PrimitiveType.INVALID_TYPE; + compatibilityMatrix[HLL.ordinal()][DECIMAL128.ordinal()] = PrimitiveType.INVALID_TYPE; // BITMAP @@ -1174,18 +1331,31 @@ public abstract class Type { compatibilityMatrix[BITMAP.ordinal()][QUANTILE_STATE.ordinal()] = PrimitiveType.INVALID_TYPE; compatibilityMatrix[BITMAP.ordinal()][DATEV2.ordinal()] = PrimitiveType.INVALID_TYPE; compatibilityMatrix[BITMAP.ordinal()][DATETIMEV2.ordinal()] = PrimitiveType.INVALID_TYPE; + compatibilityMatrix[BITMAP.ordinal()][DECIMAL32.ordinal()] = PrimitiveType.INVALID_TYPE; + compatibilityMatrix[BITMAP.ordinal()][DECIMAL64.ordinal()] = PrimitiveType.INVALID_TYPE; + compatibilityMatrix[BITMAP.ordinal()][DECIMAL128.ordinal()] = PrimitiveType.INVALID_TYPE; //QUANTILE_STATE compatibilityMatrix[QUANTILE_STATE.ordinal()][STRING.ordinal()] = PrimitiveType.INVALID_TYPE; compatibilityMatrix[QUANTILE_STATE.ordinal()][DATEV2.ordinal()] = PrimitiveType.INVALID_TYPE; compatibilityMatrix[QUANTILE_STATE.ordinal()][DATETIMEV2.ordinal()] = PrimitiveType.INVALID_TYPE; + compatibilityMatrix[QUANTILE_STATE.ordinal()][DECIMAL32.ordinal()] = PrimitiveType.INVALID_TYPE; + compatibilityMatrix[QUANTILE_STATE.ordinal()][DECIMAL64.ordinal()] = PrimitiveType.INVALID_TYPE; + compatibilityMatrix[QUANTILE_STATE.ordinal()][DECIMAL128.ordinal()] = PrimitiveType.INVALID_TYPE; // TIME why here not??? compatibilityMatrix[TIME.ordinal()][TIME.ordinal()] = PrimitiveType.INVALID_TYPE; compatibilityMatrix[TIME.ordinal()][TIMEV2.ordinal()] = PrimitiveType.INVALID_TYPE; + compatibilityMatrix[TIME.ordinal()][DECIMAL32.ordinal()] = PrimitiveType.INVALID_TYPE; + compatibilityMatrix[TIME.ordinal()][DECIMAL64.ordinal()] = PrimitiveType.INVALID_TYPE; + compatibilityMatrix[TIME.ordinal()][DECIMAL128.ordinal()] = PrimitiveType.INVALID_TYPE; + compatibilityMatrix[TIMEV2.ordinal()][TIMEV2.ordinal()] = PrimitiveType.INVALID_TYPE; compatibilityMatrix[TIMEV2.ordinal()][TIME.ordinal()] = PrimitiveType.INVALID_TYPE; + compatibilityMatrix[TIMEV2.ordinal()][DECIMAL32.ordinal()] = PrimitiveType.INVALID_TYPE; + compatibilityMatrix[TIMEV2.ordinal()][DECIMAL64.ordinal()] = PrimitiveType.INVALID_TYPE; + compatibilityMatrix[TIMEV2.ordinal()][DECIMAL128.ordinal()] = PrimitiveType.INVALID_TYPE; // Check all of the necessary entries that should be filled. // ignore binary and all @@ -1238,6 +1408,12 @@ public abstract class Type { return DEFAULT_TIMEV2; case DECIMALV2: return DECIMALV2; + case DECIMAL32: + return DECIMAL32; + case DECIMAL64: + return DECIMAL64; + case DECIMAL128: + return DECIMAL128; case STRING: return STRING; default: @@ -1279,6 +1455,22 @@ public abstract class Type { if (t1ResultType == PrimitiveType.BIGINT && t2ResultType == PrimitiveType.BIGINT) { return getAssignmentCompatibleType(t1, t2, false); } + if (t1ResultType.isDecimalV3Type() && t2ResultType.isDecimalV3Type()) { + int resultPrecision = Math.max(t1.getPrecision(), t2.getPrecision()); + PrimitiveType resultDecimalType; + if (resultPrecision <= ScalarType.MAX_DECIMAL32_PRECISION) { + resultDecimalType = PrimitiveType.DECIMAL32; + } else if (resultPrecision <= ScalarType.MAX_DECIMAL64_PRECISION) { + resultDecimalType = PrimitiveType.DECIMAL64; + } else { + resultDecimalType = PrimitiveType.DECIMAL128; + } + return ScalarType.createDecimalType(resultDecimalType, resultPrecision, + Math.max(((ScalarType) t1).getScalarScale(), ((ScalarType) t2).getScalarScale())); + } + if (t1ResultType.isDecimalV3Type() || t2ResultType.isDecimalV3Type()) { + return getAssignmentCompatibleType(t1, t2, false); + } if ((t1ResultType == PrimitiveType.BIGINT || t1ResultType == PrimitiveType.DECIMALV2) && (t2ResultType == PrimitiveType.BIGINT @@ -1366,6 +1558,12 @@ public abstract class Type { return Type.DEFAULT_TIMEV2; case DECIMALV2: return Type.DECIMALV2; + case DECIMAL32: + return Type.DECIMAL32; + case DECIMAL64: + return Type.DECIMAL64; + case DECIMAL128: + return Type.DECIMAL128; default: return Type.INVALID; diff --git a/fe/fe-core/src/main/java/org/apache/doris/common/Config.java b/fe/fe-core/src/main/java/org/apache/doris/common/Config.java index 4f9231516d..198a240a80 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/common/Config.java +++ b/fe/fe-core/src/main/java/org/apache/doris/common/Config.java @@ -1670,6 +1670,13 @@ public class Config extends ConfigBase { @ConfField(mutable = true, masterOnly = false) public static long file_scan_node_split_num = 128; + /* + * If set to TRUE, the precision of decimal will be broaden to [1, 38]. + * Decimalv3 of storage layer needs to be enabled first. + */ + @ConfField + public static boolean enable_decimalv3 = false; + /** * If set to TRUE, FE will: * 1. divide BE into high load and low load(no mid load) to force triggering tablet scheduling; diff --git a/fe/fe-core/src/main/java/org/apache/doris/common/util/Util.java b/fe/fe-core/src/main/java/org/apache/doris/common/util/Util.java index 2d433e3642..d8ea432dd9 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/common/util/Util.java +++ b/fe/fe-core/src/main/java/org/apache/doris/common/util/Util.java @@ -76,6 +76,9 @@ public class Util { TYPE_STRING_MAP.put(PrimitiveType.VARCHAR, "varchar(%d)"); TYPE_STRING_MAP.put(PrimitiveType.STRING, "string"); TYPE_STRING_MAP.put(PrimitiveType.DECIMALV2, "decimal(%d,%d)"); + TYPE_STRING_MAP.put(PrimitiveType.DECIMAL32, "decimal(%d,%d)"); + TYPE_STRING_MAP.put(PrimitiveType.DECIMAL64, "decimal(%d,%d)"); + TYPE_STRING_MAP.put(PrimitiveType.DECIMAL128, "decimal(%d,%d)"); TYPE_STRING_MAP.put(PrimitiveType.HLL, "varchar(%d)"); TYPE_STRING_MAP.put(PrimitiveType.BOOLEAN, "bool"); TYPE_STRING_MAP.put(PrimitiveType.BITMAP, "bitmap"); diff --git a/fe/fe-core/src/main/java/org/apache/doris/datasource/InternalDataSource.java b/fe/fe-core/src/main/java/org/apache/doris/datasource/InternalDataSource.java index 634bca7079..0dc170154a 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/datasource/InternalDataSource.java +++ b/fe/fe-core/src/main/java/org/apache/doris/datasource/InternalDataSource.java @@ -1132,7 +1132,10 @@ public class InternalDataSource implements DataSourceIf { if (resultType.isStringType() && resultType.getLength() < 0) { typeDef = new TypeDef(Type.STRING); } else if (resultType.isDecimalV2() && resultType.equals(ScalarType.DECIMALV2)) { - typeDef = new TypeDef(ScalarType.createDecimalV2Type(27, 9)); + typeDef = new TypeDef(ScalarType.createDecimalType(27, 9)); + } else if (resultType.isDecimalV3()) { + typeDef = new TypeDef(ScalarType.createDecimalType(resultType.getPrecision(), + ((ScalarType) resultType).getScalarScale())); } else { typeDef = new TypeDef(resultExpr.getType()); } diff --git a/fe/fe-core/src/main/java/org/apache/doris/external/hive/util/HiveUtil.java b/fe/fe-core/src/main/java/org/apache/doris/external/hive/util/HiveUtil.java index 57c83e6392..3ac1806dbf 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/external/hive/util/HiveUtil.java +++ b/fe/fe-core/src/main/java/org/apache/doris/external/hive/util/HiveUtil.java @@ -22,6 +22,7 @@ import org.apache.doris.catalog.ArrayType; import org.apache.doris.catalog.Column; import org.apache.doris.catalog.Type; import org.apache.doris.common.AnalysisException; +import org.apache.doris.common.Config; import org.apache.doris.common.UserException; import com.google.common.collect.Lists; @@ -157,7 +158,7 @@ public final class HiveUtil { case TIMESTAMP: return DateLiteral.getDefaultDateType(Type.DATETIME); case DECIMAL: - return Type.DECIMALV2; + return Config.enable_decimalv3 ? Type.DECIMAL128 : Type.DECIMALV2; default: throw new UnsupportedOperationException("Unsupported type: " + primitiveTypeInfo.getPrimitiveCategory()); diff --git a/fe/fe-core/src/main/java/org/apache/doris/external/iceberg/util/DorisTypeToType.java b/fe/fe-core/src/main/java/org/apache/doris/external/iceberg/util/DorisTypeToType.java index a5d2701ed4..531de582ea 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/external/iceberg/util/DorisTypeToType.java +++ b/fe/fe-core/src/main/java/org/apache/doris/external/iceberg/util/DorisTypeToType.java @@ -99,7 +99,8 @@ public class DorisTypeToType extends DorisTypeVisitor { } else if (atomic.getPrimitiveType().equals(PrimitiveType.TIME) || atomic.getPrimitiveType().equals(PrimitiveType.TIMEV2)) { return Types.TimeType.get(); - } else if (atomic.getPrimitiveType().equals(PrimitiveType.DECIMALV2)) { + } else if (atomic.getPrimitiveType().equals(PrimitiveType.DECIMALV2) + || atomic.getPrimitiveType().isDecimalV3Type()) { return Types.DecimalType.of( ((ScalarType) atomic).getScalarPrecision(), ((ScalarType) atomic).getScalarScale()); diff --git a/fe/fe-core/src/main/java/org/apache/doris/external/iceberg/util/TypeToDorisType.java b/fe/fe-core/src/main/java/org/apache/doris/external/iceberg/util/TypeToDorisType.java index c2eda3779e..43fe10aa03 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/external/iceberg/util/TypeToDorisType.java +++ b/fe/fe-core/src/main/java/org/apache/doris/external/iceberg/util/TypeToDorisType.java @@ -77,7 +77,7 @@ public class TypeToDorisType extends TypeUtil.SchemaVisitor { return Type.DOUBLE; case DECIMAL: Types.DecimalType decimal = (Types.DecimalType) primitive; - return ScalarType.createDecimalV2Type(decimal.precision(), decimal.scale()); + return ScalarType.createDecimalType(decimal.precision(), decimal.scale()); case DATE: return DateLiteral.getDefaultDateType(Type.DATE); case TIMESTAMP: diff --git a/fe/fe-core/src/main/java/org/apache/doris/httpv2/rest/TableSchemaAction.java b/fe/fe-core/src/main/java/org/apache/doris/httpv2/rest/TableSchemaAction.java index beeca5739c..5b56b81e2d 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/httpv2/rest/TableSchemaAction.java +++ b/fe/fe-core/src/main/java/org/apache/doris/httpv2/rest/TableSchemaAction.java @@ -79,7 +79,7 @@ public class TableSchemaAction extends RestBaseController { Map baseInfo = new HashMap<>(2); Type colType = column.getOriginType(); PrimitiveType primitiveType = colType.getPrimitiveType(); - if (primitiveType == PrimitiveType.DECIMALV2) { + if (primitiveType == PrimitiveType.DECIMALV2 || primitiveType.isDecimalV3Type()) { ScalarType scalarType = (ScalarType) colType; baseInfo.put("precision", scalarType.getPrecision() + ""); baseInfo.put("scale", scalarType.getScalarScale() + ""); diff --git a/fe/fe-core/src/main/java/org/apache/doris/load/loadv2/SparkLoadPendingTask.java b/fe/fe-core/src/main/java/org/apache/doris/load/loadv2/SparkLoadPendingTask.java index ad56376c03..f543fbe3ca 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/load/loadv2/SparkLoadPendingTask.java +++ b/fe/fe-core/src/main/java/org/apache/doris/load/loadv2/SparkLoadPendingTask.java @@ -314,7 +314,7 @@ public class SparkLoadPendingTask extends LoadTask { // decimal precision scale int precision = 0; int scale = 0; - if (type.isDecimalV2Type()) { + if (type.isDecimalV2Type() || type.isDecimalV3Type()) { precision = column.getPrecision(); scale = column.getScale(); } diff --git a/fe/fe-core/src/main/java/org/apache/doris/planner/OriginalPlanner.java b/fe/fe-core/src/main/java/org/apache/doris/planner/OriginalPlanner.java index b7bcc6b331..8397acc594 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/planner/OriginalPlanner.java +++ b/fe/fe-core/src/main/java/org/apache/doris/planner/OriginalPlanner.java @@ -93,11 +93,13 @@ public class OriginalPlanner extends Planner { for (Expr expr : outputExprs) { List slotList = Lists.newArrayList(); expr.getIds(null, slotList); - if (PrimitiveType.DECIMALV2 != expr.getType().getPrimitiveType()) { + if ((!expr.getType().getPrimitiveType().isDecimalV2Type() + && expr.getType().getPrimitiveType().isDecimalV3Type())) { continue; } - if (PrimitiveType.DECIMALV2 != slotDesc.getType().getPrimitiveType()) { + if (!slotDesc.getType().getPrimitiveType().isDecimalV2Type() + && !slotDesc.getType().getPrimitiveType().isDecimalV3Type()) { continue; } diff --git a/fe/fe-core/src/main/java/org/apache/doris/qe/cache/PartitionRange.java b/fe/fe-core/src/main/java/org/apache/doris/qe/cache/PartitionRange.java index a482184a4e..1c5b5bb4a0 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/qe/cache/PartitionRange.java +++ b/fe/fe-core/src/main/java/org/apache/doris/qe/cache/PartitionRange.java @@ -181,6 +181,9 @@ public class PartitionRange { case FLOAT: case DOUBLE: case DECIMALV2: + case DECIMAL32: + case DECIMAL64: + case DECIMAL128: case CHAR: case VARCHAR: case STRING: diff --git a/fe/fe-core/src/main/java/org/apache/doris/rewrite/FEFunctions.java b/fe/fe-core/src/main/java/org/apache/doris/rewrite/FEFunctions.java index 02d0b25922..e1d2ca5b06 100755 --- a/fe/fe-core/src/main/java/org/apache/doris/rewrite/FEFunctions.java +++ b/fe/fe-core/src/main/java/org/apache/doris/rewrite/FEFunctions.java @@ -602,6 +602,33 @@ public class FEFunctions { return new DecimalLiteral(result); } + @FEFunction(name = "add", argTypes = { "DECIMAL32", "DECIMAL32" }, returnType = "DECIMAL32") + public static DecimalLiteral addDecimal32(LiteralExpr first, LiteralExpr second) throws AnalysisException { + BigDecimal left = new BigDecimal(first.getStringValue()); + BigDecimal right = new BigDecimal(second.getStringValue()); + + BigDecimal result = left.add(right); + return new DecimalLiteral(result); + } + + @FEFunction(name = "add", argTypes = { "DECIMAL64", "DECIMAL64" }, returnType = "DECIMAL64") + public static DecimalLiteral addDecimal64(LiteralExpr first, LiteralExpr second) throws AnalysisException { + BigDecimal left = new BigDecimal(first.getStringValue()); + BigDecimal right = new BigDecimal(second.getStringValue()); + + BigDecimal result = left.add(right); + return new DecimalLiteral(result); + } + + @FEFunction(name = "add", argTypes = { "DECIMAL128", "DECIMAL128" }, returnType = "DECIMAL128") + public static DecimalLiteral addDecimal128(LiteralExpr first, LiteralExpr second) throws AnalysisException { + BigDecimal left = new BigDecimal(first.getStringValue()); + BigDecimal right = new BigDecimal(second.getStringValue()); + + BigDecimal result = left.add(right); + return new DecimalLiteral(result); + } + @FEFunction(name = "add", argTypes = { "LARGEINT", "LARGEINT" }, returnType = "LARGEINT") public static LargeIntLiteral addBigInt(LiteralExpr first, LiteralExpr second) throws AnalysisException { BigInteger left = new BigInteger(first.getStringValue()); diff --git a/fe/fe-core/src/main/java/org/apache/doris/rewrite/RewriteBinaryPredicatesRule.java b/fe/fe-core/src/main/java/org/apache/doris/rewrite/RewriteBinaryPredicatesRule.java index a18797b657..2f8b9603e9 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/rewrite/RewriteBinaryPredicatesRule.java +++ b/fe/fe-core/src/main/java/org/apache/doris/rewrite/RewriteBinaryPredicatesRule.java @@ -27,9 +27,12 @@ import org.apache.doris.analysis.Expr; import org.apache.doris.analysis.IntLiteral; import org.apache.doris.analysis.LiteralExpr; import org.apache.doris.analysis.SlotRef; +import org.apache.doris.catalog.ScalarType; import org.apache.doris.catalog.Type; import org.apache.doris.common.AnalysisException; +import java.math.BigDecimal; + /** * Rewrite binary predicate. */ @@ -113,6 +116,56 @@ public class RewriteBinaryPredicatesRule implements ExprRewriteRule { } } + private Expr rewriteDecimalLiteral(Expr expr) { + BinaryPredicate.Operator op = ((BinaryPredicate) expr).getOp(); + Expr expr0 = expr.getChild(0); + Expr expr1 = expr.getChild(1); + if (expr1.getType().isDecimalV3() && expr1 instanceof DecimalLiteral) { + DecimalLiteral literal = (DecimalLiteral) expr1; + if (expr0.getType().isDecimalV3() + && ((ScalarType) expr0.getType()).getScalarScale() + < ((ScalarType) expr1.getType()).getScalarScale()) { + switch (op) { + case EQ: { + BigDecimal originValue = literal.getValue(); + literal.roundCeiling(((ScalarType) expr0.getType()).getScalarScale()); + if (literal.getValue().equals(originValue)) { + expr.setChild(1, literal); + return expr; + } else { + return new BoolLiteral(false); + } + } + case NE: { + BigDecimal originValue = literal.getValue(); + literal.roundCeiling(((ScalarType) expr0.getType()).getScalarScale()); + if (literal.getValue().equals(originValue)) { + expr.setChild(1, literal); + return expr; + } else { + return new BoolLiteral(true); + } + } + case GT: + case LE: { + literal.roundFloor(((ScalarType) expr0.getType()).getScalarScale()); + expr.setChild(1, literal); + return expr; + } + case LT: + case GE: { + literal.roundCeiling(((ScalarType) expr0.getType()).getScalarScale()); + expr.setChild(1, literal); + return expr; + } + default: + return expr; + } + } + } + return expr; + } + @Override public Expr apply(Expr expr, Analyzer analyzer, ExprRewriter.ClauseType clauseType) throws AnalysisException { if (!(expr instanceof BinaryPredicate)) { @@ -121,10 +174,11 @@ public class RewriteBinaryPredicatesRule implements ExprRewriteRule { BinaryPredicate.Operator op = ((BinaryPredicate) expr).getOp(); Expr expr0 = expr.getChild(0); Expr expr1 = expr.getChild(1); - if (expr0 instanceof CastExpr && expr0.getType() == Type.DECIMALV2 && expr0.getChild(0) instanceof SlotRef - && expr0.getChild(0).getType().getResultType() == Type.BIGINT && expr1 instanceof DecimalLiteral) { + if (expr0 instanceof CastExpr && (expr0.getType().isDecimalV2() || expr0.getType().isDecimalV3()) + && expr0.getChild(0) instanceof SlotRef && expr0.getChild(0).getType().getResultType() + == Type.BIGINT && expr1 instanceof DecimalLiteral) { return rewriteBigintSlotRefCompareDecimalLiteral(expr0, (DecimalLiteral) expr1, op); } - return expr; + return rewriteDecimalLiteral(expr); } } diff --git a/fe/fe-core/src/main/java/org/apache/doris/rewrite/RewriteInPredicateRule.java b/fe/fe-core/src/main/java/org/apache/doris/rewrite/RewriteInPredicateRule.java index b37377e72a..f6f6b9bc3f 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/rewrite/RewriteInPredicateRule.java +++ b/fe/fe-core/src/main/java/org/apache/doris/rewrite/RewriteInPredicateRule.java @@ -26,6 +26,7 @@ import org.apache.doris.analysis.SlotRef; import org.apache.doris.analysis.Subquery; import org.apache.doris.catalog.Type; import org.apache.doris.common.AnalysisException; +import org.apache.doris.common.Config; import org.apache.doris.rewrite.ExprRewriter.ClauseType; import com.google.common.collect.Lists; @@ -82,7 +83,8 @@ public class RewriteInPredicateRule implements ExprRewriteRule { // cannot be directly converted to LargeIntLiteral, so it is converted to decimal first. if (childExpr.getType().getPrimitiveType().isCharFamily() || childExpr.getType().isFloatingPointType()) { try { - childExpr = (LiteralExpr) childExpr.castTo(Type.DECIMALV2); + childExpr = (LiteralExpr) childExpr.castTo(Config.enable_decimalv3 + ? Type.DECIMAL32 : Type.DECIMALV2); } catch (AnalysisException e) { continue; } diff --git a/fe/fe-core/src/main/java/org/apache/doris/statistics/ColumnStats.java b/fe/fe-core/src/main/java/org/apache/doris/statistics/ColumnStats.java index bb81e443ab..eab8bed080 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/statistics/ColumnStats.java +++ b/fe/fe-core/src/main/java/org/apache/doris/statistics/ColumnStats.java @@ -149,6 +149,9 @@ public class ColumnStats { case DOUBLE: return new FloatLiteral(columnValue); case DECIMALV2: + case DECIMAL32: + case DECIMAL64: + case DECIMAL128: DecimalLiteral decimalLiteral = new DecimalLiteral(columnValue); decimalLiteral.checkPrecisionAndScale(scalarType.getScalarPrecision(), scalarType.getScalarScale()); return decimalLiteral; diff --git a/fe/fe-core/src/main/java/org/apache/doris/task/HadoopLoadPendingTask.java b/fe/fe-core/src/main/java/org/apache/doris/task/HadoopLoadPendingTask.java index 7e5b110403..00ee88a39a 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/task/HadoopLoadPendingTask.java +++ b/fe/fe-core/src/main/java/org/apache/doris/task/HadoopLoadPendingTask.java @@ -553,6 +553,9 @@ public class HadoopLoadPendingTask extends LoadPendingTask { columnType = "QUANTILE_STATE"; break; case DECIMALV2: + case DECIMAL32: + case DECIMAL64: + case DECIMAL128: columnType = "DECIMAL"; break; default: @@ -581,7 +584,7 @@ public class HadoopLoadPendingTask extends LoadPendingTask { } // decimal precision scale - if (type == PrimitiveType.DECIMALV2) { + if (type.isDecimalV2Type() || type.isDecimalV3Type()) { dppColumn.put("precision", column.getPrecision()); dppColumn.put("scale", column.getScale()); } diff --git a/fe/fe-core/src/test/java/org/apache/doris/analysis/ArithmeticExprTest.java b/fe/fe-core/src/test/java/org/apache/doris/analysis/ArithmeticExprTest.java new file mode 100644 index 0000000000..d985a8a72c --- /dev/null +++ b/fe/fe-core/src/test/java/org/apache/doris/analysis/ArithmeticExprTest.java @@ -0,0 +1,175 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +package org.apache.doris.analysis; + +import org.apache.doris.catalog.PrimitiveType; +import org.apache.doris.catalog.ScalarType; +import org.apache.doris.common.AnalysisException; +import org.apache.doris.common.util.VectorizedUtil; +import org.apache.doris.datasource.InternalDataSource; + +import mockit.Expectations; +import mockit.Mocked; +import org.junit.Assert; +import org.junit.Test; + +import java.util.Arrays; +import java.util.List; + +public class ArithmeticExprTest { + private static final String internalCtl = InternalDataSource.INTERNAL_DS_NAME; + + @Test + public void testDecimalArithmetic(@Mocked VectorizedUtil vectorizedUtil) { + Expr lhsExpr = new SlotRef(new TableName(internalCtl, "db", "table"), "c0"); + Expr rhsExpr = new SlotRef(new TableName(internalCtl, "db", "table"), "c1"); + ScalarType t1; + ScalarType t2; + ScalarType res; + ArithmeticExpr arithmeticExpr; + boolean hasException = false; + new Expectations() { + { + vectorizedUtil.isVectorized(); + result = true; + } + }; + List operators = Arrays.asList(ArithmeticExpr.Operator.ADD, + ArithmeticExpr.Operator.SUBTRACT, ArithmeticExpr.Operator.MOD, + ArithmeticExpr.Operator.MULTIPLY, ArithmeticExpr.Operator.DIVIDE); + try { + for (ArithmeticExpr.Operator operator : operators) { + t1 = ScalarType.createDecimalType(9, 4); + t2 = ScalarType.createDecimalType(19, 6); + lhsExpr.setType(t1); + rhsExpr.setType(t2); + arithmeticExpr = new ArithmeticExpr(operator, lhsExpr, rhsExpr); + res = ScalarType.createDecimalType(38, 6); + if (operator == ArithmeticExpr.Operator.MULTIPLY) { + res = ScalarType.createDecimalType(38, 10); + } + if (operator == ArithmeticExpr.Operator.DIVIDE) { + res = ScalarType.createDecimalType(38, 4); + } + arithmeticExpr.analyzeImpl(null); + Assert.assertEquals(arithmeticExpr.type, res); + + t1 = ScalarType.createDecimalType(9, 4); + t2 = ScalarType.createDecimalType(18, 5); + lhsExpr.setType(t1); + rhsExpr.setType(t2); + arithmeticExpr = new ArithmeticExpr(operator, lhsExpr, rhsExpr); + res = ScalarType.createDecimalType(18, 5); + if (operator == ArithmeticExpr.Operator.MULTIPLY) { + res = ScalarType.createDecimalType(18, 9); + } + if (operator == ArithmeticExpr.Operator.DIVIDE) { + res = ScalarType.createDecimalType(18, 4); + } + arithmeticExpr.analyzeImpl(null); + Assert.assertEquals(arithmeticExpr.type, res); + + t1 = ScalarType.createDecimalType(9, 4); + t2 = ScalarType.createType(PrimitiveType.BIGINT); + lhsExpr.setType(t1); + rhsExpr.setType(t2); + arithmeticExpr = new ArithmeticExpr(operator, lhsExpr, rhsExpr); + res = ScalarType.createDecimalType(18, 4); + arithmeticExpr.analyzeImpl(null); + Assert.assertEquals(arithmeticExpr.type, res); + + t1 = ScalarType.createDecimalType(9, 4); + t2 = ScalarType.createType(PrimitiveType.LARGEINT); + lhsExpr.setType(t1); + rhsExpr.setType(t2); + arithmeticExpr = new ArithmeticExpr(operator, lhsExpr, rhsExpr); + res = ScalarType.createDecimalType(38, 4); + arithmeticExpr.analyzeImpl(null); + Assert.assertEquals(arithmeticExpr.type, res); + + t1 = ScalarType.createDecimalType(9, 4); + t2 = ScalarType.createType(PrimitiveType.INT); + lhsExpr.setType(t1); + rhsExpr.setType(t2); + arithmeticExpr = new ArithmeticExpr(operator, lhsExpr, rhsExpr); + res = ScalarType.createDecimalType(9, 4); + arithmeticExpr.analyzeImpl(null); + Assert.assertEquals(arithmeticExpr.type, res); + + t1 = ScalarType.createDecimalType(9, 4); + t2 = ScalarType.createType(PrimitiveType.FLOAT); + lhsExpr.setType(t1); + rhsExpr.setType(t2); + arithmeticExpr = new ArithmeticExpr(operator, lhsExpr, rhsExpr); + res = ScalarType.createType(PrimitiveType.DOUBLE); + arithmeticExpr.analyzeImpl(null); + Assert.assertEquals(arithmeticExpr.type, res); + + t1 = ScalarType.createDecimalType(9, 4); + t2 = ScalarType.createType(PrimitiveType.DOUBLE); + lhsExpr.setType(t1); + rhsExpr.setType(t2); + arithmeticExpr = new ArithmeticExpr(operator, lhsExpr, rhsExpr); + res = ScalarType.createType(PrimitiveType.DOUBLE); + arithmeticExpr.analyzeImpl(null); + Assert.assertEquals(arithmeticExpr.type, res); + } + } catch (AnalysisException e) { + e.printStackTrace(); + hasException = true; + } + Assert.assertFalse(hasException); + } + + @Test + public void testDecimalBitOperation(@Mocked VectorizedUtil vectorizedUtil) { + Expr lhsExpr = new SlotRef(new TableName(internalCtl, "db", "table"), "c0"); + Expr rhsExpr = new SlotRef(new TableName(internalCtl, "db", "table"), "c1"); + ScalarType t1; + ScalarType t2; + ScalarType res; + ArithmeticExpr arithmeticExpr; + boolean hasException = false; + new Expectations() { + { + vectorizedUtil.isVectorized(); + result = true; + } + }; + List operators = Arrays.asList(ArithmeticExpr.Operator.BITAND, + ArithmeticExpr.Operator.BITOR, ArithmeticExpr.Operator.BITXOR); + try { + for (ArithmeticExpr.Operator operator : operators) { + t1 = ScalarType.createDecimalType(9, 4); + t2 = ScalarType.createDecimalType(19, 6); + lhsExpr.setType(t1); + rhsExpr.setType(t2); + arithmeticExpr = new ArithmeticExpr(operator, lhsExpr, rhsExpr); + res = ScalarType.createType(PrimitiveType.BIGINT); + arithmeticExpr.analyzeImpl(null); + Assert.assertEquals(arithmeticExpr.type, res); + Assert.assertTrue(arithmeticExpr.getChild(0) instanceof CastExpr); + Assert.assertTrue(arithmeticExpr.getChild(1) instanceof CastExpr); + } + } catch (AnalysisException e) { + e.printStackTrace(); + hasException = true; + } + Assert.assertFalse(hasException); + } +} diff --git a/fe/fe-core/src/test/java/org/apache/doris/analysis/CreateMaterializedViewStmtTest.java b/fe/fe-core/src/test/java/org/apache/doris/analysis/CreateMaterializedViewStmtTest.java index 0ab19f5d1e..ce4a9c029d 100644 --- a/fe/fe-core/src/test/java/org/apache/doris/analysis/CreateMaterializedViewStmtTest.java +++ b/fe/fe-core/src/test/java/org/apache/doris/analysis/CreateMaterializedViewStmtTest.java @@ -1255,7 +1255,7 @@ public class CreateMaterializedViewStmtTest { slotDescriptor2.getColumn(); result = column2; column2.getOriginType(); - result = ScalarType.createDecimalV2Type(10, 1); + result = ScalarType.createDecimalType(10, 1); } }; MVColumnItem mvColumnItem2 = Deencapsulation.invoke(createMaterializedViewStmt, "buildMVColumnItem", functionCallExpr2); diff --git a/fe/fe-core/src/test/java/org/apache/doris/analysis/DecimalLiteralTest.java b/fe/fe-core/src/test/java/org/apache/doris/analysis/DecimalLiteralTest.java index 4d4f0f3c24..3fa2dfe4d1 100644 --- a/fe/fe-core/src/test/java/org/apache/doris/analysis/DecimalLiteralTest.java +++ b/fe/fe-core/src/test/java/org/apache/doris/analysis/DecimalLiteralTest.java @@ -18,6 +18,7 @@ package org.apache.doris.analysis; import org.apache.doris.catalog.PrimitiveType; +import org.apache.doris.catalog.ScalarType; import org.apache.doris.catalog.Type; import org.apache.doris.common.AnalysisException; @@ -48,4 +49,35 @@ public class DecimalLiteralTest { Assert.assertEquals(1, literal.compareLiteral(new NullLiteral())); } + + @Test + public void testTypePrecision() { + BigDecimal decimal = new BigDecimal("-123456789123456789.123456789"); + DecimalLiteral literal = new DecimalLiteral(decimal); + int precision = ((ScalarType) literal.getType()).getScalarPrecision(); + int scale = ((ScalarType) literal.getType()).getScalarScale(); + Assert.assertEquals(27, precision); + Assert.assertEquals(9, scale); + + decimal = new BigDecimal("-0.00123"); + literal = new DecimalLiteral(decimal); + precision = ((ScalarType) literal.getType()).getScalarPrecision(); + scale = ((ScalarType) literal.getType()).getScalarScale(); + Assert.assertEquals(5, precision); + Assert.assertEquals(5, scale); + + decimal = new BigDecimal("20000"); + literal = new DecimalLiteral(decimal); + precision = ((ScalarType) literal.getType()).getScalarPrecision(); + scale = ((ScalarType) literal.getType()).getScalarScale(); + Assert.assertEquals(5, precision); + Assert.assertEquals(0, scale); + + decimal = new BigDecimal("0.123"); + literal = new DecimalLiteral(decimal); + precision = ((ScalarType) literal.getType()).getScalarPrecision(); + scale = ((ScalarType) literal.getType()).getScalarScale(); + Assert.assertEquals(3, precision); + Assert.assertEquals(3, scale); + } } diff --git a/fe/fe-core/src/test/java/org/apache/doris/analysis/FunctionCallExprTest.java b/fe/fe-core/src/test/java/org/apache/doris/analysis/FunctionCallExprTest.java new file mode 100644 index 0000000000..6d8cbe7433 --- /dev/null +++ b/fe/fe-core/src/test/java/org/apache/doris/analysis/FunctionCallExprTest.java @@ -0,0 +1,82 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +package org.apache.doris.analysis; + +import org.apache.doris.catalog.ScalarType; +import org.apache.doris.catalog.Type; +import org.apache.doris.common.AnalysisException; +import org.apache.doris.datasource.InternalDataSource; + +import com.google.common.collect.ImmutableList; +import mockit.Mock; +import mockit.MockUp; +import mockit.Mocked; +import org.junit.Assert; +import org.junit.Test; + +import java.util.Arrays; + +public class FunctionCallExprTest { + private static final String internalCtl = InternalDataSource.INTERNAL_DS_NAME; + + @Test + public void testDecimalFunction(@Mocked Analyzer analyzer) throws AnalysisException { + new MockUp(SlotRef.class) { + @Mock + public void analyzeImpl(Analyzer analyzer) throws AnalysisException { + return; + } + }; + Expr argExpr = new SlotRef(new TableName(internalCtl, "db", "table"), "c0"); + FunctionCallExpr functionCallExpr; + boolean hasException = false; + Type res; + ImmutableList sameTypeFunction = ImmutableList.builder() + .add("min").add("max").add("lead").add("lag") + .add("first_value").add("last_value").add("abs") + .add("positive").add("negative").build(); + ImmutableList widerTypeFunction = ImmutableList.builder() + .add("sum").add("avg").add("multi_distinct_sum").build(); + try { + for (String func : sameTypeFunction) { + Type argType = ScalarType.createDecimalType(9, 4); + argExpr.setType(argType); + functionCallExpr = new FunctionCallExpr(func, Arrays.asList(argExpr)); + functionCallExpr.setIsAnalyticFnCall(true); + res = ScalarType.createDecimalType(9, 4); + functionCallExpr.analyzeImpl(analyzer); + Assert.assertEquals(functionCallExpr.type, res); + } + + for (String func : widerTypeFunction) { + Type argType = ScalarType.createDecimalType(9, 4); + argExpr.setType(argType); + functionCallExpr = new FunctionCallExpr(func, Arrays.asList(argExpr)); + res = ScalarType.createDecimalType(38, 4); + functionCallExpr.analyzeImpl(analyzer); + Assert.assertEquals(functionCallExpr.type, res); + } + + } catch (AnalysisException e) { + e.printStackTrace(); + hasException = true; + } + Assert.assertFalse(hasException); + } + +} diff --git a/fe/fe-core/src/test/java/org/apache/doris/backup/CatalogMocker.java b/fe/fe-core/src/test/java/org/apache/doris/backup/CatalogMocker.java index 94e6f7cbe9..a04377cd8c 100644 --- a/fe/fe-core/src/test/java/org/apache/doris/backup/CatalogMocker.java +++ b/fe/fe-core/src/test/java/org/apache/doris/backup/CatalogMocker.java @@ -154,7 +154,7 @@ public class CatalogMocker { Column k5 = new Column("k5", ScalarType.createType(PrimitiveType.LARGEINT), true, null, "", "key5"); Column k6 = new Column("k6", ScalarType.createType(PrimitiveType.DATE), true, null, "", "key6"); Column k7 = new Column("k7", ScalarType.createType(PrimitiveType.DATETIME), true, null, "", "key7"); - Column k8 = new Column("k8", ScalarType.createDecimalV2Type(10, 3), true, null, "", "key8"); + Column k8 = new Column("k8", ScalarType.createDecimalType(10, 3), true, null, "", "key8"); k1.setIsKey(true); k2.setIsKey(true); k3.setIsKey(true); diff --git a/fe/fe-core/src/test/java/org/apache/doris/catalog/ColumnGsonSerializationTest.java b/fe/fe-core/src/test/java/org/apache/doris/catalog/ColumnGsonSerializationTest.java index 9b0bf33a52..fae2d6b274 100644 --- a/fe/fe-core/src/test/java/org/apache/doris/catalog/ColumnGsonSerializationTest.java +++ b/fe/fe-core/src/test/java/org/apache/doris/catalog/ColumnGsonSerializationTest.java @@ -96,7 +96,7 @@ public class ColumnGsonSerializationTest { Column c1 = new Column("c1", Type.fromPrimitiveType(PrimitiveType.BIGINT), true, null, true, "1", "abc"); Column c2 = new Column("c2", ScalarType.createType(PrimitiveType.VARCHAR, 32, -1, -1), true, null, true, "cmy", ""); - Column c3 = new Column("c3", ScalarType.createDecimalV2Type(27, 9), false, AggregateType.SUM, false, "1.1", "decimalv2"); + Column c3 = new Column("c3", ScalarType.createDecimalType(27, 9), false, AggregateType.SUM, false, "1.1", "decimalv2"); ColumnList columnList = new ColumnList(); columnList.columns.add(c1); diff --git a/fe/fe-core/src/test/java/org/apache/doris/catalog/ColumnTypeTest.java b/fe/fe-core/src/test/java/org/apache/doris/catalog/ColumnTypeTest.java index 9130a540f6..98cb7b3b92 100644 --- a/fe/fe-core/src/test/java/org/apache/doris/catalog/ColumnTypeTest.java +++ b/fe/fe-core/src/test/java/org/apache/doris/catalog/ColumnTypeTest.java @@ -217,10 +217,10 @@ public class ColumnTypeTest { ScalarType type2 = ScalarType.createType(PrimitiveType.BIGINT); ColumnType.write(dos, type2); - ScalarType type3 = ScalarType.createDecimalV2Type(1, 1); + ScalarType type3 = ScalarType.createDecimalType(1, 1); ColumnType.write(dos, type3); - ScalarType type4 = ScalarType.createDecimalV2Type(1, 1); + ScalarType type4 = ScalarType.createDecimalType(1, 1); ColumnType.write(dos, type4); // 2. Read objects from file diff --git a/fe/java-udf/src/main/java/org/apache/doris/udf/UdfUtils.java b/fe/java-udf/src/main/java/org/apache/doris/udf/UdfUtils.java index 616416ef3a..537aa9aa0a 100644 --- a/fe/java-udf/src/main/java/org/apache/doris/udf/UdfUtils.java +++ b/fe/java-udf/src/main/java/org/apache/doris/udf/UdfUtils.java @@ -76,7 +76,7 @@ public class UdfUtils { } else if (scalarType.getType() == TPrimitiveType.DECIMALV2) { Preconditions.checkState(scalarType.isSetPrecision() && scalarType.isSetScale()); - type = ScalarType.createDecimalV2Type(scalarType.getPrecision(), + type = ScalarType.createDecimalType(scalarType.getPrecision(), scalarType.getScale()); } else { type = ScalarType.createType( diff --git a/fe/spark-dpp/src/main/java/org/apache/doris/load/loadv2/dpp/ColumnParser.java b/fe/spark-dpp/src/main/java/org/apache/doris/load/loadv2/dpp/ColumnParser.java index fff375b6a5..2d15ab88a3 100644 --- a/fe/spark-dpp/src/main/java/org/apache/doris/load/loadv2/dpp/ColumnParser.java +++ b/fe/spark-dpp/src/main/java/org/apache/doris/load/loadv2/dpp/ColumnParser.java @@ -63,7 +63,10 @@ public abstract class ColumnParser implements Serializable { || columnType.equalsIgnoreCase("BITMAP") || columnType.equalsIgnoreCase("HLL")) { return new StringParser(etlColumn); - } else if (columnType.equalsIgnoreCase("DECIMALV2")) { + } else if (columnType.equalsIgnoreCase("DECIMALV2") + || columnType.equalsIgnoreCase("DECIMAL32") + || columnType.equalsIgnoreCase("DECIMAL64") + || columnType.equalsIgnoreCase("DECIMAL128")) { return new DecimalParser(etlColumn); } else if (columnType.equalsIgnoreCase("LARGEINT")) { return new LargeIntParser(); diff --git a/fe/spark-dpp/src/main/java/org/apache/doris/load/loadv2/dpp/DppUtils.java b/fe/spark-dpp/src/main/java/org/apache/doris/load/loadv2/dpp/DppUtils.java index caf79c1a2b..c53e1ae087 100644 --- a/fe/spark-dpp/src/main/java/org/apache/doris/load/loadv2/dpp/DppUtils.java +++ b/fe/spark-dpp/src/main/java/org/apache/doris/load/loadv2/dpp/DppUtils.java @@ -98,6 +98,9 @@ public class DppUtils { case "OBJECT": return String.class; case "DECIMALV2": + case "DECIMAL32": + case "DECIMAL64": + case "DECIMAL128": return BigDecimal.valueOf(column.precision, column.scale).getClass(); default: return String.class; @@ -147,6 +150,9 @@ public class DppUtils { dataType = regardDistinctColumnAsBinary ? DataTypes.BinaryType : DataTypes.StringType; break; case "DECIMALV2": + case "DECIMAL32": + case "DECIMAL64": + case "DECIMAL128": dataType = DecimalType.apply(column.precision, column.scale); break; default: diff --git a/fe/spark-dpp/src/main/java/org/apache/doris/load/loadv2/dpp/SparkDpp.java b/fe/spark-dpp/src/main/java/org/apache/doris/load/loadv2/dpp/SparkDpp.java index ed1cbe73a7..ab7e791a1c 100644 --- a/fe/spark-dpp/src/main/java/org/apache/doris/load/loadv2/dpp/SparkDpp.java +++ b/fe/spark-dpp/src/main/java/org/apache/doris/load/loadv2/dpp/SparkDpp.java @@ -379,6 +379,9 @@ public final class SparkDpp implements java.io.Serializable { switch (etlColumn.columnType.toUpperCase()) { case "DECIMALV2": + case "DECIMAL32": + case "DECIMAL64": + case "DECIMAL128": // TODO(wb): support decimal round; see be DecimalV2Value::round DecimalParser decimalParser = (DecimalParser) columnParser; BigDecimal srcBigDecimal = (BigDecimal) srcValue; diff --git a/fe/spark-dpp/src/main/java/org/apache/doris/load/loadv2/dpp/SparkRDDAggregator.java b/fe/spark-dpp/src/main/java/org/apache/doris/load/loadv2/dpp/SparkRDDAggregator.java index 9e3f31dfc7..3e374d9938 100644 --- a/fe/spark-dpp/src/main/java/org/apache/doris/load/loadv2/dpp/SparkRDDAggregator.java +++ b/fe/spark-dpp/src/main/java/org/apache/doris/load/loadv2/dpp/SparkRDDAggregator.java @@ -73,6 +73,9 @@ public abstract class SparkRDDAggregator implements Serializable { case "float": case "double": case "decimalv2": + case "decimal32": + case "decimal64": + case "decimal128": case "date": case "datetime": case "datev2": @@ -96,6 +99,9 @@ public abstract class SparkRDDAggregator implements Serializable { case "float": case "double": case "decimalv2": + case "decimal32": + case "decimal64": + case "decimal128": case "date": case "datetime": case "datev2": @@ -127,6 +133,9 @@ public abstract class SparkRDDAggregator implements Serializable { case "largeint": return new LargeIntSumAggregator(); case "decimalv2": + case "decimal32": + case "decimal64": + case "decimal128": return new BigDecimalSumAggregator(); default: throw new SparkDppException( diff --git a/gensrc/proto/internal_service.proto b/gensrc/proto/internal_service.proto index f7013880d4..28fb5c4aec 100644 --- a/gensrc/proto/internal_service.proto +++ b/gensrc/proto/internal_service.proto @@ -359,6 +359,8 @@ message PColumnValue { optional int64 longVal = 3; optional double doubleVal = 4; optional bytes stringVal = 5; + optional int32 precision = 6; + optional int32 scale = 7; } // TODO: CHECK ALL TYPE @@ -379,6 +381,9 @@ enum PColumnType { COLUMN_TYPE_DECIMALV2 = 13; COLUMN_TYPE_STRING = 14; COLUMN_TYPE_DATEV2 = 15; + COLUMN_TYPE_DECIMAL32 = 16; + COLUMN_TYPE_DECIMAL64 = 17; + COLUMN_TYPE_DECIMAL128 = 18; } message PMinMaxFilter { diff --git a/gensrc/thrift/Types.thrift b/gensrc/thrift/Types.thrift index a7212e0476..3913f30376 100644 --- a/gensrc/thrift/Types.thrift +++ b/gensrc/thrift/Types.thrift @@ -87,6 +87,9 @@ enum TPrimitiveType { DATEV2, DATETIMEV2, TIMEV2, + DECIMAL32, + DECIMAL64, + DECIMAL128, } enum TTypeNodeType {