Modify the logic of recursive search for rownum

This commit is contained in:
obdev
2024-02-09 00:49:18 +00:00
committed by ob-robot
parent 238cc0735c
commit 9d088dff7e
3 changed files with 54 additions and 65 deletions

View File

@ -3107,16 +3107,19 @@ int ObLogicalOperator::alloc_op_pre(AllocOpContext& ctx)
sys_rownum_expr->set_op_id(op_id_);
}
}
const ObRawExpr *rownum_expr = NULL;
if (OB_SUCC(ret) && OB_SUCC(find_rownum_expr(rownum_expr)) && rownum_expr != NULL) {
const ObSysFunRawExpr *sys_rownum_expr = static_cast<const ObSysFunRawExpr *>(rownum_expr);
uint64_t count_op_id = sys_rownum_expr->get_op_id();
LOG_DEBUG("the coun_op_id of rownum is", K(count_op_id));
// rownum expr may be above count
if (count_op_id != OB_INVALID_ID) {
if (OB_FAIL(disable_rownum_expr(ctx.disabled_op_set_, count_op_id))) {
LOG_WARN("fail to disable rownum", K(ret), K(count_op_id));
ObLogicalOperator *rownum_op = NULL;
if (OB_SUCC(ret) && OB_SUCC(find_rownum_expr(op_id_, rownum_op)) && rownum_op != NULL) {
uint64_t op_id = rownum_op->get_op_id();
ObLogicalOperator *parent = rownum_op->get_parent();
while (OB_SUCC(ret) && op_id != op_id_ && parent != NULL) {
ret = ctx.disabled_op_set_.set_refactored(op_id_);
if (ret != OB_SUCCESS && ret != OB_HASH_EXIST) {
LOG_WARN("set_refactored fail", K(ret));
} else {
ret = OB_SUCCESS;
}
op_id = parent->get_op_id();
parent = parent->get_parent();
}
}
// disable all right childs of nlj and spf
@ -5690,81 +5693,75 @@ int ObLogicalOperator::collect_batch_exec_param(void* ctx,
return ret;
}
int ObLogicalOperator::find_rownum_expr_recursively(const ObRawExpr *&rownum_expr, const ObRawExpr *expr)
// Recursively search in all subexpressions
int ObLogicalOperator::find_rownum_expr_recursively(const uint64_t &count_op_id,
ObLogicalOperator *&rownum_op,
const ObRawExpr *expr)
{
int ret = OB_SUCCESS;
if (OB_ISNULL(expr)) {
ret = OB_ERR_UNEXPECTED;
LOG_WARN("expr is null", K(ret));
} else if (expr->get_expr_type() == T_FUN_SYS_ROWNUM) {
rownum_expr = expr;
} else if (!expr->has_flag(CNT_ROWNUM)) {
/* expr do not have rownum, need not to search recursively, do nothing */
} else if (expr->get_expr_type() == T_FUN_SYS_ROWNUM
&& static_cast<const ObSysFunRawExpr *>(expr)->get_op_id() == count_op_id) {
rownum_op = this;
} else {
for (int64_t i = 0; OB_SUCC(ret) && rownum_expr == NULL && i < expr->get_param_count(); i++) {
if (OB_FAIL(SMART_CALL(find_rownum_expr_recursively(rownum_expr, expr->get_param_expr(i))))) {
for (int64_t i = 0; OB_SUCC(ret) && NULL == rownum_op && i < expr->get_param_count(); i++) {
if (OB_FAIL(SMART_CALL(
find_rownum_expr_recursively(count_op_id, rownum_op, expr->get_param_expr(i))))) {
LOG_WARN("fail to find rownum expr recursively", K(ret));
}
}
}
LOG_DEBUG("find_rownum_expr_recursively finished", K(expr->get_param_count()),
K(expr->get_expr_type()));
K(expr->get_expr_type()), K(NULL == rownum_op));
return ret;
}
int ObLogicalOperator::find_rownum_expr(const ObRawExpr *&rownum_expr, const ObIArray<ObRawExpr *> &exprs)
int ObLogicalOperator::find_rownum_expr(const uint64_t &count_op_id, ObLogicalOperator *&rownum_op,
const ObIArray<ObRawExpr *> &exprs)
{
LOG_DEBUG("find_rownum_expr begin", K(exprs.count()));
LOG_DEBUG("find_rownum_expr begin", K(exprs.count()), K(NULL == rownum_op));
int ret = OB_SUCCESS;
for (int64_t i = 0; OB_SUCC(ret) && rownum_expr == NULL && i < exprs.count(); i++) {
for (int64_t i = 0; OB_SUCC(ret) && NULL == rownum_op && i < exprs.count(); i++) {
ObRawExpr *expr = exprs.at(i);
ret = find_rownum_expr_recursively(rownum_expr, expr);
LOG_DEBUG("find_rownum_expr_recursively done:", K(expr->get_expr_type()), K(i),
K(expr->get_param_count()));
ret = find_rownum_expr_recursively(count_op_id, rownum_op, expr);
LOG_DEBUG(
"find_rownum_expr_recursively done:", K(expr->get_expr_type()),
K(NULL == rownum_op), K(i), K(expr->get_param_count()));
}
return ret;
}
// rownum expr can show up in the following 4 cases, check them all
// - filter expr
// - output expr
// - join conditions: equal ("=")
// - join conditions: filter (">", "<", ">=", "<=")
int ObLogicalOperator::find_rownum_expr(const ObRawExpr *&rownum_expr)
int ObLogicalOperator::find_rownum_expr(const uint64_t &count_op_id, ObLogicalOperator *&rownum_op)
{
int ret = OB_SUCCESS;
LOG_DEBUG("find_rownum_expr debug: ", K(get_name()));
if (OB_FAIL(find_rownum_expr(rownum_expr, get_filter_exprs()))) {
LOG_DEBUG("find_rownum_expr debug: ", K(get_name()), K(count_op_id));
if (OB_FAIL(find_rownum_expr(count_op_id, rownum_op, get_filter_exprs()))) {
LOG_WARN("failure encountered during find rownum expr", K(ret));
} else if (OB_FAIL(find_rownum_expr(rownum_expr, get_output_exprs()))) {
} else if (OB_FAIL(find_rownum_expr(count_op_id, rownum_op, get_output_exprs()))) {
LOG_WARN("failure encountered during find rownum expr", K(ret));
} else if (rownum_expr == NULL && get_type() == log_op_def::LOG_JOIN) {
} else if (NULL == rownum_op && get_type() == log_op_def::LOG_JOIN) {
ObLogJoin *join_op = dynamic_cast<ObLogJoin *>(this);
// NO NPE check for join_op as it should NOT be nullptr
if (OB_ISNULL(join_op)) {
ret = OB_ERR_UNEXPECTED;
LOG_WARN("join op is null", K(ret));
} else if (OB_FAIL(find_rownum_expr(rownum_expr, join_op->get_other_join_conditions()))) {
} else if (OB_FAIL(
find_rownum_expr(count_op_id, rownum_op, join_op->get_other_join_conditions()))) {
LOG_WARN("failure encountered during find rownum expr", K(ret));
} else if (OB_FAIL(find_rownum_expr(rownum_expr, join_op->get_equal_join_conditions()))) {
} else if (OB_FAIL(
find_rownum_expr(count_op_id, rownum_op, join_op->get_equal_join_conditions()))) {
LOG_WARN("failure encountered during find rownum expr", K(ret));
}
}
return ret;
}
/**
After finding the rownum expression, the count operator is searched from bottom to top, and
the operators on the path are marked as not being able to add materialization.
*/
int ObLogicalOperator::disable_rownum_expr(hash::ObHashSet<uint64_t> &disabled_op_set,
const uint64_t &count_op_id)
{
int ret = OB_SUCCESS;
uint64_t op_id = op_id_;
ObLogicalOperator *parent = get_parent();
while (OB_SUCC(ret) && op_id != count_op_id && parent != NULL) {
ret = disabled_op_set.set_refactored(op_id);
op_id = parent->get_op_id();
parent = parent->get_parent();
for (int64_t i = 0; NULL == rownum_op && OB_SUCC(ret) && i < get_num_of_child(); i++) {
if (OB_FAIL(SMART_CALL(get_child(i)->find_rownum_expr(count_op_id, rownum_op)))) {
LOG_WARN("fail to find rownum expr", K(ret));
}
}
return ret;
}

View File

@ -1358,11 +1358,11 @@ public:
*/
int alloc_op_pre(AllocOpContext& ctx);
int alloc_op_post(AllocOpContext& ctx);
int find_rownum_expr_recursively(const ObRawExpr *&rownum_expr, const ObRawExpr *expr);
int find_rownum_expr(const ObRawExpr *&rownum_expr, const ObIArray<ObRawExpr *> &exprs);
int find_rownum_expr(const ObRawExpr *&rownum_expr);
int disable_rownum_expr(hash::ObHashSet<uint64_t> &disabled_op_set, ObIArray<uint64_t> &cur_path);
int disable_rownum_expr(hash::ObHashSet<uint64_t> &disabled_op_set, const uint64_t &count_op_id);
int find_rownum_expr_recursively(const uint64_t &op_id, ObLogicalOperator *&rownum_op,
const ObRawExpr *expr);
int find_rownum_expr(const uint64_t &op_id, ObLogicalOperator *&rownum_op,
const ObIArray<ObRawExpr *> &exprs);
int find_rownum_expr(const uint64_t &op_id, ObLogicalOperator *&rownum_op);
int gen_temp_op_id(AllocOpContext& ctx);
int recursively_disable_alloc_op_above(AllocOpContext& ctx);
int alloc_nodes_above(AllocOpContext& ctx, const uint64_t &flags);