diff --git a/maxscale-system-test/masking_auto_firewall.cpp b/maxscale-system-test/masking_auto_firewall.cpp index 90d87e293..7f81f8577 100644 --- a/maxscale-system-test/masking_auto_firewall.cpp +++ b/maxscale-system-test/masking_auto_firewall.cpp @@ -106,6 +106,10 @@ void run(TestConnections& test) // This should NOT succeed as a function is used with a masked column // in a binary prepared statement. test_one_ps(test, "SELECT LENGTH(a), b FROM masking_auto_firewall", Expect::FAILURE); + + // This should NOT succeed as a masked column is used in a statement + // defining a variable. + test_one(test, "set @a = (SELECT a, b FROM masking_auto_firewall)", Expect::FAILURE); } } diff --git a/server/modules/filter/masking/maskingfiltersession.cc b/server/modules/filter/masking/maskingfiltersession.cc index fe1a033d1..53d99fc61 100644 --- a/server/modules/filter/masking/maskingfiltersession.cc +++ b/server/modules/filter/masking/maskingfiltersession.cc @@ -69,8 +69,6 @@ MaskingFilterSession* MaskingFilterSession::create(MXS_SESSION* pSession, const bool MaskingFilterSession::check_query(GWBUF* pPacket) { - bool rv = true; - const char* zUser = session_get_user(m_pSession); const char* zHost = session_get_remote(m_pSession); @@ -84,7 +82,9 @@ bool MaskingFilterSession::check_query(GWBUF* pPacket) zHost = ""; } - if (m_filter.config().prevent_function_usage()) + bool rv = true; + + if (rv && m_filter.config().prevent_function_usage()) { if (is_function_used(pPacket, zUser, zHost)) { @@ -92,6 +92,14 @@ bool MaskingFilterSession::check_query(GWBUF* pPacket) } } + if (rv && m_filter.config().check_user_variables()) + { + if (is_variable_defined(pPacket, zUser, zHost)) + { + rv = false; + } + } + return rv; } @@ -540,3 +548,44 @@ bool MaskingFilterSession::is_function_used(GWBUF* pPacket, const char* zUser, c return is_used; } + +bool MaskingFilterSession::is_variable_defined(GWBUF* pPacket, const char* zUser, const char* zHost) +{ + if (!qc_query_is_type(qc_get_type_mask(pPacket), QUERY_TYPE_USERVAR_WRITE)) + { + return false; + } + + bool is_defined = false; + + SMaskingRules sRules = m_filter.rules(); + + auto pred = [&sRules, zUser, zHost](const QC_FIELD_INFO& field_info) { + const MaskingRules::Rule* pRule = sRules->get_rule_for(field_info, zUser, zHost); + + return pRule ? true : false; + }; + + const QC_FIELD_INFO* pInfos; + size_t nInfos; + + qc_get_field_info(pPacket, &pInfos, &nInfos); + + const QC_FIELD_INFO* begin = pInfos; + const QC_FIELD_INFO* end = begin + nInfos; + + auto i = std::find_if(begin, end, pred); + + if (i != end) + { + std::stringstream ss; + ss << "The field " << i->column << " that should be masked for '" << zUser << "'@'" << zHost + << "' is used when defining a variable, access is denied."; + + set_response(create_error_response(ss.str().c_str())); + + is_defined = true; + } + + return is_defined; +} diff --git a/server/modules/filter/masking/maskingfiltersession.hh b/server/modules/filter/masking/maskingfiltersession.hh index d647b8fad..39da05596 100644 --- a/server/modules/filter/masking/maskingfiltersession.hh +++ b/server/modules/filter/masking/maskingfiltersession.hh @@ -66,6 +66,7 @@ private: void mask_values(ComPacket& response); bool is_function_used(GWBUF* pPacket, const char* zUser, const char* zHost); + bool is_variable_defined(GWBUF* pPacket, const char* zUser, const char* zHost); private: typedef std::shared_ptr SMaskingRules;