diff --git a/be/src/vec/columns/column.h b/be/src/vec/columns/column.h index b0de27bb6e..c7e9e6f1e1 100644 --- a/be/src/vec/columns/column.h +++ b/be/src/vec/columns/column.h @@ -215,6 +215,7 @@ public: /// Appends range of elements from other column with the same type. /// Could be used to concatenate columns. + /// TODO: we need `insert_range_from_const` for every column type. virtual void insert_range_from(const IColumn& src, size_t start, size_t length) = 0; /// Appends one element from other column with the same type multiple times. diff --git a/be/src/vec/columns/column_const.cpp b/be/src/vec/columns/column_const.cpp index 7b7fc4d753..96d0c013b1 100644 --- a/be/src/vec/columns/column_const.cpp +++ b/be/src/vec/columns/column_const.cpp @@ -20,6 +20,8 @@ #include "vec/columns/column_const.h" +#include +#include #include #include "gutil/port.h" @@ -209,4 +211,37 @@ std::pair check_column_const_set_readability(const IColumn& c } return result; } + +std::pair unpack_if_const(const ColumnPtr& ptr) noexcept { + if (is_column_const(*ptr)) { + return std::make_pair( + std::cref(static_cast(*ptr).get_data_column_ptr()), true); + } + return std::make_pair(std::cref(ptr), false); +} + +void default_preprocess_parameter_columns(ColumnPtr* columns, const bool* col_const, + const std::initializer_list& parameters, + Block& block, const ColumnNumbers& arg_indexes) noexcept { + if (std::all_of(parameters.begin(), parameters.end(), + [&](size_t const_index) -> bool { return col_const[const_index]; })) { + // only need to avoid expanding when all parameters are const + for (auto index : parameters) { + columns[index] = static_cast( + *block.get_by_position(arg_indexes[index]).column) + .get_data_column_ptr(); + } + } else { + // no need to avoid expanding for this rare situation + for (auto index : parameters) { + if (col_const[index]) { + columns[index] = static_cast( + *block.get_by_position(arg_indexes[index]).column) + .convert_to_full_column(); + } else { + columns[index] = block.get_by_position(arg_indexes[index]).column; + } + } + } +} } // namespace doris::vectorized diff --git a/be/src/vec/columns/column_const.h b/be/src/vec/columns/column_const.h index 357f10dcee..3ab11d6768 100644 --- a/be/src/vec/columns/column_const.h +++ b/be/src/vec/columns/column_const.h @@ -21,11 +21,15 @@ #pragma once #include +#include +#include #include "vec/columns/column.h" #include "vec/columns/column_nullable.h" #include "vec/common/assert_cast.h" #include "vec/common/typeid_cast.h" +#include "vec/core/block.h" +#include "vec/core/column_numbers.h" #include "vec/core/field.h" namespace doris::vectorized { @@ -234,4 +238,29 @@ public: */ std::pair check_column_const_set_readability(const IColumn& column, const size_t row_num) noexcept; + +/* + * @warning use this function sometimes cause performance problem in GCC. +*/ +template , T> = 0> +T index_check_const(T arg, bool constancy) noexcept { + return constancy ? 0 : arg; +} + +/* + * @return first : data_column_ptr for ColumnConst, itself otherwise. + * second : whether it's ColumnConst. +*/ +std::pair unpack_if_const(const ColumnPtr&) noexcept; + +/* + * For the functions that some columns of arguments are almost but not completely always const, we use this function to preprocessing its parameter columns + * (which are not data columns). When we have two or more columns which only provide parameter, use this to deal with corner case. So you can specialize you + * implementations for all const or all parameters const, without considering some of parameters are const. + + * Do the transformation only for the columns whose arg_indexes in parameters. +*/ +void default_preprocess_parameter_columns(ColumnPtr* columns, const bool* col_const, + const std::initializer_list& parameters, + Block& block, const ColumnNumbers& arg_indexes) noexcept; } // namespace doris::vectorized diff --git a/be/src/vec/columns/column_nullable.cpp b/be/src/vec/columns/column_nullable.cpp index 92fb114a4e..64db3e6455 100644 --- a/be/src/vec/columns/column_nullable.cpp +++ b/be/src/vec/columns/column_nullable.cpp @@ -27,6 +27,7 @@ #include "vec/common/nan_utils.h" #include "vec/common/typeid_cast.h" #include "vec/core/sort_block.h" +#include "vec/utils/util.hpp" namespace doris::vectorized { @@ -633,4 +634,14 @@ ColumnPtr ColumnNullable::index(const IColumn& indexes, size_t limit) const { return ColumnNullable::create(indexed_data, indexed_null_map); } +void check_set_nullable(ColumnPtr& argument_column, ColumnVector::MutablePtr& null_map) { + if (auto* nullable = check_and_get_column(*argument_column)) { + // Danger: Here must dispose the null map data first! Because + // argument_columns[i]=nullable->get_nested_column_ptr(); will release the mem + // of column nullable mem of null map + VectorizedUtils::update_null_map(null_map->get_data(), nullable->get_null_map_data()); + argument_column = nullable->get_nested_column_ptr(); + } +} + } // namespace doris::vectorized diff --git a/be/src/vec/columns/column_nullable.h b/be/src/vec/columns/column_nullable.h index 0610b139eb..05b0516cda 100644 --- a/be/src/vec/columns/column_nullable.h +++ b/be/src/vec/columns/column_nullable.h @@ -20,6 +20,8 @@ #pragma once +#include "vec/columns/column_vector.h" +#include "vec/core/types.h" #ifdef __aarch64__ #include #endif @@ -368,5 +370,7 @@ private: ColumnPtr make_nullable(const ColumnPtr& column, bool is_nullable = false); ColumnPtr remove_nullable(const ColumnPtr& column); - +// check if argument column is nullable. If so, extract its concrete column and set null_map. +//TODO: use this to replace inner usages. +void check_set_nullable(ColumnPtr&, ColumnVector::MutablePtr&); } // namespace doris::vectorized diff --git a/be/src/vec/common/sort/sorter.cpp b/be/src/vec/common/sort/sorter.cpp index c8eebed652..b5315a1489 100644 --- a/be/src/vec/common/sort/sorter.cpp +++ b/be/src/vec/common/sort/sorter.cpp @@ -306,8 +306,9 @@ Status FullSorter::append_block(Block* block) { << " type1: " << data[i].type->get_name() << " type2: " << arrival_data[i].type->get_name(); try { + //TODO: to eliminate unnecessary expansion, we need a `insert_range_from_const` for every column type. RETURN_IF_CATCH_BAD_ALLOC(data[i].column->assume_mutable()->insert_range_from( - *arrival_data[i].column->convert_to_full_column_if_const().get(), 0, sz)); + *arrival_data[i].column->convert_to_full_column_if_const(), 0, sz)); } catch (const doris::Exception& e) { return Status::Error(e.code(), e.to_string()); } diff --git a/be/src/vec/common/string_ref.cpp b/be/src/vec/common/string_ref.cpp index 20cfbad708..d292b8160f 100644 --- a/be/src/vec/common/string_ref.cpp +++ b/be/src/vec/common/string_ref.cpp @@ -20,6 +20,8 @@ #include "string_ref.h" +#include "common/compiler_util.h" + namespace doris { StringRef StringRef::trim() const { @@ -53,6 +55,19 @@ StringRef StringRef::max_string_val() { return StringRef((char*)(&StringRef::MAX_CHAR), 1); } +bool StringRef::start_with(char ch) const { + if (UNLIKELY(size == 0)) { + return false; + } + return data[0] == ch; +} +bool StringRef::end_with(char ch) const { + if (UNLIKELY(size == 0)) { + return false; + } + return data[size - 1] == ch; +} + bool StringRef::start_with(const StringRef& search_string) const { DCHECK(size >= search_string.size); if (search_string.size == 0) { diff --git a/be/src/vec/common/string_ref.h b/be/src/vec/common/string_ref.h index 7561b33c3c..9ecb1cfcf0 100644 --- a/be/src/vec/common/string_ref.h +++ b/be/src/vec/common/string_ref.h @@ -225,6 +225,12 @@ struct StringRef { StringRef substring(int start_pos) const { return substring(start_pos, size - start_pos); } + const char* begin() const { return data; } + const char* end() const { return data + size; } + // there's no border check in functions below. That's same with STL. + char front() const { return *data; } + char back() const { return *(data + size - 1); } + // Trims leading and trailing spaces. StringRef trim() const; @@ -237,6 +243,8 @@ struct StringRef { static StringRef min_string_val(); static StringRef max_string_val(); + bool start_with(char) const; + bool end_with(char) const; bool start_with(const StringRef& search_string) const; bool end_with(const StringRef& search_string) const; diff --git a/be/src/vec/functions/array/function_array_apply.cpp b/be/src/vec/functions/array/function_array_apply.cpp index 723dfc143a..bb81f04fda 100644 --- a/be/src/vec/functions/array/function_array_apply.cpp +++ b/be/src/vec/functions/array/function_array_apply.cpp @@ -37,6 +37,7 @@ public: String get_name() const override { return name; } size_t get_number_of_arguments() const override { return 3; } + ColumnNumbers get_arguments_that_are_always_constant() const override { return {1, 2}; } DataTypePtr get_return_type_impl(const DataTypes& arguments) const override { DCHECK(is_array(arguments[0])) @@ -63,10 +64,7 @@ public: auto nested_type = assert_cast(*src_column_type).get_nested_type(); const std::string& condition = block.get_by_position(arguments[1]).column->get_data_at(0).to_string(); - if (!is_column_const(*block.get_by_position(arguments[2]).column)) { - return Status::RuntimeError( - "execute failed or unsupported column, only support const column"); - } + const ColumnConst& rhs_value_column = static_cast(*block.get_by_position(arguments[2]).column.get()); ColumnPtr result_ptr; @@ -137,7 +135,9 @@ private: size_t out_pos = 0; for (size_t i = 0; i < src_offsets.size(); ++i) { for (; in_pos < src_offsets[i]; ++in_pos) { - if (filter[in_pos]) ++out_pos; + if (filter[in_pos]) { + ++out_pos; + } } dst_offsets[i] = out_pos; } diff --git a/be/src/vec/functions/function_bitmap.cpp b/be/src/vec/functions/function_bitmap.cpp index 01991f1b2f..d6cf1652db 100644 --- a/be/src/vec/functions/function_bitmap.cpp +++ b/be/src/vec/functions/function_bitmap.cpp @@ -444,13 +444,26 @@ struct BitmapNot { using T1 = typename RightDataType::FieldType; using TData = std::vector; - static Status vector_vector(const TData& lvec, const TData& rvec, TData& res) { + static void vector_vector(const TData& lvec, const TData& rvec, TData& res) { size_t size = lvec.size(); for (size_t i = 0; i < size; ++i) { res[i] = lvec[i]; res[i] -= rvec[i]; } - return Status::OK(); + } + static void vector_scalar(const TData& lvec, const BitmapValue& rval, TData& res) { + size_t size = lvec.size(); + for (size_t i = 0; i < size; ++i) { + res[i] = lvec[i]; + res[i] -= rval; + } + } + static void scalar_vector(const BitmapValue& lval, const TData& rvec, TData& res) { + size_t size = rvec.size(); + for (size_t i = 0; i < size; ++i) { + res[i] = lval; + res[i] -= rvec[i]; + } } }; @@ -465,7 +478,7 @@ struct BitmapAndNot { using T1 = typename RightDataType::FieldType; using TData = std::vector; - static Status vector_vector(const TData& lvec, const TData& rvec, TData& res) { + static void vector_vector(const TData& lvec, const TData& rvec, TData& res) { size_t size = lvec.size(); BitmapValue mid_data; for (size_t i = 0; i < size; ++i) { @@ -475,7 +488,28 @@ struct BitmapAndNot { res[i] -= mid_data; mid_data.clear(); } - return Status::OK(); + } + static void vector_scalar(const TData& lvec, const BitmapValue& rval, TData& res) { + size_t size = lvec.size(); + BitmapValue mid_data; + for (size_t i = 0; i < size; ++i) { + mid_data = lvec[i]; + mid_data &= rval; + res[i] = lvec[i]; + res[i] -= mid_data; + mid_data.clear(); + } + } + static void scalar_vector(const BitmapValue& lval, const TData& rvec, TData& res) { + size_t size = rvec.size(); + BitmapValue mid_data; + for (size_t i = 0; i < size; ++i) { + mid_data = lval; + mid_data &= rvec[i]; + res[i] = lval; + res[i] -= mid_data; + mid_data.clear(); + } } }; @@ -491,7 +525,7 @@ struct BitmapAndNotCount { using TData = std::vector; using ResTData = typename ColumnVector::Container::value_type; - static Status vector_vector(const TData& lvec, const TData& rvec, ResTData* res) { + static void vector_vector(const TData& lvec, const TData& rvec, ResTData* res) { size_t size = lvec.size(); BitmapValue mid_data; for (size_t i = 0; i < size; ++i) { @@ -500,7 +534,26 @@ struct BitmapAndNotCount { res[i] = lvec[i].andnot_cardinality(mid_data); mid_data.clear(); } - return Status::OK(); + } + static void scalar_vector(const BitmapValue& lval, const TData& rvec, ResTData* res) { + size_t size = rvec.size(); + BitmapValue mid_data; + for (size_t i = 0; i < size; ++i) { + mid_data = lval; + mid_data &= rvec[i]; + res[i] = lval.andnot_cardinality(mid_data); + mid_data.clear(); + } + } + static void vector_scalar(const TData& lvec, const BitmapValue& rval, ResTData* res) { + size_t size = lvec.size(); + BitmapValue mid_data; + for (size_t i = 0; i < size; ++i) { + mid_data = lvec[i]; + mid_data &= rval; + res[i] = lvec[i].andnot_cardinality(mid_data); + mid_data.clear(); + } } }; @@ -622,11 +675,6 @@ public: Status execute_impl_internal(FunctionContext* context, Block& block, const ColumnNumbers& arguments, size_t result, size_t input_rows_count) { - const auto& left = block.get_by_position(arguments[0]); - auto lcol = left.column->convert_to_full_column_if_const(); - const auto& right = block.get_by_position(arguments[1]); - auto rcol = right.column->convert_to_full_column_if_const(); - using ResultType = typename ResultDataType::FieldType; using ColVecResult = ColumnVector; @@ -634,10 +682,29 @@ public: auto& vec_res = col_res->get_data(); vec_res.resize(block.rows()); - const ColumnBitmap* l_bitmap_col = assert_cast(lcol.get()); - const ColumnBitmap* r_bitmap_col = assert_cast(rcol.get()); - BitmapAndNotCount::vector_vector( - l_bitmap_col->get_data(), r_bitmap_col->get_data(), vec_res.data()); + const auto& left = block.get_by_position(arguments[0]); + auto lcol = left.column; + const auto& right = block.get_by_position(arguments[1]); + auto rcol = right.column; + + if (is_column_const(*left.column)) { + BitmapAndNotCount::scalar_vector( + assert_cast( + assert_cast(lcol.get())->get_data_column()) + .get_data()[0], + assert_cast(rcol.get())->get_data(), vec_res.data()); + } else if (is_column_const(*right.column)) { + BitmapAndNotCount::vector_scalar( + assert_cast(lcol.get())->get_data(), + assert_cast( + assert_cast(rcol.get())->get_data_column()) + .get_data()[0], + vec_res.data()); + } else { + BitmapAndNotCount::vector_vector( + assert_cast(lcol.get())->get_data(), + assert_cast(rcol.get())->get_data(), vec_res.data()); + } auto& result_info = block.get_by_position(result); if (result_info.type->is_nullable()) { @@ -664,12 +731,23 @@ struct BitmapContains { using RTData = typename ColumnVector::Container; using ResTData = typename ColumnVector::Container; - static Status vector_vector(const LTData& lvec, const RTData& rvec, ResTData& res) { + static void vector_vector(const LTData& lvec, const RTData& rvec, ResTData& res) { size_t size = lvec.size(); for (size_t i = 0; i < size; ++i) { res[i] = lvec[i].contains(rvec[i]); } - return Status::OK(); + } + static void vector_scalar(const LTData& lvec, const T1& rval, ResTData& res) { + size_t size = lvec.size(); + for (size_t i = 0; i < size; ++i) { + res[i] = lvec[i].contains(rval); + } + } + static void scalar_vector(const BitmapValue& lval, const RTData& rvec, ResTData& res) { + size_t size = rvec.size(); + for (size_t i = 0; i < size; ++i) { + res[i] = lval.contains(rvec[i]); + } } }; @@ -685,14 +763,29 @@ struct BitmapHasAny { using TData = std::vector; using ResTData = typename ColumnVector::Container; - static Status vector_vector(const TData& lvec, const TData& rvec, ResTData& res) { + static void vector_vector(const TData& lvec, const TData& rvec, ResTData& res) { size_t size = lvec.size(); for (size_t i = 0; i < size; ++i) { auto bitmap = const_cast(lvec[i]); bitmap &= rvec[i]; res[i] = bitmap.cardinality() != 0; } - return Status::OK(); + } + static void vector_scalar(const TData& lvec, const BitmapValue& rval, ResTData& res) { + size_t size = lvec.size(); + for (size_t i = 0; i < size; ++i) { + auto bitmap = const_cast(lvec[i]); + bitmap &= rval; + res[i] = bitmap.cardinality() != 0; + } + } + static void scalar_vector(const BitmapValue& lval, const TData& rvec, ResTData& res) { + size_t size = rvec.size(); + for (size_t i = 0; i < size; ++i) { + auto bitmap = const_cast(lval); + bitmap &= rvec[i]; + res[i] = bitmap.cardinality() != 0; + } } }; @@ -708,7 +801,7 @@ struct BitmapHasAll { using TData = std::vector; using ResTData = typename ColumnVector::Container; - static Status vector_vector(const TData& lvec, const TData& rvec, ResTData& res) { + static void vector_vector(const TData& lvec, const TData& rvec, ResTData& res) { size_t size = lvec.size(); for (size_t i = 0; i < size; ++i) { uint64_t lhs_cardinality = lvec[i].cardinality(); @@ -716,7 +809,24 @@ struct BitmapHasAll { bitmap |= rvec[i]; res[i] = bitmap.cardinality() == lhs_cardinality; } - return Status::OK(); + } + static void vector_scalar(const TData& lvec, const BitmapValue& rval, ResTData& res) { + size_t size = lvec.size(); + for (size_t i = 0; i < size; ++i) { + uint64_t lhs_cardinality = lvec[i].cardinality(); + auto bitmap = const_cast(lvec[i]); + bitmap |= rval; + res[i] = bitmap.cardinality() == lhs_cardinality; + } + } + static void scalar_vector(const BitmapValue& lval, const TData& rvec, ResTData& res) { + size_t size = rvec.size(); + for (size_t i = 0; i < size; ++i) { + uint64_t lhs_cardinality = lval.cardinality(); + auto bitmap = const_cast(lval); + bitmap |= rvec[i]; + res[i] = bitmap.cardinality() == lhs_cardinality; + } } }; @@ -748,9 +858,9 @@ struct SubBitmap { using TData1 = std::vector; using TData2 = typename ColumnVector::Container; - static Status vector_vector(const TData1& bitmap_data, const TData2& offset_data, - const TData2& limit_data, NullMap& null_map, - size_t input_rows_count, TData1& res) { + static void vector3(const TData1& bitmap_data, const TData2& offset_data, + const TData2& limit_data, NullMap& null_map, size_t input_rows_count, + TData1& res) { for (int i = 0; i < input_rows_count; ++i) { if (null_map[i]) { continue; @@ -764,7 +874,23 @@ struct SubBitmap { null_map[i] = 1; } } - return Status::OK(); + } + static void vector_scalars(const TData1& bitmap_data, const Int64& offset_data, + const Int64& limit_data, NullMap& null_map, size_t input_rows_count, + TData1& res) { + for (int i = 0; i < input_rows_count; ++i) { + if (null_map[i]) { + continue; + } + if (limit_data <= 0) { + null_map[i] = 1; + continue; + } + if (const_cast(bitmap_data)[i].offset_limit(offset_data, limit_data, + &res[i]) == 0) { + null_map[i] = 1; + } + } } }; @@ -773,9 +899,9 @@ struct BitmapSubsetLimit { using TData1 = std::vector; using TData2 = typename ColumnVector::Container; - static Status vector_vector(const TData1& bitmap_data, const TData2& offset_data, - const TData2& limit_data, NullMap& null_map, - size_t input_rows_count, TData1& res) { + static void vector3(const TData1& bitmap_data, const TData2& offset_data, + const TData2& limit_data, NullMap& null_map, size_t input_rows_count, + TData1& res) { for (int i = 0; i < input_rows_count; ++i) { if (null_map[i]) { continue; @@ -786,7 +912,20 @@ struct BitmapSubsetLimit { } const_cast(bitmap_data)[i].sub_limit(offset_data[i], limit_data[i], &res[i]); } - return Status::OK(); + } + static void vector_scalars(const TData1& bitmap_data, const Int64& offset_data, + const Int64& limit_data, NullMap& null_map, size_t input_rows_count, + TData1& res) { + for (int i = 0; i < input_rows_count; ++i) { + if (null_map[i]) { + continue; + } + if (offset_data < 0 || limit_data < 0) { + null_map[i] = 1; + continue; + } + const_cast(bitmap_data)[i].sub_limit(offset_data, limit_data, &res[i]); + } } }; @@ -795,9 +934,9 @@ struct BitmapSubsetInRange { using TData1 = std::vector; using TData2 = typename ColumnVector::Container; - static Status vector_vector(const TData1& bitmap_data, const TData2& range_start, - const TData2& range_end, NullMap& null_map, size_t input_rows_count, - TData1& res) { + static void vector3(const TData1& bitmap_data, const TData2& range_start, + const TData2& range_end, NullMap& null_map, size_t input_rows_count, + TData1& res) { for (int i = 0; i < input_rows_count; ++i) { if (null_map[i]) { continue; @@ -808,7 +947,20 @@ struct BitmapSubsetInRange { } const_cast(bitmap_data)[i].sub_range(range_start[i], range_end[i], &res[i]); } - return Status::OK(); + } + static void vector_scalars(const TData1& bitmap_data, const Int64& range_start, + const Int64& range_end, NullMap& null_map, size_t input_rows_count, + TData1& res) { + for (int i = 0; i < input_rows_count; ++i) { + if (null_map[i]) { + continue; + } + if (range_start >= range_end || range_start < 0 || range_end < 0) { + null_map[i] = 1; + continue; + } + const_cast(bitmap_data)[i].sub_range(range_start, range_end, &res[i]); + } } }; @@ -835,25 +987,36 @@ public: DCHECK_EQ(arguments.size(), 3); auto res_null_map = ColumnUInt8::create(input_rows_count, 0); auto res_data_column = ColumnBitmap::create(input_rows_count); - ColumnPtr argument_columns[3]; + bool col_const[3]; + ColumnPtr argument_columns[3]; for (int i = 0; i < 3; ++i) { - argument_columns[i] = - block.get_by_position(arguments[i]).column->convert_to_full_column_if_const(); - if (auto* nullable = check_and_get_column(*argument_columns[i])) { - VectorizedUtils::update_null_map(res_null_map->get_data(), - nullable->get_null_map_data()); - argument_columns[i] = nullable->get_nested_column_ptr(); - } + col_const[i] = is_column_const(*block.get_by_position(arguments[i]).column); + } + argument_columns[0] = col_const[0] ? static_cast( + *block.get_by_position(arguments[0]).column) + .convert_to_full_column() + : block.get_by_position(arguments[0]).column; + + default_preprocess_parameter_columns(argument_columns, col_const, {1, 2}, block, arguments); + + for (int i = 0; i < 3; i++) { + check_set_nullable(argument_columns[i], res_null_map); } auto bitmap_column = assert_cast(argument_columns[0].get()); auto offset_column = assert_cast*>(argument_columns[1].get()); auto limit_column = assert_cast*>(argument_columns[2].get()); - Impl::vector_vector(bitmap_column->get_data(), offset_column->get_data(), - limit_column->get_data(), res_null_map->get_data(), input_rows_count, - res_data_column->get_data()); + if (col_const[1] && col_const[2]) { + Impl::vector_scalars(bitmap_column->get_data(), offset_column->get_element(0), + limit_column->get_element(0), res_null_map->get_data(), + input_rows_count, res_data_column->get_data()); + } else { + Impl::vector3(bitmap_column->get_data(), offset_column->get_data(), + limit_column->get_data(), res_null_map->get_data(), input_rows_count, + res_data_column->get_data()); + } block.get_by_position(result).column = ColumnNullable::create(std::move(res_data_column), std::move(res_null_map)); diff --git a/be/src/vec/functions/function_totype.h b/be/src/vec/functions/function_totype.h index bc16e066f8..344223a024 100644 --- a/be/src/vec/functions/function_totype.h +++ b/be/src/vec/functions/function_totype.h @@ -19,6 +19,8 @@ #include #include "vec/columns/column_complex.h" +#include "vec/columns/column_const.h" +#include "vec/columns/column_nullable.h" #include "vec/columns/column_string.h" #include "vec/columns/column_vector.h" #include "vec/data_types/data_type.h" @@ -143,10 +145,10 @@ public: Status execute_impl(FunctionContext* context, Block& block, const ColumnNumbers& arguments, size_t result, size_t /*input_rows_count*/) override { DCHECK_EQ(arguments.size(), 2); - const auto& left = block.get_by_position(arguments[0]); - auto lcol = left.column->convert_to_full_column_if_const(); - const auto& right = block.get_by_position(arguments[1]); - auto rcol = right.column->convert_to_full_column_if_const(); + const auto& [lcol, left_const] = + unpack_if_const(block.get_by_position(arguments[0]).column); + const auto& [rcol, right_const] = + unpack_if_const(block.get_by_position(arguments[1]).column); using ResultDataType = typename Impl::ResultDataType; @@ -171,8 +173,17 @@ public: if (auto col_left = check_and_get_column(lcol.get())) { if (auto col_right = check_and_get_column(rcol.get())) { - Impl::vector_vector(col_left->get_data(), - col_right->get_data(), vec_res); + if (left_const) { + Impl::scalar_vector( + col_left->get_data()[0], col_right->get_data(), vec_res); + } else if (right_const) { + Impl::vector_scalar( + col_left->get_data(), col_right->get_data()[0], vec_res); + } else { + Impl::vector_vector( + col_left->get_data(), col_right->get_data(), vec_res); + } + block.replace_by_position(result, std::move(col_res)); return Status::OK(); } diff --git a/be/src/vec/functions/least_greast.cpp b/be/src/vec/functions/least_greast.cpp index 072ba8cbd5..5125f9a5b6 100644 --- a/be/src/vec/functions/least_greast.cpp +++ b/be/src/vec/functions/least_greast.cpp @@ -15,8 +15,10 @@ // specific language governing permissions and limitations // under the License. +#include "vec/columns/column_const.h" #include "vec/columns/column_vector.h" #include "vec/columns/columns_number.h" +#include "vec/common/assert_cast.h" #include "vec/core/accurate_comparison.h" #include "vec/data_types/data_type.h" #include "vec/data_types/data_type_number.h" @@ -34,60 +36,33 @@ struct CompareMultiImpl { static DataTypePtr get_return_type_impl(const DataTypes& arguments) { return arguments[0]; } - template - static void insert_result_data(MutableColumnPtr& result_column, ColumnPtr& argument_column, - const size_t input_rows_count) { - auto* __restrict result_raw_data = - reinterpret_cast(result_column.get())->get_data().data(); - auto* __restrict column_raw_data = - reinterpret_cast(argument_column.get())->get_data().data(); - - if constexpr (std::is_same_v) { - for (size_t i = 0; i < input_rows_count; ++i) { - result_raw_data[i] = Op::apply(column_raw_data[i], - result_raw_data[i]) - ? column_raw_data[i] - : result_raw_data[i]; - } - } else { - for (size_t i = 0; i < input_rows_count; ++i) { - using type = std::decay_t; - result_raw_data[i] = Op::apply(column_raw_data[i], result_raw_data[i]) - ? column_raw_data[i] - : result_raw_data[i]; - } - } - } - static ColumnPtr execute(Block& block, const ColumnNumbers& arguments, size_t input_rows_count) { - if (arguments.size() == 1) return block.get_by_position(arguments.back()).column; + if (arguments.size() == 1) { + return block.get_by_position(arguments.back()).column; + } const auto& data_type = block.get_by_position(arguments.back()).type; MutableColumnPtr result_column = data_type->create_column(); - Columns args; + Columns cols(arguments.size()); + std::unique_ptr col_const = std::make_unique(arguments.size()); for (int i = 0; i < arguments.size(); ++i) { - args.emplace_back( - block.get_by_position(arguments[i]).column->convert_to_full_column_if_const()); + std::tie(cols[i], col_const[i]) = + unpack_if_const(block.get_by_position(arguments[i]).column); } // because now the string types does not support random position writing, // so insert into result data have two methods, one is for string types, one is for others type remaining - bool is_string_result = result_column->is_column_string(); - if (is_string_result) { + if (result_column->is_column_string()) { result_column->reserve(input_rows_count); - } else { - result_column->insert_range_from(*(args[0]), 0, input_rows_count); - } - - if (is_string_result) { - const auto& column_string = reinterpret_cast(*args[0]); + const auto& column_string = reinterpret_cast(*cols[0]); auto& column_res = reinterpret_cast(*result_column); for (int i = 0; i < input_rows_count; ++i) { auto str_data = column_string.get_data_at(i); - for (int j = 1; j < arguments.size(); ++j) { - auto temp_data = reinterpret_cast(*args[j]).get_data_at(i); + for (int cmp_col = 1; cmp_col < arguments.size(); ++cmp_col) { + auto temp_data = assert_cast(*cols[cmp_col]) + .get_data_at(index_check_const(i, col_const[i])); str_data = Op::apply(temp_data, str_data) ? temp_data : str_data; } @@ -95,12 +70,15 @@ struct CompareMultiImpl { } } else { + result_column->insert_range_from(*(cols[0]), 0, input_rows_count); WhichDataType which(data_type); -#define DISPATCH(TYPE, COLUMN_TYPE) \ - if (which.idx == TypeIndex::TYPE) { \ - for (int i = 1; i < arguments.size(); ++i) { \ - insert_result_data(result_column, args[i], input_rows_count); \ - } \ + +#define DISPATCH(TYPE, COLUMN_TYPE) \ + if (which.idx == TypeIndex::TYPE) { \ + for (int i = 1; i < arguments.size(); ++i) { \ + insert_result_data(result_column, cols[i], input_rows_count, \ + col_const[i]); \ + } \ } NUMERIC_TYPE_TO_COLUMN_TYPE(DISPATCH) DISPATCH(Decimal128, ColumnDecimal) @@ -110,6 +88,36 @@ struct CompareMultiImpl { return result_column; } + +private: + template + static void insert_result_data(const MutableColumnPtr& result_column, + const ColumnPtr& argument_column, const size_t input_rows_count, + const bool arg_const) { + auto* __restrict result_raw_data = + reinterpret_cast(result_column.get())->get_data().data(); + auto* __restrict column_raw_data = + reinterpret_cast(argument_column.get())->get_data().data(); + + if constexpr (std::is_same_v) { + for (size_t i = 0; i < input_rows_count; ++i) { + result_raw_data[i] = Op::apply( + column_raw_data[index_check_const(i, arg_const)], + result_raw_data[i]) + ? column_raw_data[index_check_const(i, arg_const)] + : result_raw_data[i]; + } + } else { + for (size_t i = 0; i < input_rows_count; ++i) { + using type = std::decay_t; + result_raw_data[i] = + Op::apply(column_raw_data[index_check_const(i, arg_const)], + result_raw_data[i]) + ? column_raw_data[index_check_const(i, arg_const)] + : result_raw_data[i]; + } + } + } }; struct FunctionFieldImpl {