diff --git a/server/modules/filter/masking/maskingfiltersession.cc b/server/modules/filter/masking/maskingfiltersession.cc index 7b9e57c69..ea5e3fc1c 100644 --- a/server/modules/filter/masking/maskingfiltersession.cc +++ b/server/modules/filter/masking/maskingfiltersession.cc @@ -46,11 +46,12 @@ int MaskingFilterSession::routeQuery(GWBUF* pPacket) { ComRequest request(pPacket); + // TODO: Breaks if responses are not waited for, before the next request is sent. switch (request.command()) { case MYSQL_COM_QUERY: - // TODO: Breaks if responses are not waited for, before the next request is sent. - m_res.reset(m_filter.rules()); + case MYSQL_COM_STMT_EXECUTE: + m_res.reset(request.command(), m_filter.rules()); m_state = EXPECTING_RESPONSE; break; @@ -138,7 +139,7 @@ void MaskingFilterSession::handle_field(GWBUF* pPacket) const MaskingRules::Rule* pRule = m_res.rules()->get_rule_for(column_def, zUser, zHost); - if (m_res.append_rule(pRule)) + if (m_res.append_type_and_rule(column_def.type(), pRule)) { // All fields have been read. m_state = EXPECTING_FIELD_EOF; @@ -192,28 +193,60 @@ void MaskingFilterSession::handle_row(GWBUF* pPacket) break; default: + switch (m_res.command()) { - ComQueryResponse::Row row(response); - - ComQueryResponse::Row::iterator i = row.begin(); - while (i != row.end()) + case MYSQL_COM_QUERY: { - const MaskingRules::Rule* pRule = m_res.get_rule(); + ComQueryResponse::Row row(response); - if (pRule) + ComQueryResponse::Row::iterator i = row.begin(); + while (i != row.end()) { - LEncString s = *i; + const MaskingRules::Rule* pRule = m_res.get_rule(); - if (!s.is_null()) + if (pRule) { - pRule->rewrite(s); - } + LEncString s = *i; - MXS_NOTICE("String: %s", (*i).to_string().c_str()); + if (!s.is_null()) + { + pRule->rewrite(s); + } + + MXS_NOTICE("String: %s", (*i).to_string().c_str()); + } + ++i; } - ++i; } + break; + + case MYSQL_COM_STMT_EXECUTE: + { + ComQueryResponse::BinaryRow row(response, m_res.types()); + + ComQueryResponse::BinaryRow::iterator i = row.begin(); + while (i != row.end()) + { + const MaskingRules::Rule* pRule = m_res.get_rule(); + + if (pRule) + { + ComQueryResponse::BinaryRow::Value value = *i; + + if (value.is_string()) + { + LEncString s = value.as_string(); + pRule->rewrite(s); + } + } + ++i; + } + } + break; + + default: + MXS_ERROR("Unexpected request: %d", m_res.command()); + ss_dassert(!true); } - break; } } diff --git a/server/modules/filter/masking/maskingfiltersession.hh b/server/modules/filter/masking/maskingfiltersession.hh index d413f366d..fbd94bebb 100644 --- a/server/modules/filter/masking/maskingfiltersession.hh +++ b/server/modules/filter/masking/maskingfiltersession.hh @@ -61,18 +61,26 @@ private: { public: ResponseState() - : m_nTotal_fields(0) + : m_command(0) + , m_nTotal_fields(0) , m_index(0) {} - void reset(const SMaskingRules& sRules) + void reset(uint8_t command, const SMaskingRules& sRules) { + m_command = command; m_sRules = sRules; m_nTotal_fields = 0; + m_types.clear(); m_rules.clear(); m_index = 0; } + uint8_t command() const + { + return m_command; + } + const SMaskingRules& rules() const { return m_sRules; @@ -82,13 +90,19 @@ private: void set_total_fields(uint32_t n) { m_nTotal_fields = n; } - bool append_rule(const MaskingRules::Rule* pRule) + bool append_type_and_rule(enum_field_types type, const MaskingRules::Rule* pRule) { + m_types.push_back(type); m_rules.push_back(pRule); return m_rules.size() == m_nTotal_fields; } + const std::vector& types() const + { + return m_types; + } + const MaskingRules::Rule* get_rule() { ss_dassert(m_nTotal_fields == m_rules.size()); @@ -101,8 +115,10 @@ private: } private: + uint8_t m_command; /* m_types; /* m_rules; /*