diff --git a/maxutils/maxsql/include/maxsql/mysql_plus.hh b/maxutils/maxsql/include/maxsql/mysql_plus.hh index 5e76ea13f..f73ef55f7 100644 --- a/maxutils/maxsql/include/maxsql/mysql_plus.hh +++ b/maxutils/maxsql/include/maxsql/mysql_plus.hh @@ -817,6 +817,12 @@ public: return m_command; } + bool server_will_respond() const + { + return m_command != MXS_COM_STMT_SEND_LONG_DATA // what? + && m_command != MXS_COM_QUIT + && m_command != MXS_COM_STMT_CLOSE; + } private: uint8_t m_command; }; diff --git a/maxutils/maxsql/include/maxsql/packet_tracker.hh b/maxutils/maxsql/include/maxsql/packet_tracker.hh index a6628afe1..2bdd132ec 100644 --- a/maxutils/maxsql/include/maxsql/packet_tracker.hh +++ b/maxutils/maxsql/include/maxsql/packet_tracker.hh @@ -27,7 +27,9 @@ class ComResponse; class PacketTracker { public: - enum class State {FirstPacket, Field, FieldEof, ComFieldList, Row, Done, ErrorPacket, Error}; + enum class State {FirstPacket, Field, FieldEof, Row, + ComFieldList, ComStatistics, ComStmtFetch, + Done, ErrorPacket, Error}; PacketTracker() = default; explicit PacketTracker(GWBUF* pQuery); // Track this query @@ -38,16 +40,19 @@ public: private: // State functions. - State first_packet(const ComResponse& com_packet); - State field(const ComResponse& com_packet); - State field_eof(const ComResponse& com_packet); - State com_field_list(const ComResponse& com_packet); - State row(const ComResponse& com_packet); - State expect_no_more(const ComResponse& com_packet); // states: Done, ErrorPacket, Error + State first_packet(const ComResponse& response); + State field(const ComResponse& response); + State field_eof(const ComResponse& response); + State row(const ComResponse& response); + State com_field_list(const ComResponse& response); + State com_statistics(const ComResponse& response); + State com_stmt_fetch(const ComResponse& response); + + State expect_no_more(const ComResponse& response); // states: Done, ErrorPacket, Error State m_state = State::Error; - bool m_client_packet_bool = false; - bool m_server_packet_bool = false; + bool m_client_packet_internal = false; + bool m_server_packet_internal = false; int m_command; int m_total_fields; diff --git a/maxutils/maxsql/src/packet_tracker.cc b/maxutils/maxsql/src/packet_tracker.cc index 647d34066..f43a515a7 100644 --- a/maxutils/maxsql/src/packet_tracker.cc +++ b/maxutils/maxsql/src/packet_tracker.cc @@ -12,9 +12,8 @@ */ // TODO handle client split packets -// TODO handle https://mariadb.com/kb/en/library/com_statistics/ -// TODO handle https://mariadb.com/kb/en/library/com_stmt_fetch/ -// TODO handle local infile local packet +// TODO handle local infile +// TODO do cursors need more special handling than ComStmtFetch has. #include #include @@ -25,9 +24,10 @@ namespace maxsql { -static const std::array state_names = { - "FirstPacket", "Field", "FieldEof", "ComFieldList", - "Row", "Done", "ErrorPacket", "Error" +static const std::array state_names = { + "FirstPacket", "Field", "FieldEof", "Row", + "ComFieldList", "ComStatistics", "ComStmtFetch", + "Done", "ErrorPacket", "Error" }; std::ostream& operator<<(std::ostream& os, PacketTracker::State state) @@ -37,12 +37,37 @@ std::ostream& operator<<(std::ostream& os, PacketTracker::State state) } PacketTracker::PacketTracker(GWBUF* pPacket) - : m_command(ComRequest(ComPacket(pPacket, &m_client_packet_bool)).command()) { + ComRequest request(ComPacket(pPacket, &m_client_packet_internal)); + m_command = request.command(); + MXS_SINFO("PacketTracker Command: " << STRPACKETTYPE(m_command)); // TODO remove or change to debug - // TODO mxs_mysql_command_will_respond() => ComRequest::mariadb_will_respond(); - m_state = (m_command == COM_FIELD_LIST) ? State::ComFieldList : State::FirstPacket; + if (request.server_will_respond()) + { + switch (m_command) + { + case MXS_COM_FIELD_LIST: + m_state = State::ComFieldList; + break; + + case MXS_COM_STATISTICS: + m_state = State::ComStatistics; + break; + + case MXS_COM_STMT_FETCH: + m_state = State::ComStmtFetch; + break; + + default: + m_state = State::FirstPacket; + break; + } + } + else + { + m_state = State::Done; + } } bool PacketTracker::expecting_more_packets() const @@ -59,53 +84,65 @@ bool PacketTracker::expecting_more_packets() const } } -static const std::array data_states { - PacketTracker::State::Field, PacketTracker::State::ComFieldList, PacketTracker::State::Row +static constexpr std::array data_states { + PacketTracker::State::Field, PacketTracker::State::Row, + PacketTracker::State::ComFieldList, PacketTracker::State::ComStatistics, + PacketTracker::State::ComStmtFetch }; void PacketTracker::update(GWBUF* pPacket) { - ComPacket com_packet(pPacket, &m_server_packet_bool); + ComPacket com_packet(pPacket, &m_server_packet_internal); bool expect_data_only = std::find(begin(data_states), end(data_states), m_state) != end(data_states); ComResponse response(com_packet, expect_data_only); + if (response.is_split_continuation()) + { // no state change, just more of the same data + return; + } + if (response.is_err()) { m_state = State::ErrorPacket; return; } - if (!response.is_split_continuation()) + switch (m_state) { - switch (m_state) - { - case State::FirstPacket: - m_state = first_packet(response); - break; + case State::FirstPacket: + m_state = first_packet(response); + break; - case State::Field: - m_state = field(response); - break; + case State::Field: + m_state = field(response); + break; - case State::FieldEof: - m_state = field_eof(response); - break; + case State::FieldEof: + m_state = field_eof(response); + break; - case State::ComFieldList: - m_state = com_field_list(response); - break; + case State::Row: + m_state = row(response); + break; - case State::Row: - m_state = row(response); - break; + case State::ComFieldList: + m_state = com_field_list(response); + break; - case State::Done: - case State::ErrorPacket: - case State::Error: - m_state = expect_no_more(response); - break; - } + case State::ComStatistics: + m_state = com_statistics(response); + break; + + case State::ComStmtFetch: + m_state = com_stmt_fetch(response); + break; + + case State::Done: + case State::ErrorPacket: + case State::Error: + m_state = expect_no_more(response); + break; } } @@ -172,23 +209,6 @@ PacketTracker::State PacketTracker::field_eof(const ComResponse& response) return new_state; } -PacketTracker::State PacketTracker::com_field_list(const ComResponse& response) -{ - State new_state = m_state; - - if (response.is_eof()) - { - new_state = State::Done; - } - else if (!response.is_data()) - { - MXS_SERROR("PacketTracker unexpected " << response.type() << " in state " << m_state); - new_state = State::Error; - } - - return new_state; -} - PacketTracker::State PacketTracker::row(const ComResponse& response) { State new_state = m_state; @@ -210,6 +230,65 @@ PacketTracker::State PacketTracker::row(const ComResponse& response) return new_state; } +PacketTracker::State PacketTracker::com_field_list(const ComResponse& response) +{ + State new_state = m_state; + + if (response.is_data()) + { + // ok + } + else if (response.is_eof()) + { + new_state = State::Done; + } + else + { + MXS_SERROR("PacketTracker unexpected " << response.type() << " in state " << m_state); + new_state = State::Error; + } + + return new_state; +} + +PacketTracker::State PacketTracker::com_statistics(const maxsql::ComResponse& response) +{ + State new_state = m_state; + + if (response.is_data()) + { + new_state = State::Done; + } + else + { + MXS_SERROR("PacketTracker unexpected " << response.type() << " in state " << m_state); + new_state = State::Error; + } + + return new_state; +} + +PacketTracker::State PacketTracker::com_stmt_fetch(const maxsql::ComResponse& response) +{ + State new_state = m_state; + + if (response.is_data()) + { + // ok + } + else if (response.is_eof()) + { + new_state = (ComEOF(response).more_results_exist()) ? State::ComStmtFetch : State::Done; + } + else + { + MXS_SERROR("PacketTracker unexpected " << response.type() << " in state " << m_state); + new_state = State::Error; + } + + return new_state; +} + PacketTracker::State PacketTracker::expect_no_more(const ComResponse& response) { MXS_SERROR("PacketTracker unexpected " << response.type() << " in state " << m_state);