[Opt](Exec) Support runtime update topn filter (#31250)
This commit is contained in:
@ -26,11 +26,10 @@
|
||||
#include "olap/column_predicate.h"
|
||||
#include "olap/predicate_creator.h"
|
||||
|
||||
namespace doris {
|
||||
namespace doris::vectorized {
|
||||
|
||||
namespace vectorized {
|
||||
|
||||
Status RuntimePredicate::init(const PrimitiveType type, const bool nulls_first) {
|
||||
Status RuntimePredicate::init(PrimitiveType type, bool nulls_first, bool is_asc,
|
||||
const std::string& col_name) {
|
||||
std::unique_lock<std::shared_mutex> wlock(_rwlock);
|
||||
|
||||
if (_inited) {
|
||||
@ -38,55 +37,53 @@ Status RuntimePredicate::init(const PrimitiveType type, const bool nulls_first)
|
||||
}
|
||||
|
||||
_nulls_first = nulls_first;
|
||||
|
||||
_predicate_arena.reset(new Arena());
|
||||
_is_asc = is_asc;
|
||||
// For ASC sort, create runtime predicate col_name <= max_top_value
|
||||
// since values that > min_top_value are large than any value in current topn values
|
||||
// For DESC sort, create runtime predicate col_name >= min_top_value
|
||||
// since values that < min_top_value are less than any value in current topn values
|
||||
_pred_constructor = is_asc ? create_comparison_predicate<PredicateType::LE>
|
||||
: create_comparison_predicate<PredicateType::GE>;
|
||||
_col_name = col_name;
|
||||
|
||||
// set get value function
|
||||
switch (type) {
|
||||
case PrimitiveType::TYPE_BOOLEAN: {
|
||||
_get_value_fn = get_bool_value;
|
||||
_get_value_fn = get_normal_value<TYPE_BOOLEAN>;
|
||||
break;
|
||||
}
|
||||
case PrimitiveType::TYPE_TINYINT: {
|
||||
_get_value_fn = get_tinyint_value;
|
||||
_get_value_fn = get_normal_value<TYPE_TINYINT>;
|
||||
break;
|
||||
}
|
||||
case PrimitiveType::TYPE_SMALLINT: {
|
||||
_get_value_fn = get_smallint_value;
|
||||
_get_value_fn = get_normal_value<TYPE_SMALLINT>;
|
||||
break;
|
||||
}
|
||||
case PrimitiveType::TYPE_INT: {
|
||||
_get_value_fn = get_int_value;
|
||||
_get_value_fn = get_normal_value<TYPE_INT>;
|
||||
break;
|
||||
}
|
||||
case PrimitiveType::TYPE_BIGINT: {
|
||||
_get_value_fn = get_bigint_value;
|
||||
_get_value_fn = get_normal_value<TYPE_BIGINT>;
|
||||
break;
|
||||
}
|
||||
case PrimitiveType::TYPE_LARGEINT: {
|
||||
_get_value_fn = get_largeint_value;
|
||||
break;
|
||||
}
|
||||
case PrimitiveType::TYPE_FLOAT: {
|
||||
_get_value_fn = get_float_value;
|
||||
break;
|
||||
}
|
||||
case PrimitiveType::TYPE_DOUBLE: {
|
||||
_get_value_fn = get_double_value;
|
||||
_get_value_fn = get_normal_value<TYPE_LARGEINT>;
|
||||
break;
|
||||
}
|
||||
case PrimitiveType::TYPE_CHAR:
|
||||
case PrimitiveType::TYPE_VARCHAR:
|
||||
case PrimitiveType::TYPE_STRING: {
|
||||
_get_value_fn = get_string_value;
|
||||
_get_value_fn = [](const Field& field) { return field.get<String>(); };
|
||||
break;
|
||||
}
|
||||
case PrimitiveType::TYPE_DATEV2: {
|
||||
_get_value_fn = get_datev2_value;
|
||||
_get_value_fn = get_normal_value<TYPE_DATEV2>;
|
||||
break;
|
||||
}
|
||||
case PrimitiveType::TYPE_DATETIMEV2: {
|
||||
_get_value_fn = get_datetimev2_value;
|
||||
_get_value_fn = get_normal_value<TYPE_DATETIMEV2>;
|
||||
break;
|
||||
}
|
||||
case PrimitiveType::TYPE_DATE: {
|
||||
@ -98,11 +95,11 @@ Status RuntimePredicate::init(const PrimitiveType type, const bool nulls_first)
|
||||
break;
|
||||
}
|
||||
case PrimitiveType::TYPE_DECIMAL32: {
|
||||
_get_value_fn = get_decimal32_value;
|
||||
_get_value_fn = get_decimal_value<TYPE_DECIMAL32>;
|
||||
break;
|
||||
}
|
||||
case PrimitiveType::TYPE_DECIMAL64: {
|
||||
_get_value_fn = get_decimal64_value;
|
||||
_get_value_fn = get_decimal_value<TYPE_DECIMAL64>;
|
||||
break;
|
||||
}
|
||||
case PrimitiveType::TYPE_DECIMALV2: {
|
||||
@ -110,19 +107,19 @@ Status RuntimePredicate::init(const PrimitiveType type, const bool nulls_first)
|
||||
break;
|
||||
}
|
||||
case PrimitiveType::TYPE_DECIMAL128I: {
|
||||
_get_value_fn = get_decimal128_value;
|
||||
_get_value_fn = get_decimal_value<TYPE_DECIMAL128I>;
|
||||
break;
|
||||
}
|
||||
case PrimitiveType::TYPE_DECIMAL256: {
|
||||
_get_value_fn = get_decimal256_value;
|
||||
_get_value_fn = get_decimal_value<TYPE_DECIMAL256>;
|
||||
break;
|
||||
}
|
||||
case PrimitiveType::TYPE_IPV4: {
|
||||
_get_value_fn = get_ipv4_value;
|
||||
_get_value_fn = get_normal_value<TYPE_IPV4>;
|
||||
break;
|
||||
}
|
||||
case PrimitiveType::TYPE_IPV6: {
|
||||
_get_value_fn = get_ipv6_value;
|
||||
_get_value_fn = get_normal_value<TYPE_IPV6>;
|
||||
break;
|
||||
}
|
||||
default:
|
||||
@ -133,30 +130,20 @@ Status RuntimePredicate::init(const PrimitiveType type, const bool nulls_first)
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status RuntimePredicate::update(const Field& value, const String& col_name, bool is_reverse) {
|
||||
// skip null value
|
||||
if (value.is_null()) {
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
if (!_inited) {
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status RuntimePredicate::update(const Field& value) {
|
||||
std::unique_lock<std::shared_mutex> wlock(_rwlock);
|
||||
// skip null value
|
||||
if (value.is_null() || !_inited || !_tablet_schema) {
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
bool updated = false;
|
||||
|
||||
if (UNLIKELY(_orderby_extrem.is_null())) {
|
||||
_orderby_extrem = value;
|
||||
updated = true;
|
||||
} else if (is_reverse) {
|
||||
if (value > _orderby_extrem) {
|
||||
_orderby_extrem = value;
|
||||
updated = true;
|
||||
}
|
||||
} else {
|
||||
if (value < _orderby_extrem) {
|
||||
if ((_is_asc && value < _orderby_extrem) || (!_is_asc && value > _orderby_extrem)) {
|
||||
_orderby_extrem = value;
|
||||
updated = true;
|
||||
}
|
||||
@ -166,38 +153,19 @@ Status RuntimePredicate::update(const Field& value, const String& col_name, bool
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// TODO defensive code
|
||||
if (!_tablet_schema || !_tablet_schema->have_column(col_name)) {
|
||||
return Status::OK();
|
||||
}
|
||||
// update _predictate
|
||||
int32_t col_unique_id = _tablet_schema->column(col_name).unique_id();
|
||||
const TabletColumn& column = _tablet_schema->column_by_uid(col_unique_id);
|
||||
uint32_t index = _tablet_schema->field_index(col_unique_id);
|
||||
auto val = _get_value_fn(_orderby_extrem);
|
||||
std::unique_ptr<ColumnPredicate> pred {nullptr};
|
||||
if (is_reverse) {
|
||||
// For DESC sort, create runtime predicate col_name >= min_top_value
|
||||
// since values that < min_top_value are less than any value in current topn values
|
||||
pred.reset(create_comparison_predicate<PredicateType::GE>(column, index, val, false,
|
||||
_predicate_arena.get()));
|
||||
} else {
|
||||
// For ASC sort, create runtime predicate col_name <= max_top_value
|
||||
// since values that > min_top_value are large than any value in current topn values
|
||||
pred.reset(create_comparison_predicate<PredicateType::LE>(column, index, val, false,
|
||||
_predicate_arena.get()));
|
||||
}
|
||||
|
||||
std::unique_ptr<ColumnPredicate> pred {
|
||||
_pred_constructor(_tablet_schema->column(_col_name), _predicate->column_id(),
|
||||
_get_value_fn(_orderby_extrem), false, &_predicate_arena)};
|
||||
// For NULLS FIRST, wrap a AcceptNullPredicate to return true for NULL
|
||||
// since ORDER BY ASC/DESC should get NULL first but pred returns NULL
|
||||
// and NULL in where predicate will be treated as FALSE
|
||||
if (_nulls_first) {
|
||||
pred = AcceptNullPredicate::create_unique(pred.release());
|
||||
}
|
||||
_predictate.reset(pred.release());
|
||||
|
||||
((SharedPredicate*)_predicate.get())->set_nested(pred.release());
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
} // namespace vectorized
|
||||
} // namespace doris
|
||||
} // namespace doris::vectorized
|
||||
|
||||
Reference in New Issue
Block a user