diff --git a/Documentation/Filters/Masking.md b/Documentation/Filters/Masking.md index b3030cf57..a16146a4c 100644 --- a/Documentation/Filters/Masking.md +++ b/Documentation/Filters/Masking.md @@ -71,6 +71,14 @@ Please see the configuration parameter [prevent_function_usage](#prevent_function_usage) for how to change the default behaviour. +From MaxScale 2.3.5 onwards, the masking filter will check the +definition of user variables and reject statements that define a user +variable using a statement that refers to columns that should be masked. + +Please see the configuration parameter +[check_user_variables](#check_user_variables) +for how to change the default behaviour. + ## Limitations The masking filter can _only_ be used for masking columns of the following @@ -170,6 +178,20 @@ prevent_function_usage=false ``` The default value is `true`. +#### `check_user_variables` + +This optional parameter specifies how the masking filter should +behave with respect to user variables. If true, then a statement like +``` +set @a = (select ssn from customer where id = 1); +``` +will be rejected if `ssn` is a column that should be masked. +``` +check_user_variables=false +``` + +The default value is `true`. + ## Rules The masking rules are expressed as a JSON object. diff --git a/maxscale-system-test/masking_auto_firewall.cpp b/maxscale-system-test/masking_auto_firewall.cpp index a6f91e0a8..7f81f8577 100644 --- a/maxscale-system-test/masking_auto_firewall.cpp +++ b/maxscale-system-test/masking_auto_firewall.cpp @@ -28,6 +28,50 @@ void init(TestConnections& test) test.try_query(pMysql, "INSERT INTO masking_auto_firewall VALUES ('hello', 'world')"); } +enum class Expect +{ + FAILURE, + SUCCESS +}; + +void test_one(TestConnections& test, const char* zQuery, Expect expect) +{ + MYSQL* pMysql = test.maxscales->conn_rwsplit[0]; + + const char* zExpect = (expect == Expect::SUCCESS ? "SHOULD" : "should NOT"); + + test.tprintf("Executing '%s', %s succeed.", zQuery, zExpect); + int rv = execute_query_silent(pMysql, zQuery); + + if (expect == Expect::SUCCESS) + { + test.add_result(rv, "Could NOT execute query '%s'.", zQuery); + } + else + { + test.add_result(rv == 0, "COULD execute query '%s'.", zQuery); + } +} + +void test_one_ps(TestConnections& test, const char* zQuery, Expect expect) +{ + MYSQL* pMysql = test.maxscales->conn_rwsplit[0]; + + MYSQL_STMT* pPs = mysql_stmt_init(pMysql); + int rv = mysql_stmt_prepare(pPs, zQuery, strlen(zQuery)); + + if (expect == Expect::SUCCESS) + { + test.add_result(rv, "Could NOT prepare statement."); + } + else + { + test.add_result(rv == 0, "COULD prepare statement."); + } + + mysql_stmt_close(pPs); +} + void run(TestConnections& test) { init(test); @@ -36,18 +80,38 @@ void run(TestConnections& test) int rv; - // This should go through, a is simply masked. - static const char* zMasked_query = "SELECT a, b FROM masking_auto_firewall"; - test.tprintf("Executing '%s', SHOULD succeed.", zMasked_query); - rv = execute_query(pMysql, "%s", zMasked_query); - test.add_result(rv, "Could NOT execute query '%s'.", zMasked_query); + // This SHOULD go through, a is simply masked. + test_one(test, "SELECT a, b FROM masking_auto_firewall", Expect::SUCCESS); // This should NOT go through as a function is used with a masked column. - static const char* zRejected_query = "SELECT LENGTH(a), b FROM masking_auto_firewall"; - test.tprintf("Executing '%s', should NOT succeed.", zRejected_query); - rv = execute_query_silent(pMysql, zRejected_query); - test.add_result(rv == 0, "COULD execute query '%s'.", zRejected_query); + test_one(test, "SELECT LENGTH(a), b FROM masking_auto_firewall", Expect::FAILURE); + + // This SHOULD go through as a function is NOT used with a masked column + // in a prepared statement. + test_one(test, "PREPARE ps1 FROM 'SELECT a, LENGTH(b) FROM masking_auto_firewall'", Expect::SUCCESS); + + // This should NOT go through as a function is used with a masked column + // in a prepared statement. + test_one(test, "PREPARE ps2 FROM 'SELECT LENGTH(a), b FROM masking_auto_firewall'", Expect::FAILURE); + + rv = execute_query_silent(pMysql, "set @a = 'SELECT LENGTH(a), b FROM masking_auto_firewall'"); + test.add_result(rv, "Could NOT set variable."); + // This should NOT go through as a prepared statement is prepared from a variable. + test_one(test, "PREPARE ps3 FROM @a", Expect::FAILURE); + + // This SHOULD succeed as a function is NOT used with a masked column + // in a binary prepared statement. + test_one_ps(test, "SELECT a, LENGTH(b) FROM masking_auto_firewall", Expect::SUCCESS); + + // This should NOT succeed as a function is used with a masked column + // in a binary prepared statement. + test_one_ps(test, "SELECT LENGTH(a), b FROM masking_auto_firewall", Expect::FAILURE); + + // This should NOT succeed as a masked column is used in a statement + // defining a variable. + test_one(test, "set @a = (SELECT a, b FROM masking_auto_firewall)", Expect::FAILURE); } + } int main(int argc, char* argv[]) diff --git a/maxscale-system-test/mysqlmon_fail_switch_events.cpp b/maxscale-system-test/mysqlmon_fail_switch_events.cpp index 5e7e14029..4a0e6e2be 100644 --- a/maxscale-system-test/mysqlmon_fail_switch_events.cpp +++ b/maxscale-system-test/mysqlmon_fail_switch_events.cpp @@ -22,6 +22,12 @@ const char EVENT_SHCEDULER[] = "SET GLOBAL event_scheduler = %s;"; const char USE_TEST[] = "USE test;"; const char DELETE_EVENT[] = "DROP EVENT %s;"; +const char EV_STATE_ENABLED[] = "ENABLED"; +const char EV_STATE_DISABLED[] = "DISABLED"; +const char EV_STATE_SLAVE_DISABLED[] = "SLAVESIDE_DISABLED"; + +const char WRONG_MASTER_FMT[] = "%s is not master as expected. Current master id: %i."; + int read_incremented_field(TestConnections& test) { int rval = -1; @@ -137,12 +143,57 @@ bool check_event_status(TestConnections& test, int node, else { rval = true; - cout << "Event '" << event_name << "' is '" << status << "' as it should.\n"; + cout << "Event '" << event_name << "' is '" << status << "' on node " << node << + " as it should.\n"; } } return rval; } +void set_event_state(TestConnections& test, const string& event_name, const string& new_state) +{ + bool success = false; + test.maxscales->connect_maxscale(0); + MYSQL* conn = test.maxscales->conn_rwsplit[0]; + const char query_fmt[] = "ALTER EVENT %s %s;"; + + if ((test.try_query(conn, USE_TEST) == 0) + && (test.try_query(conn, query_fmt, event_name.c_str(), new_state.c_str()) == 0)) + { + success = true; + } + test.expect(success, "ALTER EVENT failed: %s", mysql_error(conn)); + if (success) + { + cout << "Event '" << event_name << "' set to '" << new_state << "'.\n"; + } +} + +void switchover(TestConnections& test, const string& new_master) +{ + string switch_cmd = "call command mysqlmon switchover MySQL-Monitor " + new_master; + test.maxscales->execute_maxadmin_command_print(0, switch_cmd.c_str()); + test.maxscales->wait_for_monitor(2); + // Check success. + auto new_master_status = test.get_server_status(new_master.c_str()); + auto new_master_id = test.get_master_server_id(); + string status_string; + for (auto elem : new_master_status) + { + status_string += elem + ", "; + } + + bool success = (new_master_status.count("Master") == 1); + test.expect(success, + "%s is not master as expected. Status: %s. Current master id: %i", + new_master.c_str(), status_string.c_str(), new_master_id); + + if (success) + { + cout << "Switchover success, " + new_master + " is new master.\n"; + } +} + int main(int argc, char** argv) { Mariadb_nodes::require_gtid(true); @@ -154,10 +205,17 @@ int main(int argc, char** argv) // Schedule a repeating event. create_event(test); + int server1_ind = 0; + int server2_ind = 1; + int server1_id = test.repl->get_server_id(server1_ind); + + const char* server_names[] = {"server1", "server2", "server3", "server4"}; + auto server1_name = server_names[server1_ind]; + auto server2_name = server_names[server2_ind]; + int master_id_begin = test.get_master_server_id(); - int node0_id = test.repl->get_server_id(0); - test.expect(master_id_begin == node0_id, - "First server is not the master: master id: %i", master_id_begin); + + test.expect(master_id_begin == server1_id, WRONG_MASTER_FMT, server1_name, master_id_begin); // If initialisation failed, fail the test immediately. if (test.global_result != 0) @@ -167,8 +225,8 @@ int main(int argc, char** argv) } // Part 1: Do a failover - cout << "Step 1: Stop master and wait for failover. Check that another server is promoted.\n"; - test.repl->stop_node(0); + cout << "\nStep 1: Stop master and wait for failover. Check that another server is promoted.\n"; + test.repl->stop_node(server1_ind); test.maxscales->wait_for_monitor(3); get_output(test); int master_id_failover = test.get_master_server_id(); @@ -187,21 +245,21 @@ int main(int argc, char** argv) } // Part 2: Start node 0, let it join the cluster and check that the event is properly disabled. - cout << "Step 2: Restart node 0. It should join the cluster.\n"; - test.repl->start_node(0); + cout << "\nStep 2: Restart " << server1_name << ". It should join the cluster.\n"; + test.repl->start_node(server1_ind); test.maxscales->wait_for_monitor(4); get_output(test); - const char server_name[] = "server1"; - auto states = test.get_server_status(server_name); + + auto states = test.get_server_status(server1_name); if (states.count("Slave") < 1) { test.expect(false, "%s is not a slave as expected. Status: %s", - server_name, string_set_to_string(states).c_str()); + server1_name, string_set_to_string(states).c_str()); } else { // Old master joined as slave, check that event is disabled. - check_event_status(test, 0, EVENT_NAME, "SLAVESIDE_DISABLED"); + check_event_status(test, server1_ind, EVENT_NAME, EV_STATE_SLAVE_DISABLED); } if (test.global_result != 0) @@ -212,29 +270,49 @@ int main(int argc, char** argv) // Part 3: Switchover back to server1 as master. The event will most likely not run because the old // master doesn't have event scheduler on anymore. - cout << "Step 3: Switchover back to server1. Check that event is enabled. Don't check that the " - "event is running since the scheduler process is likely off.\n"; - string switch_cmd = "call command mysqlmon switchover MySQL-Monitor server1"; - test.maxscales->execute_maxadmin_command_print(0, switch_cmd.c_str()); - test.maxscales->wait_for_monitor(1); - get_output(test); - // Check success. - int master_id_switchover = test.get_master_server_id(); - test.expect(master_id_switchover == node0_id, - "server1 is not master as expected. Current master: %i.", master_id_switchover); - check_event_status(test, 0, EVENT_NAME, "ENABLED"); - if (test.global_result != 0) + cout << "\nStep 3: Switchover back to " << server1_name << ". Check that event is enabled. " + "Don't check that the event is running since the scheduler process is likely off.\n"; + switchover(test, server1_name); + if (test.ok()) { - try_delete_event(test); - return test.global_result; + check_event_status(test, server1_ind, EVENT_NAME, EV_STATE_ENABLED); } - // Check that all other nodes are slaves. - for (int i = 1; i < test.repl->N; i++) + // Part 4: Disable the event on master. The event should still be "SLAVESIDE_DISABLED" on slaves. + // Check that after switchover, the event is not enabled. + cout << "\nStep 4: Disable event on master, switchover to " << server2_name << ". " + "Check that event is still disabled.\n"; + if (test.ok()) { - string server_name = "server" + std::to_string(i + 1); - auto states = test.maxscales->get_server_status(server_name.c_str()); - test.expect(states.count("Slave") == 1, "%s is not a slave.", server_name.c_str()); + set_event_state(test, EVENT_NAME, "DISABLE"); + test.maxscales->wait_for_monitor(); // Wait for the monitor to detect the change. + check_event_status(test, server1_ind, EVENT_NAME, EV_STATE_DISABLED); + check_event_status(test, server2_ind, EVENT_NAME, EV_STATE_SLAVE_DISABLED); + + if (test.ok()) + { + cout << "Event is disabled on master and slaveside-disabled on slave.\n"; + switchover(test, server2_name); + if (test.ok()) + { + // Event should not have been touched. + check_event_status(test, server2_ind, EVENT_NAME, EV_STATE_SLAVE_DISABLED); + } + + // Switchover back. + switchover(test, server1_name); + } + } + + if (test.ok()) + { + // Check that all other nodes are slaves. + for (int i = 1; i < test.repl->N; i++) + { + string server_name = server_names[i]; + auto states = test.maxscales->get_server_status(server_name.c_str()); + test.expect(states.count("Slave") == 1, "%s is not a slave.", server_name.c_str()); + } } try_delete_event(test); diff --git a/maxscale-system-test/testconnections.cpp b/maxscale-system-test/testconnections.cpp index bdb44e0c0..a35afd457 100644 --- a/maxscale-system-test/testconnections.cpp +++ b/maxscale-system-test/testconnections.cpp @@ -816,11 +816,16 @@ void TestConnections::copy_one_mariadb_log(int i, std::string filename) for (auto cmd : log_retrive_commands) { - std::ofstream outfile(filename + std::to_string(j++)); + auto output = repl->ssh_output(cmd, i).second; - if (outfile) + if (!output.empty()) { - outfile << repl->ssh_output(cmd, i).second; + std::ofstream outfile(filename + std::to_string(j++)); + + if (outfile) + { + outfile << output; + } } } } diff --git a/query_classifier/qc_sqlite/qc_sqlite.cc b/query_classifier/qc_sqlite/qc_sqlite.cc index d76496cd1..890cf989b 100644 --- a/query_classifier/qc_sqlite/qc_sqlite.cc +++ b/query_classifier/qc_sqlite/qc_sqlite.cc @@ -2653,6 +2653,12 @@ public: return rv; } + void maxscaleSetStatusCap(int cap) + { + mxb_assert(cap >= QC_QUERY_TOKENIZED && cap <= QC_QUERY_PARSED); + m_status_cap = static_cast(cap); + } + void maxscaleRenameTable(Parse* pParse, SrcList* pTables) { mxb_assert(this_thread.initialized); @@ -3065,6 +3071,7 @@ private: QcSqliteInfo(uint32_t cllct) : m_refs(1) , m_status(QC_QUERY_INVALID) + , m_status_cap(QC_QUERY_PARSED) , m_collect(cllct) , m_collected(0) , m_pQuery(NULL) @@ -3270,6 +3277,7 @@ public: // TODO: Make these private once everything's been updated. int32_t m_refs; // The reference count. qc_parse_result_t m_status; // The validity of the information in this structure. + qc_parse_result_t m_status_cap; // The cap on 'm_status', it won't be set to higher than this. uint32_t m_collect; // What information should be collected. uint32_t m_collected; // What information has been collected. const char* m_pQuery; // The query passed to sqlite. @@ -3371,6 +3379,7 @@ extern "C" extern int maxscaleComment(); extern int maxscaleKeyword(int token); + extern void maxscaleSetStatusCap(int cap); extern int maxscaleTranslateKeyword(int token); } @@ -3422,6 +3431,11 @@ static void parse_query_string(const char* query, int len, bool suppress_logging const char* suffix = (len > max_len ? "..." : ""); const char* format; + if (this_thread.pInfo->m_status > this_thread.pInfo->m_status_cap) + { + this_thread.pInfo->m_status = this_thread.pInfo->m_status_cap; + } + if (this_thread.pInfo->m_operation == QUERY_OP_EXPLAIN) { this_thread.pInfo->m_status = QC_QUERY_PARSED; @@ -4321,6 +4335,18 @@ void maxscaleLock(Parse* pParse, mxs_lock_t type, SrcList* pTables) QC_EXCEPTION_GUARD(pInfo->maxscaleLock(pParse, type, pTables)); } +void maxscaleSetStatusCap(int cap) +{ + QC_TRACE(); + + mxb_assert((cap >= QC_QUERY_INVALID) && (cap <= QC_QUERY_PARSED)); + + QcSqliteInfo* pInfo = this_thread.pInfo; + mxb_assert(pInfo); + + QC_EXCEPTION_GUARD(pInfo->maxscaleSetStatusCap(cap)); +} + int maxscaleTranslateKeyword(int token) { QC_TRACE(); diff --git a/query_classifier/qc_sqlite/sqlite-src-3110100/src/tokenize.c b/query_classifier/qc_sqlite/sqlite-src-3110100/src/tokenize.c index 3284a2e2e..11360dc56 100644 --- a/query_classifier/qc_sqlite/sqlite-src-3110100/src/tokenize.c +++ b/query_classifier/qc_sqlite/sqlite-src-3110100/src/tokenize.c @@ -259,10 +259,18 @@ int sqlite3GetToken(const unsigned char *z, int *tokenType){ // MySQL-specific code for (i=3, c=z[2]; (c!='*' || z[i]!='/') && (c=z[i])!=0; i++){} if (c=='*' && z[i]=='/'){ - char* znc = (char*) z; - znc[0]=znc[1]=znc[2]=znc[i-1]=znc[i]=' '; // Remove comment chars, i.e. "/*!" and "*/". - for (i=3; sqlite3Isdigit(z[i]); ++i){} // Jump over the MySQL version number. - for (; sqlite3Isspace(z[i]); ++i){} // Jump over any space. + if (sqlite3Isdigit(z[3])) { + // A version specific executable comment, e.g. "/*!99999 ..." => never parsed. + extern void maxscaleSetStatusCap(int); + maxscaleSetStatusCap(2); // QC_QUERY_PARTIALLY_PARSED, see query_classifier.h:qc_parse_result + ++i; // Next after the trailing '/' + } + else { + // A non-version specific executable comment, e.g. "/*! select 1 */ => always parsed. + char* znc = (char*) z; + znc[0]=znc[1]=znc[2]=znc[i-1]=znc[i]=' '; // Remove comment chars, i.e. "/*!" and "*/". + for (i=3; sqlite3Isspace(z[i]); ++i){} // Jump over any space. + } } } else { for(i=3, c=z[2]; (c!='*' || z[i]!='/') && (c=z[i])!=0; i++){} diff --git a/server/core/buffer.cc b/server/core/buffer.cc index ecd7bae8b..340908bc9 100644 --- a/server/core/buffer.cc +++ b/server/core/buffer.cc @@ -219,18 +219,19 @@ GWBUF* gwbuf_clone(GWBUF* buf) return rval; } -GWBUF* gwbuf_deep_clone(const GWBUF* buf) +static GWBUF* gwbuf_deep_clone_portion(const GWBUF* buf, size_t length) { mxb_assert(buf->owner == RoutingWorker::get_current_id()); GWBUF* rval = NULL; if (buf) { - size_t buflen = gwbuf_length(buf); - rval = gwbuf_alloc(buflen); + rval = gwbuf_alloc(length); - if (rval && gwbuf_copy_data(buf, 0, buflen, GWBUF_DATA(rval)) == buflen) + if (rval && gwbuf_copy_data(buf, 0, length, GWBUF_DATA(rval)) == length) { + // The copying of the type is done to retain the type characteristic of the buffer without + // having a link the orginal data or parsing info. rval->gwbuf_type = buf->gwbuf_type; } else @@ -243,7 +244,12 @@ GWBUF* gwbuf_deep_clone(const GWBUF* buf) return rval; } -static GWBUF* gwbuf_clone_portion(GWBUF* buf, +GWBUF* gwbuf_deep_clone(const GWBUF* buf) +{ + return gwbuf_deep_clone_portion(buf, gwbuf_length(buf)); +} + +static GWBUF *gwbuf_clone_portion(GWBUF* buf, size_t start_offset, size_t length) { @@ -310,7 +316,7 @@ GWBUF* gwbuf_split(GWBUF** buf, size_t length) if (length > 0) { mxb_assert(GWBUF_LENGTH(buffer) > length); - GWBUF* partial = gwbuf_clone_portion(buffer, 0, length); + GWBUF* partial = gwbuf_deep_clone_portion(buffer, length); /** If the head points to the original head of the buffer chain * and we are splitting a contiguous buffer, we only need to return diff --git a/server/core/config.cc b/server/core/config.cc index e21f6ffd5..b87536c24 100644 --- a/server/core/config.cc +++ b/server/core/config.cc @@ -450,6 +450,7 @@ const char* config_pre_parse_global_params[] = CN_MAXLOG, CN_LOG_AUGMENTATION, CN_LOG_TO_SHM, + CN_SUBSTITUTE_VARIABLES, NULL }; diff --git a/server/core/queryclassifier.cc b/server/core/queryclassifier.cc index 0c462c434..1decf1f12 100644 --- a/server/core/queryclassifier.cc +++ b/server/core/queryclassifier.cc @@ -148,8 +148,6 @@ uint32_t get_prepare_type(GWBUF* buffer) } } - mxb_assert((type & (QUERY_TYPE_PREPARE_STMT | QUERY_TYPE_PREPARE_NAMED_STMT)) == 0); - return type; } diff --git a/server/modules/filter/binlogfilter/binlogfiltersession.cc b/server/modules/filter/binlogfilter/binlogfiltersession.cc index 33c7fd650..75d369306 100644 --- a/server/modules/filter/binlogfilter/binlogfiltersession.cc +++ b/server/modules/filter/binlogfilter/binlogfiltersession.cc @@ -229,7 +229,7 @@ int BinlogFilterSession::clientReply(GWBUF* pPacket) // they are replaced by a RAND_EVENT event packet if (m_skip) { - replaceEvent(&pPacket); + replaceEvent(&pPacket, hdr); } break; @@ -323,7 +323,7 @@ bool BinlogFilterSession::checkEvent(GWBUF* buffer, * Some events skipped. * Set next pos to 0 instead of real one and new CRC32 */ - fixEvent(event + MYSQL_HEADER_LEN + 1, hdr.event_size); + fixEvent(event + MYSQL_HEADER_LEN + 1, hdr.event_size, hdr); } break; @@ -444,12 +444,12 @@ static void event_set_crc32(uint8_t* event, uint32_t event_size) * @param event Pointer to event data * @event_size The event size */ -void BinlogFilterSession::fixEvent(uint8_t* event, uint32_t event_size) +void BinlogFilterSession::fixEvent(uint8_t* event, uint32_t event_size, const REP_HEADER& hdr) { // Set next pos to 0. // The next_pos offset is the 13th byte in replication event header 19 bytes // + 4 (time) + 1 (type) + 4 (server_id) + 4 (event_size) - gw_mysql_set_byte4(event + 4 + 1 + 4 + 4, 0); + gw_mysql_set_byte4(event + 4 + 1 + 4 + 4, hdr.next_pos); // Set CRC32 in the new event if (m_crc) @@ -466,7 +466,7 @@ void BinlogFilterSession::fixEvent(uint8_t* event, uint32_t event_size) * * @param pPacket The GWBUF with event data */ -void BinlogFilterSession::replaceEvent(GWBUF** ppPacket) +void BinlogFilterSession::replaceEvent(GWBUF** ppPacket, const REP_HEADER& hdr) { uint32_t buf_len = gwbuf_length(*ppPacket); @@ -596,7 +596,7 @@ void BinlogFilterSession::replaceEvent(GWBUF** ppPacket) } // Fix Event Next pos = 0 and set new CRC32 - fixEvent(ptr + MYSQL_HEADER_LEN + 1, new_event_size); + fixEvent(ptr + MYSQL_HEADER_LEN + 1, new_event_size, hdr); } /** diff --git a/server/modules/filter/binlogfilter/binlogfiltersession.hh b/server/modules/filter/binlogfilter/binlogfiltersession.hh index 8b34c1ef2..af276003e 100644 --- a/server/modules/filter/binlogfilter/binlogfiltersession.hh +++ b/server/modules/filter/binlogfilter/binlogfiltersession.hh @@ -85,13 +85,13 @@ private: void filterError(GWBUF* pPacket); // Fix event: set next pos to 0 and set new CRC32 - void fixEvent(uint8_t* data, uint32_t event_size); + void fixEvent(uint8_t* data, uint32_t event_size, const REP_HEADER& hdr); // Whether to skip current event bool checkEvent(GWBUF* data, const REP_HEADER& hdr); // Filter the replication event - void replaceEvent(GWBUF** data); + void replaceEvent(GWBUF** data, const REP_HEADER& hdr); // Handle event size void handlePackets(uint32_t len, const REP_HEADER& hdr); diff --git a/server/modules/filter/dbfwfilter/dbfwfilter.cc b/server/modules/filter/dbfwfilter/dbfwfilter.cc index 99eb7e5b0..ab50c3bef 100644 --- a/server/modules/filter/dbfwfilter/dbfwfilter.cc +++ b/server/modules/filter/dbfwfilter/dbfwfilter.cc @@ -1479,71 +1479,81 @@ int DbfwSession::routeQuery(GWBUF* buffer) if (qc_query_is_type(type, QUERY_TYPE_PREPARE_NAMED_STMT)) { analyzed_queue = qc_get_preparable_stmt(buffer); - mxb_assert(analyzed_queue); + + // 'analyzed_queue' will be NULL if the statement is prepared from + // a variable like in : "prepare ps from @a". } - SUser suser = find_user_data(this_thread->users(m_instance), user(), remote()); bool query_ok = false; - if (command_is_mandatory(buffer)) + if (!analyzed_queue) { - query_ok = true; + set_error("Firewall rejects statements prepared from a variable."); } - else if (suser) + else { - char* rname = NULL; - bool match = suser->match(m_instance, this, analyzed_queue, &rname); + SUser suser = find_user_data(this_thread->users(m_instance), user(), remote()); - switch (m_instance->get_action()) + if (command_is_mandatory(buffer)) { - case FW_ACTION_ALLOW: - query_ok = match; - break; - - case FW_ACTION_BLOCK: - query_ok = !match; - break; - - case FW_ACTION_IGNORE: query_ok = true; - break; - - default: - MXS_ERROR("Unknown dbfwfilter action: %d", m_instance->get_action()); - mxb_assert(false); - break; } - - if (m_instance->get_log_bitmask() != FW_LOG_NONE) + else if (suser) { - if (match && m_instance->get_log_bitmask() & FW_LOG_MATCH) - { - MXS_NOTICE("[%s] Rule '%s' for '%s' matched by %s@%s: %s", - m_session->service->name(), - rname, - suser->name(), - user().c_str(), - remote().c_str(), - get_sql(buffer).c_str()); - } - else if (!match && m_instance->get_log_bitmask() & FW_LOG_NO_MATCH) - { - MXS_NOTICE("[%s] Query for '%s' by %s@%s was not matched: %s", - m_session->service->name(), - suser->name(), - user().c_str(), - remote().c_str(), - get_sql(buffer).c_str()); - } - } + char* rname = NULL; + bool match = suser->match(m_instance, this, analyzed_queue, &rname); - MXS_FREE(rname); - } - /** If the instance is in whitelist mode, only users that have a rule - * defined for them are allowed */ - else if (m_instance->get_action() != FW_ACTION_ALLOW) - { - query_ok = true; + switch (m_instance->get_action()) + { + case FW_ACTION_ALLOW: + query_ok = match; + break; + + case FW_ACTION_BLOCK: + query_ok = !match; + break; + + case FW_ACTION_IGNORE: + query_ok = true; + break; + + default: + MXS_ERROR("Unknown dbfwfilter action: %d", m_instance->get_action()); + mxb_assert(false); + break; + } + + if (m_instance->get_log_bitmask() != FW_LOG_NONE) + { + if (match && m_instance->get_log_bitmask() & FW_LOG_MATCH) + { + MXS_NOTICE("[%s] Rule '%s' for '%s' matched by %s@%s: %s", + m_session->service->name, + rname, + suser->name(), + user().c_str(), + remote().c_str(), + get_sql(buffer).c_str()); + } + else if (!match && m_instance->get_log_bitmask() & FW_LOG_NO_MATCH) + { + MXS_NOTICE("[%s] Query for '%s' by %s@%s was not matched: %s", + m_session->service->name, + suser->name(), + user().c_str(), + remote().c_str(), + get_sql(buffer).c_str()); + } + } + + MXS_FREE(rname); + } + /** If the instance is in whitelist mode, only users that have a rule + * defined for them are allowed */ + else if (m_instance->get_action() != FW_ACTION_ALLOW) + { + query_ok = true; + } } if (query_ok) diff --git a/server/modules/filter/masking/maskingfilter.cc b/server/modules/filter/masking/maskingfilter.cc index 0db2e1ad8..55eee86df 100644 --- a/server/modules/filter/masking/maskingfilter.cc +++ b/server/modules/filter/masking/maskingfilter.cc @@ -90,13 +90,10 @@ extern "C" MXS_MODULE* MXS_CREATE_MODULE() "V1.0.0", RCAP_TYPE_CONTIGUOUS_INPUT | RCAP_TYPE_CONTIGUOUS_OUTPUT, &MaskingFilter::s_object, - NULL, /* Process init. - * */ - NULL, /* Process finish. - * */ - NULL, /* Thread init. */ - NULL, /* Thread finish. - * */ + NULL, /* Process init. */ + NULL, /* Process finish. */ + NULL, /* Thread init. */ + NULL, /* Thread finish. */ { { Config::rules_name, @@ -124,6 +121,12 @@ extern "C" MXS_MODULE* MXS_CREATE_MODULE() Config::prevent_function_usage_default, MXS_MODULE_OPT_NONE, }, + { + Config::check_user_variables_name, + MXS_MODULE_PARAM_BOOL, + Config::check_user_variables_default, + MXS_MODULE_OPT_NONE, + }, {MXS_END_MODULE_PARAMS} } }; diff --git a/server/modules/filter/masking/maskingfilterconfig.cc b/server/modules/filter/masking/maskingfilterconfig.cc index f004103aa..1544243ad 100644 --- a/server/modules/filter/masking/maskingfilterconfig.cc +++ b/server/modules/filter/masking/maskingfilterconfig.cc @@ -27,6 +27,7 @@ const char config_value_never[] = "never"; const char config_value_always[] = "always"; const char config_name_prevent_function_usage[] = "prevent_function_usage"; +const char config_check_user_variables[] = "check_user_variables"; const char config_value_true[] = "true"; } @@ -84,6 +85,14 @@ const char* MaskingFilterConfig::prevent_function_usage_name = config_name_preve // static const char* MaskingFilterConfig::prevent_function_usage_default = config_value_true; +/* + * PARAM check_user_variables + */ +const char* MaskingFilterConfig::check_user_variables_name = config_check_user_variables; + +// static +const char* MaskingFilterConfig::check_user_variables_default = config_value_true; + /* * MaskingFilterConfig */ @@ -115,3 +124,9 @@ bool MaskingFilterConfig::get_prevent_function_usage(const MXS_CONFIG_PARAMETER* { return pParams->get_bool(prevent_function_usage_name); } + +// static +bool MaskingFilterConfig::get_check_user_variables(const MXS_CONFIG_PARAMETER* pParams) +{ + return config_get_bool(pParams, check_user_variables_name); +} diff --git a/server/modules/filter/masking/maskingfilterconfig.hh b/server/modules/filter/masking/maskingfilterconfig.hh index 902bfc76d..faad0d470 100644 --- a/server/modules/filter/masking/maskingfilterconfig.hh +++ b/server/modules/filter/masking/maskingfilterconfig.hh @@ -45,12 +45,16 @@ public: static const char* prevent_function_usage_name; static const char* prevent_function_usage_default; + static const char* check_user_variables_name; + static const char* check_user_variables_default; + MaskingFilterConfig(const char* zName, const MXS_CONFIG_PARAMETER* pParams) : m_name(zName) , m_large_payload(get_large_payload(pParams)) , m_rules(get_rules(pParams)) , m_warn_type_mismatch(get_warn_type_mismatch(pParams)) , m_prevent_function_usage(get_prevent_function_usage(pParams)) + , m_check_user_variables(get_check_user_variables(pParams)) { } ~MaskingFilterConfig() @@ -82,6 +86,11 @@ public: return m_prevent_function_usage; } + bool check_user_variables() const + { + return m_check_user_variables; + } + void set_large_payload(large_payload_t l) { m_large_payload = l; @@ -101,10 +110,21 @@ public: m_prevent_function_usage = b; } + void set_check_user_variables(bool b) + { + m_check_user_variables = b; + } + + bool is_parsing_needed() const + { + return prevent_function_usage() || check_user_variables(); + } + static large_payload_t get_large_payload(const MXS_CONFIG_PARAMETER* pParams); static std::string get_rules(const MXS_CONFIG_PARAMETER* pParams); static warn_type_mismatch_t get_warn_type_mismatch(const MXS_CONFIG_PARAMETER* pParams); static bool get_prevent_function_usage(const MXS_CONFIG_PARAMETER* pParams); + static bool get_check_user_variables(const MXS_CONFIG_PARAMETER* pParams); private: std::string m_name; @@ -112,4 +132,5 @@ private: std::string m_rules; warn_type_mismatch_t m_warn_type_mismatch; bool m_prevent_function_usage; + bool m_check_user_variables; }; diff --git a/server/modules/filter/masking/maskingfiltersession.cc b/server/modules/filter/masking/maskingfiltersession.cc index 940d91f1e..ec30f2f2b 100644 --- a/server/modules/filter/masking/maskingfiltersession.cc +++ b/server/modules/filter/masking/maskingfiltersession.cc @@ -31,6 +31,25 @@ using std::ostream; using std::string; using std::stringstream; +namespace +{ + +GWBUF* create_error_response(const char* zMessage) +{ + return modutil_create_mysql_err_msg(1, 0, 1141, "HY000", zMessage); +} + +GWBUF* create_parse_error_response() +{ + const char* zMessage = + "The statement could not be fully parsed and will hence be " + "rejected (masking filter)."; + + return create_error_response(zMessage); +} + +} + MaskingFilterSession::MaskingFilterSession(MXS_SESSION* pSession, const MaskingFilter* pFilter) : maxscale::FilterSession(pSession) , m_filter(*pFilter) @@ -48,6 +67,96 @@ MaskingFilterSession* MaskingFilterSession::create(MXS_SESSION* pSession, const return new MaskingFilterSession(pSession, pFilter); } +bool MaskingFilterSession::check_query(GWBUF* pPacket) +{ + const char* zUser = session_get_user(m_pSession); + const char* zHost = session_get_remote(m_pSession); + + if (!zUser) + { + zUser = ""; + } + + if (!zHost) + { + zHost = ""; + } + + bool rv = true; + + if (rv && m_filter.config().prevent_function_usage()) + { + if (is_function_used(pPacket, zUser, zHost)) + { + rv = false; + } + } + + if (rv && m_filter.config().check_user_variables()) + { + if (is_variable_defined(pPacket, zUser, zHost)) + { + rv = false; + } + } + + return rv; +} + +bool MaskingFilterSession::check_textual_query(GWBUF* pPacket) +{ + bool rv = false; + + if (qc_parse(pPacket, QC_COLLECT_FIELDS | QC_COLLECT_FUNCTIONS) == QC_QUERY_PARSED) + { + if (qc_query_is_type(qc_get_type_mask(pPacket), QUERY_TYPE_PREPARE_NAMED_STMT)) + { + GWBUF* pP = qc_get_preparable_stmt(pPacket); + + if (pP) + { + rv = check_textual_query(pP); + } + else + { + // If pP is NULL, it indicates that we have a "prepare ps from @a". It must + // be rejected as we currently have no means for checking what columns are + // referred to. + const char* zMessage = + "A statement prepared from a variable is rejected (masking filter)."; + + set_response(create_error_response(zMessage)); + } + } + else + { + rv = check_query(pPacket); + } + } + else + { + set_response(create_parse_error_response()); + } + + return rv; +} + +bool MaskingFilterSession::check_binary_query(GWBUF* pPacket) +{ + bool rv = false; + + if (qc_parse(pPacket, QC_COLLECT_FIELDS | QC_COLLECT_FUNCTIONS) == QC_QUERY_PARSED) + { + rv = check_query(pPacket); + } + else + { + set_response(create_parse_error_response()); + } + + return rv; +} + int MaskingFilterSession::routeQuery(GWBUF* pPacket) { ComRequest request(pPacket); @@ -58,9 +167,16 @@ int MaskingFilterSession::routeQuery(GWBUF* pPacket) case MXS_COM_QUERY: m_res.reset(request.command(), m_filter.rules()); - if (m_filter.config().prevent_function_usage() && reject_if_function_used(pPacket)) + if (m_filter.config().is_parsing_needed()) { - m_state = EXPECTING_NOTHING; + if (check_textual_query(pPacket)) + { + m_state = EXPECTING_RESPONSE; + } + else + { + m_state = EXPECTING_NOTHING; + } } else { @@ -68,6 +184,24 @@ int MaskingFilterSession::routeQuery(GWBUF* pPacket) } break; + case MXS_COM_STMT_PREPARE: + if (m_filter.config().is_parsing_needed()) + { + if (check_binary_query(pPacket)) + { + m_state = IGNORING_RESPONSE; + } + else + { + m_state = EXPECTING_NOTHING; + } + } + else + { + m_state = IGNORING_RESPONSE; + } + break; + case MXS_COM_STMT_EXECUTE: m_res.reset(request.command(), m_filter.rules()); m_state = EXPECTING_RESPONSE; @@ -370,39 +504,26 @@ void MaskingFilterSession::mask_values(ComPacket& response) } } -bool MaskingFilterSession::reject_if_function_used(GWBUF* pPacket) +bool MaskingFilterSession::is_function_used(GWBUF* pPacket, const char* zUser, const char* zHost) { - bool rejected = false; + bool is_used = false; SMaskingRules sRules = m_filter.rules(); - const char* zUser = session_get_user(m_pSession); - const char* zHost = session_get_remote(m_pSession); - - if (!zUser) - { - zUser = ""; - } - - if (!zHost) - { - zHost = ""; - } - auto pred1 = [&sRules, zUser, zHost](const QC_FIELD_INFO& field_info) { - const MaskingRules::Rule* pRule = sRules->get_rule_for(field_info, zUser, zHost); + const MaskingRules::Rule* pRule = sRules->get_rule_for(field_info, zUser, zHost); - return pRule ? true : false; - }; + return pRule ? true : false; + }; auto pred2 = [&sRules, zUser, zHost, &pred1](const QC_FUNCTION_INFO& function_info) { - const QC_FIELD_INFO* begin = function_info.fields; - const QC_FIELD_INFO* end = begin + function_info.n_fields; + const QC_FIELD_INFO* begin = function_info.fields; + const QC_FIELD_INFO* end = begin + function_info.n_fields; - auto i = std::find_if(begin, end, pred1); + auto i = std::find_if(begin, end, pred1); - return i != end; - }; + return i != end; + }; const QC_FUNCTION_INFO* pInfos; size_t nInfos; @@ -420,11 +541,51 @@ bool MaskingFilterSession::reject_if_function_used(GWBUF* pPacket) ss << "The function " << i->name << " is used in conjunction with a field " << "that should be masked for '" << zUser << "'@'" << zHost << "', access is denied."; - GWBUF* pResponse = modutil_create_mysql_err_msg(1, 0, 1141, "HY000", ss.str().c_str()); - set_response(pResponse); + set_response(create_error_response(ss.str().c_str())); - rejected = true; + is_used = true; } - return rejected; + return is_used; +} + +bool MaskingFilterSession::is_variable_defined(GWBUF* pPacket, const char* zUser, const char* zHost) +{ + if (!qc_query_is_type(qc_get_type_mask(pPacket), QUERY_TYPE_USERVAR_WRITE)) + { + return false; + } + + bool is_defined = false; + + SMaskingRules sRules = m_filter.rules(); + + auto pred = [&sRules, zUser, zHost](const QC_FIELD_INFO& field_info) { + const MaskingRules::Rule* pRule = sRules->get_rule_for(field_info, zUser, zHost); + + return pRule ? true : false; + }; + + const QC_FIELD_INFO* pInfos; + size_t nInfos; + + qc_get_field_info(pPacket, &pInfos, &nInfos); + + const QC_FIELD_INFO* begin = pInfos; + const QC_FIELD_INFO* end = begin + nInfos; + + auto i = std::find_if(begin, end, pred); + + if (i != end) + { + std::stringstream ss; + ss << "The field " << i->column << " that should be masked for '" << zUser << "'@'" << zHost + << "' is used when defining a variable, access is denied."; + + set_response(create_error_response(ss.str().c_str())); + + is_defined = true; + } + + return is_defined; } diff --git a/server/modules/filter/masking/maskingfiltersession.hh b/server/modules/filter/masking/maskingfiltersession.hh index 302f60a2b..39da05596 100644 --- a/server/modules/filter/masking/maskingfiltersession.hh +++ b/server/modules/filter/masking/maskingfiltersession.hh @@ -53,6 +53,10 @@ private: SUPPRESSING_RESPONSE }; + bool check_query(GWBUF* pPacket); + bool check_textual_query(GWBUF* pPacket); + bool check_binary_query(GWBUF* pPacket); + void handle_response(GWBUF* pPacket); void handle_field(GWBUF* pPacket); void handle_row(GWBUF* pPacket); @@ -61,7 +65,8 @@ private: void mask_values(ComPacket& response); - bool reject_if_function_used(GWBUF* pPacket); + bool is_function_used(GWBUF* pPacket, const char* zUser, const char* zHost); + bool is_variable_defined(GWBUF* pPacket, const char* zUser, const char* zHost); private: typedef std::shared_ptr SMaskingRules; diff --git a/server/modules/protocol/MySQL/rwbackend.cc b/server/modules/protocol/MySQL/rwbackend.cc index 2364551c0..995ba55e3 100644 --- a/server/modules/protocol/MySQL/rwbackend.cc +++ b/server/modules/protocol/MySQL/rwbackend.cc @@ -191,6 +191,14 @@ void RWBackend::process_reply(GWBUF* buffer) // TODO: Don't clone the buffer GWBUF* tmp = gwbuf_clone(buffer); tmp = gwbuf_consume(tmp, mxs_mysql_get_packet_len(tmp)); + + // Consume repeating OK packets + while (mxs_mysql_more_results_after_ok(buffer) && have_next_packet(tmp)) + { + tmp = gwbuf_consume(tmp, mxs_mysql_get_packet_len(tmp)); + mxb_assert(tmp); + } + process_reply(tmp); gwbuf_free(tmp); return;