diff --git a/src/sql/code_generator/ob_static_engine_cg.cpp b/src/sql/code_generator/ob_static_engine_cg.cpp index f97412e815..582db2d64a 100644 --- a/src/sql/code_generator/ob_static_engine_cg.cpp +++ b/src/sql/code_generator/ob_static_engine_cg.cpp @@ -635,13 +635,18 @@ int ObStaticEngineCG::check_vectorize_supported(bool &support, // Expr rownum() shows up in both operator 0 and 2, which leads circular // dependency and breaks rownum's defination. // - bool has_rownum_expr = false; - for (int64_t i = 0; !has_rownum_expr && OB_SUCC(ret) && i < op->get_num_of_child(); i++) { - OZ(op->get_child(i)->find_rownum_expr(has_rownum_expr)); + const ObRawExpr *rownum_expr = NULL; + for (int64_t i = 0; rownum_expr == NULL && OB_SUCC(ret) && i < op->get_num_of_child(); i++) { + ObLogicalOperator *child = op->get_child(i); + if (OB_ISNULL(child)) { + ret = OB_ERR_UNEXPECTED; + LOG_WARN("op child is null", K(ret)); + } else if (OB_FAIL(child->find_rownum_expr(rownum_expr))) { + LOG_WARN("find rownum expr error", K(ret)); + } } - if (has_rownum_expr) { - LOG_DEBUG("rownum expr is in count operator's subplan tree. Stop vectorization execution", - K(has_rownum_expr)); + if (NULL != rownum_expr) { + LOG_DEBUG("rownum expr is in count operator's subplan tree. Stop vectorization exec"); disable_vectorize = true; } } else if (log_op_def::LOG_JOIN == op->get_type() && diff --git a/src/sql/optimizer/ob_logical_operator.cpp b/src/sql/optimizer/ob_logical_operator.cpp index 94b8641dba..e4df532075 100644 --- a/src/sql/optimizer/ob_logical_operator.cpp +++ b/src/sql/optimizer/ob_logical_operator.cpp @@ -3092,11 +3092,30 @@ int ObLogicalOperator::alloc_op_pre(AllocOpContext& ctx) ctx.gen_temp_op_id_ = true; } // disable nodes in COUNT-rownum situation - if (OB_SUCC(ret) && log_op_def::LOG_COUNT == get_type()) { - ObSEArray cur_path; - for (int64_t i = 0; OB_SUCC(ret) && i < get_num_of_child(); i++) { - if (OB_FAIL(get_child(i)->disable_rownum_expr(ctx.disabled_op_set_, cur_path))) { - LOG_WARN("fail to find rownum expr", K(ret)); + if (OB_SUCC(ret) && LOG_COUNT == get_type()) { + ObRawExpr *rownum_expr = NULL; + if (OB_ISNULL(get_stmt())) { + ret = OB_ERR_UNEXPECTED; + LOG_WARN("stmt is null", K(ret)); + } else if (OB_FAIL(get_stmt()->get_rownum_expr(rownum_expr))) { + LOG_WARN("get rownum expr failed", K(ret)); + } else if (OB_ISNULL(rownum_expr)) { + ret = OB_ERR_UNEXPECTED; + LOG_WARN("no rownum expr in stmt of count operator", K(ret)); + } else { + ObSysFunRawExpr *sys_rownum_expr = static_cast(rownum_expr); + sys_rownum_expr->set_op_id(op_id_); + } + } + const ObRawExpr *rownum_expr = NULL; + if (OB_SUCC(ret) && OB_SUCC(find_rownum_expr(rownum_expr)) && rownum_expr != NULL) { + const ObSysFunRawExpr *sys_rownum_expr = static_cast(rownum_expr); + uint64_t count_op_id = sys_rownum_expr->get_op_id(); + LOG_DEBUG("the coun_op_id of rownum is", K(count_op_id)); + // rownum expr may be above count + if (count_op_id != OB_INVALID_ID) { + if (OB_FAIL(disable_rownum_expr(ctx.disabled_op_set_, count_op_id))) { + LOG_WARN("fail to disable rownum", K(ret), K(count_op_id)); } } } @@ -5686,36 +5705,35 @@ int ObLogicalOperator::collect_batch_exec_param(void* ctx, return ret; } -int ObLogicalOperator::find_rownum_expr_recursively(bool &found, const ObRawExpr *expr) +int ObLogicalOperator::find_rownum_expr_recursively(const ObRawExpr *&rownum_expr, const ObRawExpr *expr) { int ret = OB_SUCCESS; if (OB_ISNULL(expr)) { ret = OB_ERR_UNEXPECTED; LOG_WARN("expr is null", K(ret)); } else if (expr->get_expr_type() == T_FUN_SYS_ROWNUM) { - found = true; + rownum_expr = expr; } else { - for (auto i = 0; OB_SUCC(ret) && !found && i < expr->get_param_count(); i++) { - if (OB_FAIL(SMART_CALL(find_rownum_expr_recursively(found, expr->get_param_expr(i))))) { + for (int64_t i = 0; OB_SUCC(ret) && rownum_expr == NULL && i < expr->get_param_count(); i++) { + if (OB_FAIL(SMART_CALL(find_rownum_expr_recursively(rownum_expr, expr->get_param_expr(i))))) { LOG_WARN("fail to find rownum expr recursively", K(ret)); } } } LOG_DEBUG("find_rownum_expr_recursively finished", K(expr->get_param_count()), - K(expr->get_expr_type()), K(found)); + K(expr->get_expr_type())); return ret; } -int ObLogicalOperator::find_rownum_expr(bool &found, const ObIArray &exprs) +int ObLogicalOperator::find_rownum_expr(const ObRawExpr *&rownum_expr, const ObIArray &exprs) { - LOG_DEBUG("find_rownum_expr begin", K(exprs.count()), K(found)); + LOG_DEBUG("find_rownum_expr begin", K(exprs.count())); int ret = OB_SUCCESS; - for (auto i = 0; OB_SUCC(ret) && !found && i < exprs.count(); i++) { + for (int64_t i = 0; OB_SUCC(ret) && rownum_expr == NULL && i < exprs.count(); i++) { ObRawExpr *expr = exprs.at(i); - ret = find_rownum_expr_recursively(found, expr); - LOG_DEBUG( - "find_rownum_expr_recursively done:", K(expr->get_expr_type()), - K(found), K(i), K(expr->get_param_count())); + ret = find_rownum_expr_recursively(rownum_expr, expr); + LOG_DEBUG("find_rownum_expr_recursively done:", K(expr->get_expr_type()), K(i), + K(expr->get_param_count())); } return ret; } @@ -5725,68 +5743,43 @@ int ObLogicalOperator::find_rownum_expr(bool &found, const ObIArray // - output expr // - join conditions: equal ("=") // - join conditions: filter (">", "<", ">=", "<=") -int ObLogicalOperator::find_rownum_expr(bool &found) +int ObLogicalOperator::find_rownum_expr(const ObRawExpr *&rownum_expr) { int ret = OB_SUCCESS; - LOG_DEBUG("find_rownum_expr debug: ", K(get_name()), K(found)); - if (OB_FAIL(find_rownum_expr(found, get_filter_exprs()))) { + LOG_DEBUG("find_rownum_expr debug: ", K(get_name())); + if (OB_FAIL(find_rownum_expr(rownum_expr, get_filter_exprs()))) { LOG_WARN("failure encountered during find rownum expr", K(ret)); - } else if (OB_FAIL(find_rownum_expr(found, get_output_exprs()))) { + } else if (OB_FAIL(find_rownum_expr(rownum_expr, get_output_exprs()))) { LOG_WARN("failure encountered during find rownum expr", K(ret)); - } else if (!found && get_type() == log_op_def::LOG_JOIN) { + } else if (rownum_expr == NULL && get_type() == log_op_def::LOG_JOIN) { ObLogJoin *join_op = dynamic_cast(this); // NO NPE check for join_op as it should NOT be nullptr if (OB_ISNULL(join_op)) { ret = OB_ERR_UNEXPECTED; LOG_WARN("join op is null", K(ret)); - } else if (OB_FAIL(find_rownum_expr(found, join_op->get_other_join_conditions()))) { + } else if (OB_FAIL(find_rownum_expr(rownum_expr, join_op->get_other_join_conditions()))) { LOG_WARN("failure encountered during find rownum expr", K(ret)); - } else if (OB_FAIL(find_rownum_expr(found, join_op->get_equal_join_conditions()))) { + } else if (OB_FAIL(find_rownum_expr(rownum_expr, join_op->get_equal_join_conditions()))) { LOG_WARN("failure encountered during find rownum expr", K(ret)); } } - - for (auto i = 0; !found && OB_SUCC(ret) && i < get_num_of_child(); i++) { - if (OB_FAIL(SMART_CALL(get_child(i)->find_rownum_expr(found)))) { - LOG_WARN("fail to find rownum expr", K(ret)); - } - } return ret; } /** -Starting from COUNT, search downwards and add all operators on the path from COUNT to rownum() to disabled_op_set. -1. Push current op into the stack. -2. Check if the operator has rownum() first. If yes, merge cur_path into disabled_op_set. - At this point, the children nodes may still have rownum(), so cannot return and need to continue searching recursively. -3. DFS to check children nodes. -4. Pop current op out from the stack and backtrack. +After finding the rownum expression, the count operator is searched from bottom to top, and +the operators on the path are marked as not being able to add materialization. */ -int ObLogicalOperator::disable_rownum_expr(hash::ObHashSet &disabled_op_set, ObIArray &cur_path) +int ObLogicalOperator::disable_rownum_expr(hash::ObHashSet &disabled_op_set, + const uint64_t &count_op_id) { int ret = OB_SUCCESS; - bool found = false; - if (OB_FAIL(cur_path.push_back(op_id_))) { - LOG_WARN("fail to push back path", K(ret)); - } else if (OB_FAIL(find_rownum_expr(found))) { - LOG_WARN("fail to find rownum expr", K(ret)); - } else { - if (found) { - for (int64_t i = 0; OB_SUCC(ret) && i < cur_path.count(); ++i) { - ret = disabled_op_set.set_refactored(op_id_); - if (ret != OB_SUCCESS && ret != OB_HASH_EXIST) { - LOG_WARN("set_refactored fail", K(ret)); - } else { - ret = OB_SUCCESS; - } - } - } - for (int64_t i = 0; OB_SUCC(ret) && i < get_num_of_child(); i++) { - if (OB_FAIL(SMART_CALL(get_child(i)->disable_rownum_expr(disabled_op_set, cur_path)))) { - LOG_WARN("fail to disable rownum expr", K(ret)); - } - } - cur_path.pop_back(); + uint64_t op_id = op_id_; + ObLogicalOperator *parent = get_parent(); + while (OB_SUCC(ret) && op_id != count_op_id && parent != NULL) { + ret = disabled_op_set.set_refactored(op_id); + op_id = parent->get_op_id(); + parent = parent->get_parent(); } return ret; } diff --git a/src/sql/optimizer/ob_logical_operator.h b/src/sql/optimizer/ob_logical_operator.h index a84a70aff7..2611ff5547 100644 --- a/src/sql/optimizer/ob_logical_operator.h +++ b/src/sql/optimizer/ob_logical_operator.h @@ -1358,10 +1358,11 @@ public: */ int alloc_op_pre(AllocOpContext& ctx); int alloc_op_post(AllocOpContext& ctx); - int find_rownum_expr_recursively(bool &found, const ObRawExpr *expr); - int find_rownum_expr(bool &found, const ObIArray &exprs); - int find_rownum_expr(bool &found); + int find_rownum_expr_recursively(const ObRawExpr *&rownum_expr, const ObRawExpr *expr); + int find_rownum_expr(const ObRawExpr *&rownum_expr, const ObIArray &exprs); + int find_rownum_expr(const ObRawExpr *&rownum_expr); int disable_rownum_expr(hash::ObHashSet &disabled_op_set, ObIArray &cur_path); + int disable_rownum_expr(hash::ObHashSet &disabled_op_set, const uint64_t &count_op_id); int gen_temp_op_id(AllocOpContext& ctx); int recursively_disable_alloc_op_above(AllocOpContext& ctx); int alloc_nodes_above(AllocOpContext& ctx, const uint64_t &flags);