diff --git a/be/src/util/simd/bits.h b/be/src/util/simd/bits.h index 54e2a0cb2c..8edbe72f4a 100644 --- a/be/src/util/simd/bits.h +++ b/be/src/util/simd/bits.h @@ -18,6 +18,8 @@ #pragma once #include +#include +#include #ifdef __AVX2__ #include @@ -97,5 +99,26 @@ inline size_t count_zero_num(const int8_t* __restrict data, const uint8_t* __res return num; } +// TODO: compare with different SIMD implements +template +inline static size_t find_byte(const std::vector& vec, size_t start, T byte) { + if (start >= vec.size()) { + return start; + } + const void* p = std::memchr((const void*)(vec.data() + start), byte, vec.size() - start); + if (p == nullptr) { + return vec.size(); + } + return (T*)p - vec.data(); +} + +inline size_t find_nonzero(const std::vector& vec, size_t start) { + return find_byte(vec, start, 1); +} + +inline size_t find_zero(const std::vector& vec, size_t start) { + return find_byte(vec, start, 0); +} + } // namespace simd } // namespace doris diff --git a/be/src/vec/columns/column.cpp b/be/src/vec/columns/column.cpp index 815be82024..845600bc1b 100644 --- a/be/src/vec/columns/column.cpp +++ b/be/src/vec/columns/column.cpp @@ -25,6 +25,7 @@ #include "vec/columns/column_const.h" #include "vec/columns/column_nullable.h" #include "vec/core/field.h" +#include "vec/core/sort_block.h" namespace doris::vectorized { @@ -46,6 +47,11 @@ void IColumn::insert_from(const IColumn& src, size_t n) { insert(src[n]); } +void IColumn::sort_column(const ColumnSorter* sorter, EqualFlags& flags, + IColumn::Permutation& perms, EqualRange& range, bool last_column) const { + sorter->sort_column(static_cast(*this), flags, perms, range, last_column); +} + bool is_column_nullable(const IColumn& column) { return check_column(column); } diff --git a/be/src/vec/columns/column.h b/be/src/vec/columns/column.h index a1883a1935..71d737274c 100644 --- a/be/src/vec/columns/column.h +++ b/be/src/vec/columns/column.h @@ -48,6 +48,10 @@ namespace doris::vectorized { class Arena; class Field; +class ColumnSorter; +using EqualFlags = std::vector; +using EqualRange = std::pair; + /// Declares interface to store columns in memory. class IColumn : public COW { private: @@ -529,6 +533,10 @@ public: virtual bool low_cardinality() const { return false; } + virtual void sort_column(const ColumnSorter* sorter, EqualFlags& flags, + IColumn::Permutation& perms, EqualRange& range, + bool last_column) const; + virtual ~IColumn() = default; IColumn() = default; IColumn(const IColumn&) = default; diff --git a/be/src/vec/columns/column_decimal.cpp b/be/src/vec/columns/column_decimal.cpp index 91145e4473..ea904c8c30 100644 --- a/be/src/vec/columns/column_decimal.cpp +++ b/be/src/vec/columns/column_decimal.cpp @@ -27,6 +27,7 @@ #include "vec/common/exception.h" #include "vec/common/sip_hash.h" #include "vec/common/unaligned.h" +#include "vec/core/sort_block.h" template bool decimal_less(T x, T y, doris::vectorized::UInt32 x_scale, doris::vectorized::UInt32 y_scale); @@ -330,6 +331,13 @@ void ColumnDecimal::get_extremes(Field& min, Field& max) const { max = NearestFieldType(cur_max, scale); } +template +void ColumnDecimal::sort_column(const ColumnSorter* sorter, EqualFlags& flags, + IColumn::Permutation& perms, EqualRange& range, + bool last_column) const { + sorter->template sort_column(static_cast(*this), flags, perms, range, last_column); +} + template <> Decimal32 ColumnDecimal::get_scale_multiplier() const { return common::exp10_i32(scale); diff --git a/be/src/vec/columns/column_decimal.h b/be/src/vec/columns/column_decimal.h index 2b3dac0ae5..81e037278c 100644 --- a/be/src/vec/columns/column_decimal.h +++ b/be/src/vec/columns/column_decimal.h @@ -226,6 +226,9 @@ public: data[self_row] = T(); } + void sort_column(const ColumnSorter* sorter, EqualFlags& flags, IColumn::Permutation& perms, + EqualRange& range, bool last_column) const override; + UInt32 get_scale() const { return scale; } T get_scale_multiplier() const; diff --git a/be/src/vec/columns/column_nullable.cpp b/be/src/vec/columns/column_nullable.cpp index cff4d0f7c3..4f18163479 100644 --- a/be/src/vec/columns/column_nullable.cpp +++ b/be/src/vec/columns/column_nullable.cpp @@ -27,6 +27,7 @@ #include "vec/common/nan_utils.h" #include "vec/common/sip_hash.h" #include "vec/common/typeid_cast.h" +#include "vec/core/sort_block.h" namespace doris::vectorized { @@ -508,6 +509,13 @@ void ColumnNullable::check_consistency() const { } } +void ColumnNullable::sort_column(const ColumnSorter* sorter, EqualFlags& flags, + IColumn::Permutation& perms, EqualRange& range, + bool last_column) const { + sorter->sort_column(static_cast(*this), flags, perms, range, + last_column); +} + ColumnPtr make_nullable(const ColumnPtr& column, bool is_nullable) { if (is_column_nullable(*column)) return column; diff --git a/be/src/vec/columns/column_nullable.h b/be/src/vec/columns/column_nullable.h index 397f3be1bd..2a8bf52dd2 100644 --- a/be/src/vec/columns/column_nullable.h +++ b/be/src/vec/columns/column_nullable.h @@ -305,6 +305,9 @@ public: get_nested_column().generate_hash_values_for_runtime_filter(); } + void sort_column(const ColumnSorter* sorter, EqualFlags& flags, IColumn::Permutation& perms, + EqualRange& range, bool last_column) const override; + private: WrappedPtr nested_column; WrappedPtr null_map; diff --git a/be/src/vec/columns/column_string.cpp b/be/src/vec/columns/column_string.cpp index c8b99e8ffa..18f0d74d23 100644 --- a/be/src/vec/columns/column_string.cpp +++ b/be/src/vec/columns/column_string.cpp @@ -25,6 +25,7 @@ #include "vec/common/assert_cast.h" #include "vec/common/memcmp_small.h" #include "vec/common/unaligned.h" +#include "vec/core/sort_block.h" namespace doris::vectorized { @@ -424,6 +425,12 @@ void ColumnString::get_extremes(Field& min, Field& max) const { get(max_idx, max); } +void ColumnString::sort_column(const ColumnSorter* sorter, EqualFlags& flags, + IColumn::Permutation& perms, EqualRange& range, + bool last_column) const { + sorter->sort_column(static_cast(*this), flags, perms, range, last_column); +} + void ColumnString::protect() { get_chars().protect(); get_offsets().protect(); diff --git a/be/src/vec/columns/column_string.h b/be/src/vec/columns/column_string.h index d2744e20b7..8c504c6d08 100644 --- a/be/src/vec/columns/column_string.h +++ b/be/src/vec/columns/column_string.h @@ -265,6 +265,9 @@ public: ColumnPtr permute(const Permutation& perm, size_t limit) const override; + void sort_column(const ColumnSorter* sorter, EqualFlags& flags, IColumn::Permutation& perms, + EqualRange& range, bool last_column) const override; + // ColumnPtr index(const IColumn & indexes, size_t limit) const override; template diff --git a/be/src/vec/columns/column_vector.cpp b/be/src/vec/columns/column_vector.cpp index 13c5810026..d60f8816a1 100644 --- a/be/src/vec/columns/column_vector.cpp +++ b/be/src/vec/columns/column_vector.cpp @@ -35,6 +35,7 @@ #include "vec/common/nan_utils.h" #include "vec/common/sip_hash.h" #include "vec/common/unaligned.h" +#include "vec/core/sort_block.h" namespace doris::vectorized { @@ -111,6 +112,13 @@ void ColumnVector::update_hashes_with_value(std::vector& hashes, SIP_HASHES_FUNCTION_COLUMN_IMPL(); } +template +void ColumnVector::sort_column(const ColumnSorter* sorter, EqualFlags& flags, + IColumn::Permutation& perms, EqualRange& range, + bool last_column) const { + sorter->template sort_column(static_cast(*this), flags, perms, range, last_column); +} + template struct ColumnVector::less { const Self& parent; diff --git a/be/src/vec/columns/column_vector.h b/be/src/vec/columns/column_vector.h index 5101894b92..3a8ec82382 100644 --- a/be/src/vec/columns/column_vector.h +++ b/be/src/vec/columns/column_vector.h @@ -384,6 +384,9 @@ public: data[self_row] = T(); } + void sort_column(const ColumnSorter* sorter, EqualFlags& flags, IColumn::Permutation& perms, + EqualRange& range, bool last_column) const override; + protected: Container data; }; diff --git a/be/src/vec/core/sort_block.cpp b/be/src/vec/core/sort_block.cpp index d26566e1f9..657fd58d23 100644 --- a/be/src/vec/core/sort_block.cpp +++ b/be/src/vec/core/sort_block.cpp @@ -20,8 +20,6 @@ #include "vec/core/sort_block.h" -#include - #include "vec/columns/column_string.h" #include "vec/common/typeid_cast.h" @@ -100,12 +98,12 @@ void sort_block(Block& block, const SortDescription& description, UInt64 limit) ColumnsWithSortDescriptions columns_with_sort_desc = get_columns_with_sort_description(block, description); { - PartialSortingLess less(columns_with_sort_desc); + EqualFlags flags(size, 1); + EqualRange range {0, size}; - if (limit) { - std::partial_sort(perm.begin(), perm.begin() + limit, perm.end(), less); - } else { - pdqsort(perm.begin(), perm.end(), less); + for (size_t i = 0; i < columns_with_sort_desc.size(); i++) { + ColumnSorter sorter(columns_with_sort_desc[i], limit); + sorter.operator()(flags, perm, range, i == columns_with_sort_desc.size() - 1); } } diff --git a/be/src/vec/core/sort_block.h b/be/src/vec/core/sort_block.h index bbdbcdb783..2a2babf46f 100644 --- a/be/src/vec/core/sort_block.h +++ b/be/src/vec/core/sort_block.h @@ -19,7 +19,9 @@ // and modified by Doris #pragma once +#include +#include "util/simd/bits.h" #include "vec/core/block.h" #include "vec/core/sort_description.h" @@ -46,9 +48,412 @@ void stable_get_permutation(const Block& block, const SortDescription& descripti */ bool is_already_sorted(const Block& block, const SortDescription& description); -using ColumnsWithSortDescriptions = std::vector>; +using ColumnWithSortDescription = std::pair; + +using ColumnsWithSortDescriptions = std::vector; ColumnsWithSortDescriptions get_columns_with_sort_description(const Block& block, const SortDescription& description); +struct EqualRangeIterator { + int range_begin; + int range_end; + + EqualRangeIterator(const EqualFlags& flags) : EqualRangeIterator(flags, 0, flags.size()) {} + + EqualRangeIterator(const EqualFlags& flags, int begin, int end) : _flags(flags), _end(end) { + range_begin = begin; + range_end = end; + _cur_range_begin = begin; + _cur_range_end = end; + } + + bool next() { + if (_cur_range_begin >= _end) { + return false; + } + + // `_flags[i]=1` indicates that the i-th row is equal to the previous row, which means we + // should continue to sort this row according to current column. Using the first non-zero + // value and first zero value after first non-zero value as two bounds, we can get an equal range here + if (!(_cur_range_begin == 0) || !(_flags[_cur_range_begin] == 1)) { + _cur_range_begin = simd::find_nonzero(_flags, _cur_range_begin + 1); + if (_cur_range_begin >= _end) { + return false; + } + _cur_range_begin--; + } + + _cur_range_end = simd::find_zero(_flags, _cur_range_begin + 1); + DCHECK(_cur_range_end <= _end); + + if (_cur_range_begin >= _cur_range_end) { + return false; + } + + range_begin = _cur_range_begin; + range_end = _cur_range_end; + _cur_range_begin = _cur_range_end; + return true; + } + +private: + int _cur_range_begin; + int _cur_range_end; + + const EqualFlags& _flags; + const int _end; +}; + +struct ColumnPartialSortingLess { + const ColumnWithSortDescription& _column_with_sort_desc; + + explicit ColumnPartialSortingLess(const ColumnWithSortDescription& column) + : _column_with_sort_desc(column) {} + + bool operator()(size_t a, size_t b) const { + int res = _column_with_sort_desc.second.direction * + _column_with_sort_desc.first->compare_at( + a, b, *_column_with_sort_desc.first, + _column_with_sort_desc.second.nulls_direction); + if (res < 0) { + return true; + } else if (res > 0) { + return false; + } + return false; + } +}; + +template +struct PermutationWithInlineValue { + T inline_value; + uint32_t row_id; +}; + +template +using PermutationForColumn = std::vector>; + +class ColumnSorter { +public: + explicit ColumnSorter(const ColumnWithSortDescription& column, const int limit) + : _column_with_sort_desc(column), + _limit(limit), + _nulls_direction(column.second.nulls_direction), + _direction(column.second.direction) {} + + void operator()(EqualFlags& flags, IColumn::Permutation& perms, EqualRange& range, + bool last_column) const { + _column_with_sort_desc.first->sort_column(this, flags, perms, range, last_column); + } + + void sort_column(const IColumn& column, EqualFlags& flags, IColumn::Permutation& perms, + EqualRange& range, bool last_column) const { + int new_limit = _limit; + auto comparator = [&](const size_t a, const size_t b) { + return column.compare_at(a, b, *_column_with_sort_desc.first, _nulls_direction); + }; + ColumnPartialSortingLess less(_column_with_sort_desc); + auto do_sort = [&](size_t first_iter, size_t last_iter) { + auto begin = perms.begin() + first_iter; + auto end = perms.begin() + last_iter; + + if (UNLIKELY(_limit > 0 && first_iter < _limit && _limit <= last_iter)) { + int n = _limit - first_iter; + std::partial_sort(begin, begin + n, end, less); + + auto nth = perms[_limit - 1]; + size_t equal_count = 0; + for (auto iter = begin + n; iter < end; iter++) { + if (comparator(*iter, nth) == 0) { + std::iter_swap(iter, begin + n + equal_count); + equal_count++; + } + } + new_limit = _limit + equal_count; + } else { + pdqsort(begin, end, less); + } + }; + + EqualRangeIterator iterator(flags, range.first, range.second); + while (iterator.next()) { + int range_begin = iterator.range_begin; + int range_end = iterator.range_end; + + if (UNLIKELY(_limit > 0 && range_begin > _limit)) { + break; + } + if (LIKELY(range_end - range_begin > 1)) { + do_sort(range_begin, range_end); + if (!last_column) { + flags[range_begin] = 0; + for (int i = range_begin + 1; i < range_end; i++) { + flags[i] &= comparator(perms[i - 1], perms[i]) == 0; + } + } + } + } + _shrink_to_fit(perms, flags, new_limit); + } + + template