diff --git a/server/core/modutil.c b/server/core/modutil.c index 9c06d1c05..d5fa3bc0a 100644 --- a/server/core/modutil.c +++ b/server/core/modutil.c @@ -34,6 +34,7 @@ #include #include #include +#include /** These are used when converting MySQL wildcards to regular expressions */ static SPINLOCK re_lock = SPINLOCK_INIT; @@ -765,7 +766,7 @@ static void modutil_reply_routing_error(DCB* backend_dcb, * @return Pointer to the first non-escaped, non-quoted occurrence of the character. * If the character is not found, NULL is returned. */ -void* strnchr_esc(char* ptr, char c, int len) +char* strnchr_esc(char* ptr, char c, int len) { char* p = (char*)ptr; char* start = p; @@ -804,27 +805,28 @@ void* strnchr_esc(char* ptr, char c, int len) /** * Find the first occurrence of a character in a string. This function ignores * escaped characters and all characters that are enclosed in single or double quotes. - * Also MySQL style comment blocks are ignored. + * MySQL style comment blocks and identifiers in backticks are also ignored. * @param ptr Pointer to area of memory to inspect * @param c Character to search for * @param len Size of the memory area * @return Pointer to the first non-escaped, non-quoted occurrence of the character. * If the character is not found, NULL is returned. */ -void* strnchr_esc_mysql(char* ptr, char c, int len) +char* strnchr_esc_mysql(char* ptr, char c, int len) { char* p = (char*) ptr; - char* start = p; - bool quoted = false, escaped = false, backtick = false; + char* start = p, *end = start + len; + bool quoted = false, escaped = false, backtick = false, comment = false; char qc; - while (p < start + len) + while (p < end) { if (escaped) { escaped = false; } - else + else if ((!comment && !quoted && !backtick) || (comment && *p == '*') || + (!comment && quoted && *p == qc) || (!comment && backtick && *p == '`')) { switch (*p) { @@ -832,7 +834,7 @@ void* strnchr_esc_mysql(char* ptr, char c, int len) escaped = true; break; - case'\'': + case '\'': case '"': if (!quoted) { @@ -845,31 +847,47 @@ void* strnchr_esc_mysql(char* ptr, char c, int len) } break; + case '/': + if (p + 1 < end && *(p + 1) == '*') + { + comment = true; + p += 1; + } + break; + + case '*': + if (comment && p + 1 < end && *(p + 1) == '/') + { + comment = false; + p += 1; + } + break; + case '`': backtick = !backtick; break; case '#': - if (!backtick) + return NULL; + + case '-': + if (p + 2 < end && *(p + 1) == '-' && + isspace(*(p + 2))) { return NULL; } break; - case '-': - if (!backtick && p + 1 < start + len && *(p + 1) == '-') - { - return NULL; - } + default: break; } - if (*p == c && !escaped && !quoted) + if (*p == c && !escaped && !quoted && !comment && !backtick) { return p; } - p++; } + p++; } return NULL; } @@ -883,7 +901,7 @@ bool is_mysql_comment_start(const char* start, int len) { const char *ptr = start; - while (ptr - start < len && isspace(*ptr)) + while (ptr < start + len && (isspace(*ptr) || *ptr == ';')) { ptr++; } @@ -891,7 +909,7 @@ bool is_mysql_comment_start(const char* start, int len) switch (*ptr) { case '-': - if (*(ptr + 1) == '-') + if (*(ptr + 1) == '-' && isspace(*(ptr + 2))) { return true; } @@ -910,6 +928,23 @@ bool is_mysql_comment_start(const char* start, int len) return false; } +/** + * @brief Check if the token is the END part of a BEGIN ... END block. + * @param ptr String with at least three non-whitespace characters in it + * @return True if the token is the final part of a BEGIN .. END block + */ +bool is_mysql_sp_end(const char* start, int len) +{ + const char *ptr = start; + + while (ptr < start + len && (isspace(*ptr) || *ptr == ';')) + { + ptr++; + } + + return ptr < start + len - 3 && strncasecmp(ptr, "end", 3) == 0; +} + /** * Create a COM_QUERY packet from a string. * @param query Query to create. diff --git a/server/include/modutil.h b/server/include/modutil.h index 288adcc86..1223f880a 100644 --- a/server/include/modutil.h +++ b/server/include/modutil.h @@ -69,8 +69,9 @@ int modutil_count_signal_packets(GWBUF*,int,int,int*); mxs_pcre2_result_t modutil_mysql_wildcard_match(const char* pattern, const char* string); /** Character and token searching functions */ -void* strnchr_esc(char* ptr, char c, int len); -void* strnchr_esc_mysql(char* ptr, char c, int len); +char* strnchr_esc(char* ptr, char c, int len); +char* strnchr_esc_mysql(char* ptr, char c, int len); bool is_mysql_comment_start(const char* start, int len); +bool is_mysql_sp_end(const char* start, int len); #endif diff --git a/server/modules/routing/readwritesplit/readwritesplit.c b/server/modules/routing/readwritesplit/readwritesplit.c index 45411e637..467f0dd79 100644 --- a/server/modules/routing/readwritesplit/readwritesplit.c +++ b/server/modules/routing/readwritesplit/readwritesplit.c @@ -5366,11 +5366,19 @@ static void check_for_multi_stmt(ROUTER_CLIENT_SES* rses, GWBUF *buf, if ((ptr = strnchr_esc_mysql(data, ';', buflen))) { - ptr++; - if (ptr - data < buflen && !is_mysql_comment_start(ptr, ptr - data)) + /** Skip stored procedures etc. */ + while (ptr && is_mysql_sp_end(ptr, ptr - data)) { - rses->forced_node = rses->rses_master_ref; - MXS_INFO("Multi-statement query, routing all future queries to master."); + ptr = strnchr_esc_mysql(ptr + 1, ';', ptr - data); + } + + if (ptr) + { + if (ptr < data + buflen && !is_mysql_comment_start(ptr, ptr - data)) + { + rses->forced_node = rses->rses_master_ref; + MXS_INFO("Multi-statement query, routing all future queries to master."); + } } } }