diff --git a/include/maxscale/protocol/mysql.h b/include/maxscale/protocol/mysql.h index 07e2764aa..6a33eeffe 100644 --- a/include/maxscale/protocol/mysql.h +++ b/include/maxscale/protocol/mysql.h @@ -95,6 +95,21 @@ MXS_BEGIN_DECLS #define GW_MYSQL_SCRAMBLE_SIZE 20 #define GW_SCRAMBLE_LENGTH_323 8 +/** + * Prepared statement payload response offsets for a COM_STMT_PREPARE response: + * + * [0] OK (1) -- always 0x00 + * [1-4] statement_id (4) -- statement-id + * [5-6] num_columns (2) -- number of columns + * [7-8] num_params (2) -- number of parameters + * [9] filler + * [10-11] warning_count (2) -- number of warnings + */ +#define MYSQL_PS_ID_OFFSET MYSQL_HEADER_LEN + 1 +#define MYSQL_PS_COLS_OFFSET MYSQL_HEADER_LEN + 5 +#define MYSQL_PS_PARAMS_OFFSET MYSQL_HEADER_LEN + 7 +#define MYSQL_PS_WARN_OFFSET MYSQL_HEADER_LEN + 10 + /** Name of the default server side authentication plugin */ #define DEFAULT_MYSQL_AUTH_PLUGIN "mysql_native_password" @@ -437,6 +452,15 @@ bool mxs_mysql_is_ok_packet(GWBUF *buffer); */ bool mxs_mysql_is_result_set(GWBUF *buffer); +/** + * @brief Check if the buffer contains a prepared statement OK packet + * + * @param buffer Buffer to check + * + * @return True if the @c buffer contains a prepared statement OK packet + */ +bool mxs_mysql_is_prep_stmt_ok(GWBUF *buffer); + /** * @brief Check if the OK packet is followed by another result * diff --git a/server/modules/protocol/MySQL/MySQLBackend/mysql_backend.c b/server/modules/protocol/MySQL/MySQLBackend/mysql_backend.c index 14e8b57c6..9e50c4b2b 100644 --- a/server/modules/protocol/MySQL/MySQLBackend/mysql_backend.c +++ b/server/modules/protocol/MySQL/MySQLBackend/mysql_backend.c @@ -604,6 +604,36 @@ static inline bool expecting_resultset(MySQLProtocol *proto) proto->current_command == MYSQL_COM_STMT_FETCH; } +static inline bool expecting_ps_response(MySQLProtocol *proto) +{ + return proto->current_command == MYSQL_COM_STMT_PREPARE; +} + +static inline bool complete_ps_response(GWBUF *buffer) +{ + ss_dassert(GWBUF_IS_CONTIGUOUS(buffer)); + uint16_t cols = gw_mysql_get_byte2(GWBUF_DATA(buffer) + MYSQL_PS_COLS_OFFSET); + uint16_t params = gw_mysql_get_byte2(GWBUF_DATA(buffer) + MYSQL_PS_PARAMS_OFFSET); + int expected_eof = 0; + + if (cols > 0) + { + expected_eof++; + } + + if (params > 0) + { + expected_eof++; + } + + bool more; + int n_eof = modutil_count_signal_packets(buffer, 0, &more); + + MXS_DEBUG("Expecting %u EOF, have %u", n_eof, expected_eof); + + return n_eof == expected_eof; +} + static inline bool collecting_resultset(MySQLProtocol *proto, uint64_t capabilities) { return rcap_type_required(capabilities, RCAP_TYPE_RESULTSET_OUTPUT) || @@ -649,6 +679,7 @@ gw_read_and_write(DCB *dcb) /** Ask what type of output the router/filter chain expects */ uint64_t capabilities = service_get_capabilities(session->service); + bool result_collected = false; if (rcap_type_required(capabilities, RCAP_TYPE_STMT_OUTPUT)) { @@ -682,19 +713,35 @@ gw_read_and_write(DCB *dcb) return 0; } - if (collecting_resultset(proto, capabilities) && - expecting_resultset(proto) && - mxs_mysql_is_result_set(read_buffer)) + if (collecting_resultset(proto, capabilities)) { - bool more = false; - if (modutil_count_signal_packets(read_buffer, 0, &more) != 2) + if (expecting_resultset(proto) && + mxs_mysql_is_result_set(read_buffer)) { - dcb->dcb_readqueue = read_buffer; - return 0; - } + bool more = false; + if (modutil_count_signal_packets(read_buffer, 0, &more) != 2) + { + dcb->dcb_readqueue = gwbuf_append(read_buffer, dcb->dcb_readqueue); + return 0; + } - // Collected the complete result - proto->collect_result = false; + // Collected the complete result + proto->collect_result = false; + result_collected = true; + } + else if (expecting_ps_response(proto) && + mxs_mysql_is_prep_stmt_ok(read_buffer)) + { + if (!complete_ps_response(read_buffer)) + { + dcb->dcb_readqueue = gwbuf_append(read_buffer, dcb->dcb_readqueue); + return 0; + } + + // Collected the complete result + proto->collect_result = false; + result_collected = true; + } } } } @@ -764,7 +811,8 @@ gw_read_and_write(DCB *dcb) } } else if (rcap_type_required(capabilities, RCAP_TYPE_STMT_OUTPUT) && - !rcap_type_required(capabilities, RCAP_TYPE_RESULTSET_OUTPUT)) + !rcap_type_required(capabilities, RCAP_TYPE_RESULTSET_OUTPUT) && + !result_collected) { stmt = modutil_get_next_MySQL_packet(&read_buffer); } diff --git a/server/modules/protocol/MySQL/mysql_common.c b/server/modules/protocol/MySQL/mysql_common.c index ed33b902f..83e530db1 100644 --- a/server/modules/protocol/MySQL/mysql_common.c +++ b/server/modules/protocol/MySQL/mysql_common.c @@ -1558,6 +1558,20 @@ bool mxs_mysql_is_result_set(GWBUF *buffer) return rval; } +bool mxs_mysql_is_prep_stmt_ok(GWBUF *buffer) +{ + bool rval = false; + uint8_t cmd; + + if (gwbuf_copy_data(buffer, MYSQL_HEADER_LEN, 1, &cmd) && + cmd == MYSQL_REPLY_OK) + { + rval = true; + } + + return rval; +} + bool mxs_mysql_more_results_after_ok(GWBUF *buffer) { bool rval = false;