From 07e183f312b9a79220d4bb8e3a32ddf1f7fe8562 Mon Sep 17 00:00:00 2001 From: haohao022 Date: Thu, 10 Oct 2024 10:16:24 +0000 Subject: [PATCH] [CP] [to #2024091400104482136] fix: add defending check for reading proto buffer --- src/observer/mysql/obmp_stmt_execute.cpp | 16 ++++++------ src/observer/mysql/obmp_stmt_execute.h | 4 +-- .../mysql/obmp_stmt_send_long_data.cpp | 11 +++++--- src/observer/mysql/obmp_stmt_send_long_data.h | 1 + .../mysql/obmp_stmt_send_piece_data.cpp | 26 ++++++++++++------- .../mysql/obmp_stmt_send_piece_data.h | 2 ++ 6 files changed, 38 insertions(+), 22 deletions(-) diff --git a/src/observer/mysql/obmp_stmt_execute.cpp b/src/observer/mysql/obmp_stmt_execute.cpp index 231854e7a..fd46040a5 100644 --- a/src/observer/mysql/obmp_stmt_execute.cpp +++ b/src/observer/mysql/obmp_stmt_execute.cpp @@ -558,7 +558,8 @@ int ObMPStmtExecute::before_process() } const ObMySQLRawPacket &pkt = reinterpret_cast(req_->get_packet()); const char* pos = pkt.get_cdata(); - analysis_checker_.init(pos, pkt.get_clen()); + // pkt.get_cdata() do not include 1 byte for `request command code` + analysis_checker_.init(pos, pkt.get_clen() - 1); int32_t stmt_id = -1; //INVALID_STMT_ID uint32_t ps_stmt_checksum = 0; ObSQLSessionInfo *session = NULL; @@ -903,19 +904,18 @@ int ObMPStmtExecute::request_params(ObSQLSessionInfo *session, } } if (OB_SUCC(ret) && params_num_ > 0) { - // Step1: 处理空值位图 - int64_t bitmap_types = (params_num_ + 7) / 8; - const char *bitmap = pos; - pos += bitmap_types; ParamTypeArray ¶m_types = ps_session_info->get_param_types(); ParamTypeInfoArray param_type_infos; ParamCastArray param_cast_infos; - ParamTypeArray returning_param_types; ParamTypeInfoArray returning_param_type_infos; - int64_t len = bitmap_types + 1/*new_param_bound_flag*/; - PS_DEFENSE_CHECK(len) // bitmap_types + + // Step1: 处理空值位图 + const char *bitmap = pos; + int64_t bitmap_types = (params_num_ + 7) / 8; + PS_DEFENSE_CHECK(bitmap_types + 1) // null value bitmap + new param bound flag { + pos += bitmap_types; // Step2: 获取new_param_bound_flag字段 ObMySQLUtil::get_int1(pos, new_param_bound_flag); if (new_param_bound_flag == 1) { diff --git a/src/observer/mysql/obmp_stmt_execute.h b/src/observer/mysql/obmp_stmt_execute.h index 1ebe12c54..51b0592a6 100644 --- a/src/observer/mysql/obmp_stmt_execute.h +++ b/src/observer/mysql/obmp_stmt_execute.h @@ -78,8 +78,8 @@ public: #define PS_STATIC_DEFENSE_CHECK(checker, len) \ if (OB_FAIL(ret)) { \ - } else if (nullptr != checker \ - && OB_FAIL(checker->detection(len))) { \ + } else if (nullptr != (checker) \ + && OB_FAIL((checker)->detection(len))) { \ LOG_WARN("memory access out of bounds", K(ret)); \ } else diff --git a/src/observer/mysql/obmp_stmt_send_long_data.cpp b/src/observer/mysql/obmp_stmt_send_long_data.cpp index 1637ce56b..6c978702d 100644 --- a/src/observer/mysql/obmp_stmt_send_long_data.cpp +++ b/src/observer/mysql/obmp_stmt_send_long_data.cpp @@ -69,9 +69,14 @@ int ObMPStmtSendLongData::before_process() } else { const ObMySQLRawPacket &pkt = reinterpret_cast(req_->get_packet()); const char* pos = pkt.get_cdata(); - // stmt_id - ObMySQLUtil::get_int4(pos, stmt_id_); - ObMySQLUtil::get_uint2(pos, param_id_); + defender_.init(pos, pkt.get_clen() - 1); // pkt.get_cdata() do not include 1 byte for `request command code` + + PS_STATIC_DEFENSE_CHECK(&defender_, 4 + 2) + { + ObMySQLUtil::get_int4(pos, stmt_id_); + ObMySQLUtil::get_uint2(pos, param_id_); + } + if (OB_SUCC(ret) && stmt_id_ < 1) { ret = OB_ERR_PARAM_INVALID; LOG_WARN("send_long_data receive unexpected stmt_id_", K(ret), K(stmt_id_), K(param_id_)); diff --git a/src/observer/mysql/obmp_stmt_send_long_data.h b/src/observer/mysql/obmp_stmt_send_long_data.h index 0450b0bfd..ce586459b 100644 --- a/src/observer/mysql/obmp_stmt_send_long_data.h +++ b/src/observer/mysql/obmp_stmt_send_long_data.h @@ -77,6 +77,7 @@ private: uint64_t buffer_len_; common::ObString buffer_; bool need_disconnect_; + ObPSAnalysisChecker defender_; private: DISALLOW_COPY_AND_ASSIGN(ObMPStmtSendLongData); diff --git a/src/observer/mysql/obmp_stmt_send_piece_data.cpp b/src/observer/mysql/obmp_stmt_send_piece_data.cpp index 43d40e53c..59cc87fea 100644 --- a/src/observer/mysql/obmp_stmt_send_piece_data.cpp +++ b/src/observer/mysql/obmp_stmt_send_piece_data.cpp @@ -72,27 +72,35 @@ int ObMPStmtSendPieceData::before_process() } else { const ObMySQLRawPacket &pkt = reinterpret_cast(req_->get_packet()); const char* pos = pkt.get_cdata(); - // stmt_id - ObMySQLUtil::get_int4(pos, stmt_id_); - ObMySQLUtil::get_uint2(pos, param_id_); - ObMySQLUtil::get_int1(pos, piece_mode_); - int8_t is_null = 0; - ObMySQLUtil::get_int1(pos, is_null); - is_null_ = (1 == is_null); - ObMySQLUtil::get_int8(pos, buffer_len_); + defender_.init(pos, pkt.get_clen() - 1); // pkt.get_cdata() do not include 1 byte for `request command code` + + PS_STATIC_DEFENSE_CHECK(&defender_, 4 + 2 + 1 + 1 + 8) + { + ObMySQLUtil::get_int4(pos, stmt_id_); + ObMySQLUtil::get_uint2(pos, param_id_); + ObMySQLUtil::get_int1(pos, piece_mode_); + int8_t is_null = 0; + ObMySQLUtil::get_int1(pos, is_null); + is_null_ = (1 == is_null); + ObMySQLUtil::get_int8(pos, buffer_len_); + } + if (stmt_id_ < 1 || buffer_len_ < 0) { ret = OB_ERR_MALFORMED_PS_PACKET; LOG_WARN("send_piece receive unexpected params", K(ret), K(stmt_id_), K(buffer_len_)); } else if (param_id_ >= OB_PARAM_ID_OVERFLOW_RISK_THRESHOLD) { LOG_WARN("param_id_ has the risk of overflow", K(ret), K(stmt_id_), K(param_id_)); } - if (OB_SUCC(ret)) { + + PS_STATIC_DEFENSE_CHECK(&defender_, buffer_len_) + { buffer_.assign_ptr(pos, static_cast(buffer_len_)); pos += buffer_len_; LOG_INFO("resolve send_piece protocol packet successfully", K(ret), K(stmt_id_), K(param_id_), K(buffer_len_)); LOG_DEBUG("send_piece packet content", K(buffer_)); } + LOG_INFO("resolve send_piece protocol packet", K(ret), K(stmt_id_), K(param_id_), K(buffer_len_), K(piece_mode_), K(is_null_)); } diff --git a/src/observer/mysql/obmp_stmt_send_piece_data.h b/src/observer/mysql/obmp_stmt_send_piece_data.h index cf726e31c..cc652e36f 100644 --- a/src/observer/mysql/obmp_stmt_send_piece_data.h +++ b/src/observer/mysql/obmp_stmt_send_piece_data.h @@ -17,6 +17,7 @@ #include "observer/mysql/obmp_base.h" #include "observer/mysql/ob_query_retry_ctrl.h" #include "lib/rc/context.h" +#include "observer/mysql/obmp_stmt_execute.h" namespace oceanbase { @@ -77,6 +78,7 @@ private: common::ObString buffer_; int8_t piece_mode_; bool is_null_; + ObPSAnalysisChecker defender_; private: DISALLOW_COPY_AND_ASSIGN(ObMPStmtSendPieceData);