diff --git a/server/modules/filter/masking/maskingfilter.cc b/server/modules/filter/masking/maskingfilter.cc index 7cf088c16..ebdba6f1e 100644 --- a/server/modules/filter/masking/maskingfilter.cc +++ b/server/modules/filter/masking/maskingfilter.cc @@ -84,7 +84,7 @@ MaskingFilter* MaskingFilter::create(const char* zName, char** pzOptions, CONFIG MaskingFilterSession* MaskingFilter::newSession(SESSION* pSession) { - return MaskingFilterSession::create(pSession); + return MaskingFilterSession::create(pSession, this); } // static @@ -96,7 +96,12 @@ void MaskingFilter::diagnostics(DCB* pDcb) // static uint64_t MaskingFilter::getCapabilities() { - return 0; + return RCAP_TYPE_STMT_INPUT | RCAP_TYPE_STMT_OUTPUT; +} + +std::tr1::shared_ptr MaskingFilter::rules() const +{ + return m_sRules; } // static diff --git a/server/modules/filter/masking/maskingfilter.hh b/server/modules/filter/masking/maskingfilter.hh index 1a547bf93..029677eca 100644 --- a/server/modules/filter/masking/maskingfilter.hh +++ b/server/modules/filter/masking/maskingfilter.hh @@ -14,6 +14,7 @@ #include #include +#include #include #include "maskingfilterconfig.hh" #include "maskingfiltersession.hh" @@ -24,6 +25,7 @@ class MaskingRules; class MaskingFilter : public maxscale::Filter { public: + typedef std::tr1::shared_ptr SMaskingRules; typedef MaskingFilterConfig Config; ~MaskingFilter(); @@ -35,6 +37,8 @@ public: static uint64_t getCapabilities(); + SMaskingRules rules() const; + private: MaskingFilter(const Config& config, std::auto_ptr sRules); @@ -44,6 +48,6 @@ private: static bool process_params(char **pzOptions, CONFIG_PARAMETER *ppParams, Config& config); private: - Config m_config; - std::auto_ptr m_sRules; + Config m_config; + SMaskingRules m_sRules; }; diff --git a/server/modules/filter/masking/maskingfiltersession.cc b/server/modules/filter/masking/maskingfiltersession.cc index e172971d5..6c86fda2f 100644 --- a/server/modules/filter/masking/maskingfiltersession.cc +++ b/server/modules/filter/masking/maskingfiltersession.cc @@ -12,11 +12,23 @@ */ #include "maskingfiltersession.hh" +#include +#include #include +#include +#include +#include "maskingfilter.hh" +#include "mysql.hh" +using maxscale::Buffer; +using std::ostream; +using std::string; +using std::stringstream; -MaskingFilterSession::MaskingFilterSession(SESSION* pSession) +MaskingFilterSession::MaskingFilterSession(SESSION* pSession, const MaskingFilter* pFilter) : maxscale::FilterSession(pSession) + , m_filter(*pFilter) + , m_state(IGNORING_RESPONSE) { } @@ -25,7 +37,186 @@ MaskingFilterSession::~MaskingFilterSession() } //static -MaskingFilterSession* MaskingFilterSession::create(SESSION* pSession) +MaskingFilterSession* MaskingFilterSession::create(SESSION* pSession, const MaskingFilter* pFilter) { - return new MaskingFilterSession(pSession); + return new MaskingFilterSession(pSession, pFilter); +} + +int MaskingFilterSession::routeQuery(GWBUF* pPacket) +{ + ComRequest request(pPacket); + + switch (request.command()) + { + case MYSQL_COM_QUERY: + // TODO: Breaks if responses are not waited for, before the next request is sent. + m_res.reset(m_filter.rules()); + m_state = EXPECTING_RESPONSE; + break; + + default: + m_state = IGNORING_RESPONSE; + } + + return FilterSession::routeQuery(pPacket); +} + +int MaskingFilterSession::clientReply(GWBUF* pPacket) +{ + MXS_NOTICE("clientReply"); + ss_dassert(GWBUF_IS_CONTIGUOUS(pPacket)); + + switch (m_state) + { + case EXPECTING_NOTHING: + MXS_WARNING("Received data, although expected nothing."); + case IGNORING_RESPONSE: + break; + + case EXPECTING_RESPONSE: + handle_response(pPacket); + break; + + case EXPECTING_FIELD: + handle_field(pPacket); + break; + + case EXPECTING_ROW: + handle_row(pPacket); + break; + + case EXPECTING_FIELD_EOF: + case EXPECTING_ROW_EOF: + handle_eof(pPacket); + break; + } + + return FilterSession::clientReply(pPacket); +} + +void MaskingFilterSession::handle_response(GWBUF* pPacket) +{ + MXS_NOTICE("handle_response"); + ComResponse response(pPacket); + + switch (response.type()) + { + case 0x00: // OK + case 0xff: // ERR + case 0xfb: // GET_MORE_CLIENT_DATA/SEND_MORE_CLIENT_DATA + m_state = EXPECTING_NOTHING; + break; + + default: + { + ComQueryResponse query_response(response); + + m_res.set_total_fields(query_response.nFields()); + m_state = EXPECTING_FIELD; + } + } +} + +void MaskingFilterSession::handle_field(GWBUF* pPacket) +{ + MXS_NOTICE("handle_field"); + + ComQueryResponse::ColumnDef column_def(pPacket); + + const char *zUser = session_getUser(m_pSession); + const char *zHost = session_get_remote(m_pSession); + + if (!zUser) + { + zUser = ""; + } + + if (!zHost) + { + zHost = ""; + } + + const MaskingRules::Rule* pRule = m_res.rules()->get_rule_for(column_def, zUser, zHost); + + if (m_res.append_rule(pRule)) + { + // All fields have been read. + m_state = EXPECTING_FIELD_EOF; + } + + MXS_NOTICE("Stats: %s", column_def.to_string().c_str()); +} + +void MaskingFilterSession::handle_eof(GWBUF* pPacket) +{ + MXS_NOTICE("handle_eof"); + + ComResponse response(pPacket); + + if (response.is_eof()) + { + switch (m_state) + { + case EXPECTING_FIELD_EOF: + m_state = EXPECTING_ROW; + break; + + case EXPECTING_ROW_EOF: + m_state = EXPECTING_NOTHING; + break; + + default: + ss_dassert(!true); + m_state = IGNORING_RESPONSE; + } + } + else + { + MXS_ERROR("Expected EOF, got something else: %d", response.type()); + m_state = IGNORING_RESPONSE; + } +} + +void MaskingFilterSession::handle_row(GWBUF* pPacket) +{ + MXS_NOTICE("handle_row"); + + ComResponse response(pPacket); + + switch (response.type()) + { + case ComPacket::EOF_PACKET: + // EOF after last row. + MXS_NOTICE("EOF after last row received."); + m_state = EXPECTING_NOTHING; + break; + + case 0xfb: // NULL is sent as 0xfb + MXS_NOTICE("NULL"); + // We must ask for the rule so as not to get out of sync. + m_res.get_rule(); + break; + + default: + { + ComQueryResponse::Row row(response); + + ComQueryResponse::Row::iterator i = row.begin(); + while (i != row.end()) + { + const MaskingRules::Rule* pRule = m_res.get_rule(); + + if (pRule) + { + LEncString s = *i; + + pRule->rewrite(s); + + MXS_NOTICE("String: %s", (*i).to_string().c_str()); + } + ++i; + } + } + break; + } } diff --git a/server/modules/filter/masking/maskingfiltersession.hh b/server/modules/filter/masking/maskingfiltersession.hh index 11e8dfc2d..d413f366d 100644 --- a/server/modules/filter/masking/maskingfiltersession.hh +++ b/server/modules/filter/masking/maskingfiltersession.hh @@ -13,18 +13,101 @@ */ #include +#include +#include +#include #include +#include "maskingrules.hh" + +class MaskingFilter; class MaskingFilterSession : public maxscale::FilterSession { public: ~MaskingFilterSession(); - static MaskingFilterSession* create(SESSION* pSession); + static MaskingFilterSession* create(SESSION* pSession, const MaskingFilter* pFilter); + + int routeQuery(GWBUF* pPacket); + + int clientReply(GWBUF* pPacket); private: - MaskingFilterSession(SESSION* pSession); + MaskingFilterSession(SESSION* pSession, const MaskingFilter* pFilter); MaskingFilterSession(const MaskingFilterSession&); MaskingFilterSession& operator = (const MaskingFilterSession&); + + enum state_t + { + EXPECTING_NOTHING, + EXPECTING_RESPONSE, + EXPECTING_FIELD, + EXPECTING_FIELD_EOF, + EXPECTING_ROW, + EXPECTING_ROW_EOF, + IGNORING_RESPONSE + }; + + void handle_response(GWBUF* pPacket); + void handle_field(GWBUF* pPacket); + void handle_row(GWBUF* pPacket); + void handle_eof(GWBUF* pPacket); + +private: + typedef std::tr1::shared_ptr SMaskingRules; + + class ResponseState + { + public: + ResponseState() + : m_nTotal_fields(0) + , m_index(0) + {} + + void reset(const SMaskingRules& sRules) + { + m_sRules = sRules; + m_nTotal_fields = 0; + m_rules.clear(); + m_index = 0; + } + + const SMaskingRules& rules() const + { + return m_sRules; + } + + uint32_t total_fields() const { return m_nTotal_fields; } + + void set_total_fields(uint32_t n) { m_nTotal_fields = n; } + + bool append_rule(const MaskingRules::Rule* pRule) + { + m_rules.push_back(pRule); + + return m_rules.size() == m_nTotal_fields; + } + + const MaskingRules::Rule* get_rule() + { + ss_dassert(m_nTotal_fields == m_rules.size()); + ss_dassert(m_index < m_rules.size()); + const MaskingRules::Rule* pRule = m_rules[m_index++]; + // The rules will be used repeatedly for each row. Hence, once we hit + // the end, we need to continue from the start. + m_index = m_index % m_rules.size(); + return pRule; + } + + private: + SMaskingRules m_sRules; /* m_rules; /*