diff --git a/query_classifier/qc_mysqlembedded/qc_mysqlembedded.cc b/query_classifier/qc_mysqlembedded/qc_mysqlembedded.cc index e72aebcaa..b2ebde1f2 100644 --- a/query_classifier/qc_mysqlembedded/qc_mysqlembedded.cc +++ b/query_classifier/qc_mysqlembedded/qc_mysqlembedded.cc @@ -74,6 +74,20 @@ #include #include +#if MYSQL_VERSION_MAJOR >= 10 && MYSQL_VERSION_MINOR >= 2 +#define CTE_SUPPORTED +#endif + +#if defined(CTE_SUPPORTED) +// We need to be able to access private data of With_element that has no +// public access methods. So, we use this very questionable method of +// making the private parts public. Ok, as qc_myselembedded is only +// used for verifying the output of qc_sqlite. +#define private public +#include +#undef private +#endif + /** * Defines what a particular name should be mapped to. */ @@ -159,6 +173,31 @@ static bool parse_query(GWBUF* querybuf); static bool query_is_parsed(GWBUF* buf); int32_t qc_mysql_get_field_info(GWBUF* buf, const QC_FIELD_INFO** infos, uint32_t* n_infos); +#if MYSQL_VERSION_MAJOR >= 10 && MYSQL_VERSION_MINOR >= 3 +inline void get_string_and_length(const LEX_CSTRING& ls, const char** s, size_t* length) +{ + *s = ls.str; + *length = ls.length; +} +#else +inline void get_string_and_length(const char* cs, const char** s, size_t* length) +{ + *s = cs; + *length = cs ? strlen(cs) : 0; +} +#endif + +static struct +{ + qc_sql_mode_t sql_mode; + pthread_mutex_t sql_mode_mutex; + NAME_MAPPING* function_name_mappings; +} this_unit = +{ + QC_SQL_MODE_DEFAULT, + PTHREAD_MUTEX_INITIALIZER, + function_name_mappings_default +}; #if MYSQL_VERSION_MAJOR >= 10 && MYSQL_VERSION_MINOR >= 3 inline void get_string_and_length(const LEX_CSTRING& ls, const char** s, size_t* length) @@ -2442,6 +2481,7 @@ static bool should_function_be_ignored(parsing_info_t* pi, const char* func_name (strcasecmp(func_name, "cast_as_unsigned") == 0) || (strcasecmp(func_name, "get_user_var") == 0) || (strcasecmp(func_name, "get_system_var") == 0) || + (strcasecmp(func_name, "not") == 0) || (strcasecmp(func_name, "set_user_var") == 0) || (strcasecmp(func_name, "set_system_var") == 0)) { @@ -2678,7 +2718,25 @@ static void update_field_infos(parsing_info_t* pi, break; case Item_subselect::EXISTS_SUBS: - // TODO: Handle these explicitly as well. + { + Item_exists_subselect* exists_subselect_item = + static_cast(item); + + st_select_lex* ssl = exists_subselect_item->get_select_lex(); + if (ssl) + { + uint32_t sub_usage = usage; + + sub_usage &= ~QC_USED_IN_SELECT; + sub_usage |= QC_USED_IN_SUBSELECT; + + update_field_infos(pi, + get_lex(pi), + ssl, + sub_usage, + excludep); + } + } break; case Item_subselect::SINGLEROW_SUBS: @@ -2706,6 +2764,22 @@ static void update_field_infos(parsing_info_t* pi, } } +#ifdef CTE_SUPPORTED +static void update_field_infos(parsing_info_t* pi, + LEX* lex, + st_select_lex_unit* select, + uint32_t usage, + List* excludep) +{ + st_select_lex* s = select->first_select(); + + if (s) + { + update_field_infos(pi, lex, s, usage, excludep); + } +} +#endif + static void update_field_infos(parsing_info_t* pi, LEX* lex, st_select_lex* select, @@ -2821,6 +2895,33 @@ int32_t qc_mysql_get_field_info(GWBUF* buf, const QC_FIELD_INFO** infos, uint32_ update_field_infos(pi, lex, &lex->select_lex, usage, NULL); +#ifdef CTE_SUPPORTED + if (lex->with_clauses_list) + { + With_clause* with_clause = lex->with_clauses_list; + + while (with_clause) + { + SQL_I_List& with_list = with_clause->with_list; + With_element* element = with_list.first; + + while (element) + { + update_field_infos(pi, lex, element->spec, usage, NULL); + + if (element->first_recursive) + { + update_field_infos(pi, lex, element->first_recursive, usage, NULL); + } + + element = element->next; + } + + with_clause = with_clause->next_with_clause; + } + } +#endif + List_iterator ilist(lex->value_list); while (Item* item = ilist++) { diff --git a/query_classifier/qc_sqlite/qc_sqlite.c b/query_classifier/qc_sqlite/qc_sqlite.c index 4bac109f1..8eecb6b79 100644 --- a/query_classifier/qc_sqlite/qc_sqlite.c +++ b/query_classifier/qc_sqlite/qc_sqlite.c @@ -220,6 +220,8 @@ static void update_field_infos_from_select(QC_SQLITE_INFO* info, const Select* pSelect, uint32_t usage, const ExprList* pExclude); +static void update_field_infos_from_with(QC_SQLITE_INFO* info, + const With* pWith); static void update_function_info(QC_SQLITE_INFO* info, const char* name, uint32_t usage); @@ -1070,6 +1072,20 @@ static void update_field_infos_from_expr(QC_SQLITE_INFO* info, } } +static void update_field_infos_from_with(QC_SQLITE_INFO* info, + const With* pWith) +{ + for (int i = 0; i < pWith->nCte; ++i) + { + const struct Cte* pCte = &pWith->a[i]; + + if (pCte->pSelect) + { + update_field_infos_from_select(info, pCte->pSelect, QC_USED_IN_SELECT, NULL); + } + } +} + static const char* get_token_symbol(int token) { switch (token) @@ -1466,6 +1482,11 @@ static void update_field_infos_from_select(QC_SQLITE_INFO* info, update_field_infos(info, 0, pSelect->pHaving, 0, QC_TOKEN_MIDDLE, pSelect->pEList); #endif } + + if (pSelect->pWith) + { + update_field_infos_from_with(info, pSelect->pWith); + } } static void update_database_names(QC_SQLITE_INFO* info, const char* zDatabase) diff --git a/query_classifier/test/CMakeLists.txt b/query_classifier/test/CMakeLists.txt index e9ad22ec0..eda63761f 100644 --- a/query_classifier/test/CMakeLists.txt +++ b/query_classifier/test/CMakeLists.txt @@ -47,6 +47,9 @@ if (BUILD_QC_MYSQLEMBEDDED) add_test(TestQC_version_sensitivity version_sensitivity) + if(NOT (MYSQL_EMBEDDED_VERSION VERSION_LESS 10.2)) + add_test(TestQC_cte_simple compare -v 2 ${CMAKE_CURRENT_SOURCE_DIR}/cte_simple.test) + endif() if(NOT (MYSQL_EMBEDDED_VERSION VERSION_LESS 10.3)) add_test(TestQC_Oracle-binlog_stm_ps compare -v 2 ${CMAKE_CURRENT_SOURCE_DIR}/oracle/binlog_stm_ps.test) add_test(TestQC_Oracle-binlog_stm_sp compare -v 2 ${CMAKE_CURRENT_SOURCE_DIR}/oracle/binlog_stm_sp.test) diff --git a/query_classifier/test/cte_simple.test b/query_classifier/test/cte_simple.test new file mode 100644 index 000000000..68c0aa8e9 --- /dev/null +++ b/query_classifier/test/cte_simple.test @@ -0,0 +1,15 @@ +WITH t AS (SELECT a FROM t1 WHERE b >= 'c') SELECT * FROM t2,t WHERE t2.c=t.a; + +SELECT t1.a,t1.b FROM t1,t2 + WHERE t1.a>t2.c AND + t2.c in (WITH t as (SELECT * FROM t1 WHERE t1.a<5) + SELECT t2.c FROM t2,t WHERE t2.c=t.a); + +WITH engineers AS ( + SELECT * FROM employees WHERE dept IN ('Development','Support') +) +SELECT * FROM engineers E1 + WHERE NOT EXISTS (SELECT 1 + FROM engineers E2 + WHERE E2.country=E1.country + AND E2.name <> E1.name);