diff --git a/src/sql/code_generator/ob_static_engine_cg.cpp b/src/sql/code_generator/ob_static_engine_cg.cpp index 582db2d64a..3efae04d71 100644 --- a/src/sql/code_generator/ob_static_engine_cg.cpp +++ b/src/sql/code_generator/ob_static_engine_cg.cpp @@ -635,17 +635,9 @@ int ObStaticEngineCG::check_vectorize_supported(bool &support, // Expr rownum() shows up in both operator 0 and 2, which leads circular // dependency and breaks rownum's defination. // - const ObRawExpr *rownum_expr = NULL; - for (int64_t i = 0; rownum_expr == NULL && OB_SUCC(ret) && i < op->get_num_of_child(); i++) { - ObLogicalOperator *child = op->get_child(i); - if (OB_ISNULL(child)) { - ret = OB_ERR_UNEXPECTED; - LOG_WARN("op child is null", K(ret)); - } else if (OB_FAIL(child->find_rownum_expr(rownum_expr))) { - LOG_WARN("find rownum expr error", K(ret)); - } - } - if (NULL != rownum_expr) { + ObLogicalOperator *rownum_op = NULL; + if (OB_SUCC(ret) && OB_SUCC(op->find_rownum_expr(op->get_op_id(), rownum_op)) + && rownum_op != NULL && rownum_op != op) { LOG_DEBUG("rownum expr is in count operator's subplan tree. Stop vectorization exec"); disable_vectorize = true; } diff --git a/src/sql/optimizer/ob_logical_operator.cpp b/src/sql/optimizer/ob_logical_operator.cpp index 6f9f31ce91..f4854074cd 100644 --- a/src/sql/optimizer/ob_logical_operator.cpp +++ b/src/sql/optimizer/ob_logical_operator.cpp @@ -3107,16 +3107,19 @@ int ObLogicalOperator::alloc_op_pre(AllocOpContext& ctx) sys_rownum_expr->set_op_id(op_id_); } } - const ObRawExpr *rownum_expr = NULL; - if (OB_SUCC(ret) && OB_SUCC(find_rownum_expr(rownum_expr)) && rownum_expr != NULL) { - const ObSysFunRawExpr *sys_rownum_expr = static_cast(rownum_expr); - uint64_t count_op_id = sys_rownum_expr->get_op_id(); - LOG_DEBUG("the coun_op_id of rownum is", K(count_op_id)); - // rownum expr may be above count - if (count_op_id != OB_INVALID_ID) { - if (OB_FAIL(disable_rownum_expr(ctx.disabled_op_set_, count_op_id))) { - LOG_WARN("fail to disable rownum", K(ret), K(count_op_id)); + ObLogicalOperator *rownum_op = NULL; + if (OB_SUCC(ret) && OB_SUCC(find_rownum_expr(op_id_, rownum_op)) && rownum_op != NULL) { + uint64_t op_id = rownum_op->get_op_id(); + ObLogicalOperator *parent = rownum_op->get_parent(); + while (OB_SUCC(ret) && op_id != op_id_ && parent != NULL) { + ret = ctx.disabled_op_set_.set_refactored(op_id_); + if (ret != OB_SUCCESS && ret != OB_HASH_EXIST) { + LOG_WARN("set_refactored fail", K(ret)); + } else { + ret = OB_SUCCESS; } + op_id = parent->get_op_id(); + parent = parent->get_parent(); } } // disable all right childs of nlj and spf @@ -5690,81 +5693,75 @@ int ObLogicalOperator::collect_batch_exec_param(void* ctx, return ret; } -int ObLogicalOperator::find_rownum_expr_recursively(const ObRawExpr *&rownum_expr, const ObRawExpr *expr) +// Recursively search in all subexpressions +int ObLogicalOperator::find_rownum_expr_recursively(const uint64_t &count_op_id, + ObLogicalOperator *&rownum_op, + const ObRawExpr *expr) { int ret = OB_SUCCESS; if (OB_ISNULL(expr)) { ret = OB_ERR_UNEXPECTED; LOG_WARN("expr is null", K(ret)); - } else if (expr->get_expr_type() == T_FUN_SYS_ROWNUM) { - rownum_expr = expr; + } else if (!expr->has_flag(CNT_ROWNUM)) { + /* expr do not have rownum, need not to search recursively, do nothing */ + } else if (expr->get_expr_type() == T_FUN_SYS_ROWNUM + && static_cast(expr)->get_op_id() == count_op_id) { + rownum_op = this; } else { - for (int64_t i = 0; OB_SUCC(ret) && rownum_expr == NULL && i < expr->get_param_count(); i++) { - if (OB_FAIL(SMART_CALL(find_rownum_expr_recursively(rownum_expr, expr->get_param_expr(i))))) { + for (int64_t i = 0; OB_SUCC(ret) && NULL == rownum_op && i < expr->get_param_count(); i++) { + if (OB_FAIL(SMART_CALL( + find_rownum_expr_recursively(count_op_id, rownum_op, expr->get_param_expr(i))))) { LOG_WARN("fail to find rownum expr recursively", K(ret)); } } } LOG_DEBUG("find_rownum_expr_recursively finished", K(expr->get_param_count()), - K(expr->get_expr_type())); + K(expr->get_expr_type()), K(NULL == rownum_op)); return ret; } - -int ObLogicalOperator::find_rownum_expr(const ObRawExpr *&rownum_expr, const ObIArray &exprs) +int ObLogicalOperator::find_rownum_expr(const uint64_t &count_op_id, ObLogicalOperator *&rownum_op, + const ObIArray &exprs) { - LOG_DEBUG("find_rownum_expr begin", K(exprs.count())); + LOG_DEBUG("find_rownum_expr begin", K(exprs.count()), K(NULL == rownum_op)); int ret = OB_SUCCESS; - for (int64_t i = 0; OB_SUCC(ret) && rownum_expr == NULL && i < exprs.count(); i++) { + for (int64_t i = 0; OB_SUCC(ret) && NULL == rownum_op && i < exprs.count(); i++) { ObRawExpr *expr = exprs.at(i); - ret = find_rownum_expr_recursively(rownum_expr, expr); - LOG_DEBUG("find_rownum_expr_recursively done:", K(expr->get_expr_type()), K(i), - K(expr->get_param_count())); + ret = find_rownum_expr_recursively(count_op_id, rownum_op, expr); + LOG_DEBUG( + "find_rownum_expr_recursively done:", K(expr->get_expr_type()), + K(NULL == rownum_op), K(i), K(expr->get_param_count())); } return ret; } - -// rownum expr can show up in the following 4 cases, check them all -// - filter expr // - output expr // - join conditions: equal ("=") // - join conditions: filter (">", "<", ">=", "<=") -int ObLogicalOperator::find_rownum_expr(const ObRawExpr *&rownum_expr) +int ObLogicalOperator::find_rownum_expr(const uint64_t &count_op_id, ObLogicalOperator *&rownum_op) { int ret = OB_SUCCESS; - LOG_DEBUG("find_rownum_expr debug: ", K(get_name())); - if (OB_FAIL(find_rownum_expr(rownum_expr, get_filter_exprs()))) { + LOG_DEBUG("find_rownum_expr debug: ", K(get_name()), K(count_op_id)); + if (OB_FAIL(find_rownum_expr(count_op_id, rownum_op, get_filter_exprs()))) { LOG_WARN("failure encountered during find rownum expr", K(ret)); - } else if (OB_FAIL(find_rownum_expr(rownum_expr, get_output_exprs()))) { + } else if (OB_FAIL(find_rownum_expr(count_op_id, rownum_op, get_output_exprs()))) { LOG_WARN("failure encountered during find rownum expr", K(ret)); - } else if (rownum_expr == NULL && get_type() == log_op_def::LOG_JOIN) { + } else if (NULL == rownum_op && get_type() == log_op_def::LOG_JOIN) { ObLogJoin *join_op = dynamic_cast(this); // NO NPE check for join_op as it should NOT be nullptr if (OB_ISNULL(join_op)) { ret = OB_ERR_UNEXPECTED; LOG_WARN("join op is null", K(ret)); - } else if (OB_FAIL(find_rownum_expr(rownum_expr, join_op->get_other_join_conditions()))) { + } else if (OB_FAIL( + find_rownum_expr(count_op_id, rownum_op, join_op->get_other_join_conditions()))) { LOG_WARN("failure encountered during find rownum expr", K(ret)); - } else if (OB_FAIL(find_rownum_expr(rownum_expr, join_op->get_equal_join_conditions()))) { + } else if (OB_FAIL( + find_rownum_expr(count_op_id, rownum_op, join_op->get_equal_join_conditions()))) { LOG_WARN("failure encountered during find rownum expr", K(ret)); } } - return ret; -} - -/** -After finding the rownum expression, the count operator is searched from bottom to top, and -the operators on the path are marked as not being able to add materialization. -*/ -int ObLogicalOperator::disable_rownum_expr(hash::ObHashSet &disabled_op_set, - const uint64_t &count_op_id) -{ - int ret = OB_SUCCESS; - uint64_t op_id = op_id_; - ObLogicalOperator *parent = get_parent(); - while (OB_SUCC(ret) && op_id != count_op_id && parent != NULL) { - ret = disabled_op_set.set_refactored(op_id); - op_id = parent->get_op_id(); - parent = parent->get_parent(); + for (int64_t i = 0; NULL == rownum_op && OB_SUCC(ret) && i < get_num_of_child(); i++) { + if (OB_FAIL(SMART_CALL(get_child(i)->find_rownum_expr(count_op_id, rownum_op)))) { + LOG_WARN("fail to find rownum expr", K(ret)); + } } return ret; } diff --git a/src/sql/optimizer/ob_logical_operator.h b/src/sql/optimizer/ob_logical_operator.h index 2611ff5547..7580d1f111 100644 --- a/src/sql/optimizer/ob_logical_operator.h +++ b/src/sql/optimizer/ob_logical_operator.h @@ -1358,11 +1358,11 @@ public: */ int alloc_op_pre(AllocOpContext& ctx); int alloc_op_post(AllocOpContext& ctx); - int find_rownum_expr_recursively(const ObRawExpr *&rownum_expr, const ObRawExpr *expr); - int find_rownum_expr(const ObRawExpr *&rownum_expr, const ObIArray &exprs); - int find_rownum_expr(const ObRawExpr *&rownum_expr); - int disable_rownum_expr(hash::ObHashSet &disabled_op_set, ObIArray &cur_path); - int disable_rownum_expr(hash::ObHashSet &disabled_op_set, const uint64_t &count_op_id); + int find_rownum_expr_recursively(const uint64_t &op_id, ObLogicalOperator *&rownum_op, + const ObRawExpr *expr); + int find_rownum_expr(const uint64_t &op_id, ObLogicalOperator *&rownum_op, + const ObIArray &exprs); + int find_rownum_expr(const uint64_t &op_id, ObLogicalOperator *&rownum_op); int gen_temp_op_id(AllocOpContext& ctx); int recursively_disable_alloc_op_above(AllocOpContext& ctx); int alloc_nodes_above(AllocOpContext& ctx, const uint64_t &flags);