diff --git a/query_classifier/qc_mysqlembedded/qc_mysqlembedded.cc b/query_classifier/qc_mysqlembedded/qc_mysqlembedded.cc index 6a6c4e6aa..f6da5a334 100644 --- a/query_classifier/qc_mysqlembedded/qc_mysqlembedded.cc +++ b/query_classifier/qc_mysqlembedded/qc_mysqlembedded.cc @@ -88,6 +88,7 @@ typedef struct parsing_info_st QC_FUNCTION_INFO* function_infos; size_t function_infos_len; size_t function_infos_capacity; + GWBUF* preparable_stmt; #if defined(SS_DEBUG) skygw_chk_t pi_chk_tail; #endif @@ -1547,6 +1548,8 @@ static void parsing_info_done(void* ptr) } free(pi->function_infos); + gwbuf_free(pi->preparable_stmt); + free(pi); } } @@ -1775,10 +1778,84 @@ int32_t qc_mysql_get_prepare_name(GWBUF* stmt, char** namep) int32_t qc_mysql_get_preparable_stmt(GWBUF* stmt, GWBUF** preparable_stmt) { - *preparable_stmt = NULL; + if (stmt) + { + if (ensure_query_is_parsed(stmt)) + { + LEX* lex = get_lex(stmt); + + if (lex->sql_command == SQLCOM_PREPARE) + { + parsing_info_t* pi = get_pinfo(stmt); + + if (!pi->preparable_stmt) + { + // This is terriby inefficient, but as qc_mysqlembedded is not used + // for anything else but comparisons it is ok. + const char* preparable_str = lex->prepared_stmt_code.str; + size_t preparable_str_len = lex->prepared_stmt_code.length; + + // MySQL does not parse e.g. "select * from x where ?=5". To work + // around that we'll replace all "?":s with "@a":s. We might replace + // something unnecessarily, but that won't hurt the classification. + size_t n_questions = 0; + const char* p = preparable_str; + while (p < preparable_str + preparable_str_len) + { + if (*p == '?') + { + ++n_questions; + } + + ++p; + } + + size_t preparable_stmt_len = preparable_str_len + n_questions * 2; + size_t payload_len = preparable_stmt_len + 1; + size_t packet_len = MYSQL_HEADER_LEN + payload_len; + + GWBUF* preperable_stmt = gwbuf_alloc(packet_len); + + if (preperable_stmt) + { + // Encode the length of the payload in the 3 first bytes. + *((unsigned char*)GWBUF_DATA(preperable_stmt) + 0) = payload_len; + *((unsigned char*)GWBUF_DATA(preperable_stmt) + 1) = (payload_len >> 8); + *((unsigned char*)GWBUF_DATA(preperable_stmt) + 2) = (payload_len >> 16); + // Sequence id + *((unsigned char*)GWBUF_DATA(preperable_stmt) + 3) = 0x00; + // Payload, starts with command. + *((unsigned char*)GWBUF_DATA(preperable_stmt) + 4) = COM_QUERY; + // Is followed by the statement. + char *s = (char*)GWBUF_DATA(preperable_stmt) + 5; + p = preparable_str; + + while (p < preparable_str + preparable_str_len) + { + switch (*p) + { + case '?': + *s++ = '@'; + *s = 'a'; + break; + + default: + *s = *p; + } + + ++p; + ++s; + } + } + + pi->preparable_stmt = preperable_stmt; + } + + *preparable_stmt = pi->preparable_stmt; + } + } + } - // TODO: Extract preparable stmt. - ss_dassert(!true); return QC_RESULT_OK; } diff --git a/query_classifier/qc_sqlite/qc_sqlite.c b/query_classifier/qc_sqlite/qc_sqlite.c index e254c8890..69121e4c0 100644 --- a/query_classifier/qc_sqlite/qc_sqlite.c +++ b/query_classifier/qc_sqlite/qc_sqlite.c @@ -75,9 +75,7 @@ typedef struct qc_sqlite_info int keyword_1; // The first encountered keyword. int keyword_2; // The second encountered keyword. char* prepare_name; // The name of a prepared statement. - qc_query_op_t prepare_operation; // The operation of a prepared statement. - char* preparable_stmt; // The preparable statement. - size_t preparable_stmt_length; // The length of the preparable statement. + GWBUF* preparable_stmt; // The preparable statement. QC_FIELD_INFO *field_infos; // Pointer to array of QC_FIELD_INFOs. size_t field_infos_len; // The used entries in field_infos. size_t field_infos_capacity; // The capacity of the field_infos array. @@ -347,7 +345,7 @@ static void info_finish(QC_SQLITE_INFO* info) free(info->created_table_name); free_string_array(info->database_names); free(info->prepare_name); - free(info->preparable_stmt); + gwbuf_free(info->preparable_stmt); free_field_infos(info->field_infos, info->field_infos_len); free_function_infos(info->function_infos, info->function_infos_len); } @@ -384,9 +382,7 @@ static QC_SQLITE_INFO* info_init(QC_SQLITE_INFO* info) info->keyword_1 = 0; // Sqlite3 starts numbering tokens from 1, so 0 means info->keyword_2 = 0; // that we have not seen a keyword. info->prepare_name = NULL; - info->prepare_operation = QUERY_OP_UNDEFINED; info->preparable_stmt = NULL; - info->preparable_stmt_length = 0; info->field_infos = NULL; info->field_infos_len = 0; info->field_infos_capacity = 0; @@ -529,29 +525,6 @@ static bool parse_query(GWBUF* query) this_thread.info->query = NULL; this_thread.info->query_len = 0; - if ((info->types & QUERY_TYPE_PREPARE_NAMED_STMT) && info->preparable_stmt) - { - QC_SQLITE_INFO* preparable_info = info_alloc(); - - if (preparable_info) - { - this_thread.info = preparable_info; - - const char *preparable_s = info->preparable_stmt; - size_t preparable_len = info->preparable_stmt_length; - - this_thread.info->query = preparable_s; - this_thread.info->query_len = preparable_len; - parse_query_string(preparable_s, preparable_len); - this_thread.info->query = NULL; - this_thread.info->query_len = 0; - - info->prepare_operation = preparable_info->operation; - - info_free(preparable_info); - } - } - // TODO: Add return value to gwbuf_add_buffer_object. // Always added; also when it was not recognized. If it was not recognized now, // it won't be if we try a second time. @@ -2328,11 +2301,25 @@ void maxscalePrepare(Parse* pParse, Token* pName, Token* pStmt) info->prepare_name[pName->n] = 0; } - info->preparable_stmt_length = pStmt->n - 2; - info->preparable_stmt = MXS_MALLOC(info->preparable_stmt_length); + size_t preparable_stmt_len = pStmt->n - 2; + size_t payload_len = 1 + preparable_stmt_len; + size_t packet_len = MYSQL_HEADER_LEN + payload_len; + + info->preparable_stmt = gwbuf_alloc(packet_len); + if (info->preparable_stmt) { - memcpy(info->preparable_stmt, pStmt->z + 1, pStmt->n - 2); + uint8_t* ptr = GWBUF_DATA(info->preparable_stmt); + // Payload length + *ptr++ = payload_len; + *ptr++ = (payload_len >> 8); + *ptr++ = (payload_len >> 16); + // Sequence id + *ptr++ = 0x00; + // Command + *ptr++ = MYSQL_COM_QUERY; + + memcpy(ptr, pStmt->z + 1, pStmt->n - 2); } } @@ -3396,8 +3383,7 @@ int32_t qc_sqlite_get_preparable_stmt(GWBUF* stmt, GWBUF** preparable_stmt) { if (qc_info_is_valid(info->status)) { - // TODO: Extract the preparable stmt. - ss_dassert(!true); + *preparable_stmt = info->preparable_stmt; rv = QC_RESULT_OK; } else if (MXS_LOG_PRIORITY_IS_ENABLED(LOG_INFO)) diff --git a/query_classifier/test/compare.cc b/query_classifier/test/compare.cc index 871d5f044..649d191c8 100644 --- a/query_classifier/test/compare.cc +++ b/query_classifier/test/compare.cc @@ -80,6 +80,7 @@ struct State size_t n_errors; struct timespec time1; struct timespec time2; + string indent; } global = { false, // query_printed "", // query VERBOSITY_NORMAL, // verbosity @@ -90,7 +91,9 @@ struct State 0, // n_statements 0, // n_errors { 0, 0 }, // time1 - { 0, 0} }; // time2 + { 0, 0}, // time2 + "" // indent +}; ostream& operator << (ostream& out, qc_parse_result_t x) { @@ -237,7 +240,7 @@ void report(bool success, const string& s) if (global.verbosity >= VERBOSITY_MAX) { - cout << s << endl; + cout << global.indent << s << endl; global.result_printed = true; } } @@ -252,7 +255,7 @@ void report(bool success, const string& s) report_query(); } - cout << s << endl; + cout << global.indent << s << endl; global.result_printed = true; } } @@ -324,7 +327,7 @@ bool compare_parse(QUERY_CLASSIFIER* pClassifier1, GWBUF* pCopy1, if (rv1 == rv2) { - ss << "Ok : " << rv1; + ss << "Ok : " << static_cast(rv1); success = true; } else @@ -1157,35 +1160,71 @@ bool compare_get_function_info(QUERY_CLASSIFIER* pClassifier1, GWBUF* pCopy1, } -bool compare(QUERY_CLASSIFIER* pClassifier1, QUERY_CLASSIFIER* pClassifier2, const string& s) +bool compare(QUERY_CLASSIFIER* pClassifier1, GWBUF* pBuf1, + QUERY_CLASSIFIER* pClassifier2, GWBUF* pBuf2) { - GWBUF* pCopy1 = create_gwbuf(s); - GWBUF* pCopy2 = create_gwbuf(s); - int errors = 0; - errors += !compare_parse(pClassifier1, pCopy1, pClassifier2, pCopy2); - errors += !compare_get_type(pClassifier1, pCopy1, pClassifier2, pCopy2); - errors += !compare_get_operation(pClassifier1, pCopy1, pClassifier2, pCopy2); - errors += !compare_get_created_table_name(pClassifier1, pCopy1, pClassifier2, pCopy2); - errors += !compare_is_drop_table_query(pClassifier1, pCopy1, pClassifier2, pCopy2); - errors += !compare_get_table_names(pClassifier1, pCopy1, pClassifier2, pCopy2, false); - errors += !compare_get_table_names(pClassifier1, pCopy1, pClassifier2, pCopy2, true); - errors += !compare_query_has_clause(pClassifier1, pCopy1, pClassifier2, pCopy2); - errors += !compare_get_database_names(pClassifier1, pCopy1, pClassifier2, pCopy2); - errors += !compare_get_prepare_name(pClassifier1, pCopy1, pClassifier2, pCopy2); - errors += !compare_get_field_info(pClassifier1, pCopy1, pClassifier2, pCopy2); - errors += !compare_get_function_info(pClassifier1, pCopy1, pClassifier2, pCopy2); - - gwbuf_free(pCopy1); - gwbuf_free(pCopy2); + errors += !compare_parse(pClassifier1, pBuf1, pClassifier2, pBuf2); + errors += !compare_get_type(pClassifier1, pBuf1, pClassifier2, pBuf2); + errors += !compare_get_operation(pClassifier1, pBuf1, pClassifier2, pBuf2); + errors += !compare_get_created_table_name(pClassifier1, pBuf1, pClassifier2, pBuf2); + errors += !compare_is_drop_table_query(pClassifier1, pBuf1, pClassifier2, pBuf2); + errors += !compare_get_table_names(pClassifier1, pBuf1, pClassifier2, pBuf2, false); + errors += !compare_get_table_names(pClassifier1, pBuf1, pClassifier2, pBuf2, true); + errors += !compare_query_has_clause(pClassifier1, pBuf1, pClassifier2, pBuf2); + errors += !compare_get_database_names(pClassifier1, pBuf1, pClassifier2, pBuf2); + errors += !compare_get_prepare_name(pClassifier1, pBuf1, pClassifier2, pBuf2); + errors += !compare_get_field_info(pClassifier1, pBuf1, pClassifier2, pBuf2); + errors += !compare_get_function_info(pClassifier1, pBuf1, pClassifier2, pBuf2); if (global.result_printed) { cout << endl; } - return errors == 0; + bool success = (errors == 0); + + uint32_t type_mask1; + pClassifier1->qc_get_type(pBuf1, &type_mask1); + + uint32_t type_mask2; + pClassifier2->qc_get_type(pBuf2, &type_mask2); + + if ((type_mask1 == type_mask2) && + ((type_mask1 & QUERY_TYPE_PREPARE_NAMED_STMT) || (type_mask1 & QUERY_TYPE_PREPARE_STMT))) + { + GWBUF* pPreparable1; + pClassifier1->qc_get_preparable_stmt(pBuf1, &pPreparable1); + ss_dassert(pPreparable1); + + GWBUF* pPreparable2; + pClassifier2->qc_get_preparable_stmt(pBuf2, &pPreparable2); + ss_dassert(pPreparable2); + + string indent = global.indent; + global.indent += string(4, ' '); + + success = compare(pClassifier1, pPreparable1, + pClassifier2, pPreparable2); + + global.indent = indent; + } + + return success; +} + +bool compare(QUERY_CLASSIFIER* pClassifier1, QUERY_CLASSIFIER* pClassifier2, const string& s) +{ + GWBUF* pCopy1 = create_gwbuf(s); + GWBUF* pCopy2 = create_gwbuf(s); + + bool success = compare(pClassifier1, pCopy1, pClassifier2, pCopy2); + + gwbuf_free(pCopy1); + gwbuf_free(pCopy2); + + return success; } inline void ltrim(std::string &s)