diff --git a/server/modules/filter/masking/maskingrules.cc b/server/modules/filter/masking/maskingrules.cc index 7fd368822..ea93e36df 100644 --- a/server/modules/filter/masking/maskingrules.cc +++ b/server/modules/filter/masking/maskingrules.cc @@ -12,7 +12,9 @@ */ #include "maskingrules.hh" +#include #include +#include #include #include #include @@ -612,6 +614,117 @@ auto_ptr MaskingRules::Rule::create_from(json_t* pRule) return sRule; } +namespace +{ + +class AccountMatcher : std::unary_function +{ +public: + AccountMatcher(const char* zUser, const char* zHost) + : m_zUser(zUser) + , m_zHost(zHost) + {} + + bool operator()(const MaskingRules::Rule::SAccount& sAccount) + { + return sAccount->matches(m_zUser, m_zHost); + } + +private: + const char* m_zUser; + const char* m_zHost; +}; + +} + +bool MaskingRules::Rule::matches(const ComQueryResponse::ColumnDef& column_def, + const char* zUser, + const char* zHost) const +{ + bool match = + (m_column == column_def.org_name()) && + (m_table.empty() || (m_table == column_def.org_table())) && + (m_database.empty() || (m_database == column_def.schema())); + + if (match) + { + // If the column matched, then we need to check whether the rule applies + // to the user and host. + + AccountMatcher matcher(zUser, zHost); + + if (m_applies_to.size() != 0) + { + match = false; + + vector::const_iterator i = std::find_if(m_applies_to.begin(), + m_applies_to.end(), + matcher); + + match = (i != m_applies_to.end()); + } + + if (match && (m_exempted.size() != 0)) + { + // If it is still a match, we need to check whether the user/host is + // exempted. + + vector::const_iterator i = std::find_if(m_exempted.begin(), + m_exempted.end(), + matcher); + + match = (i == m_exempted.end()); + } + } + + return match; +} + +void MaskingRules::Rule::rewrite(LEncString& s) const +{ + bool rewritten = false; + + size_t total_len = s.length(); + + if (!m_value.empty()) + { + if (m_value.length() == total_len) + { + std::copy(m_value.begin(), m_value.end(), s.begin()); + rewritten = true; + } + } + + if (!rewritten) + { + if (!m_fill.empty()) + { + LEncString::iterator i = s.begin(); + size_t len = m_fill.length(); + + while (total_len) + { + if (total_len < len) + { + len = total_len; + } + + std::copy(m_fill.data(), m_fill.data() + len, i); + + i += len; + total_len -= len; + } + } + else + { + MXS_ERROR("Length of returned value \"%s\" is %u, while length of " + "replacement value \"%s\" is %u, and no 'fill' value specified.", + s.to_string().c_str(), (unsigned)s.length(), + m_value.c_str(), (unsigned)m_value.length()); + } + } +} + // // MaskingRules // @@ -702,3 +815,50 @@ std::auto_ptr MaskingRules::create_from(json_t* pRoot) return sRules; } + +namespace +{ + +class RuleMatcher : std::unary_function +{ +public: + RuleMatcher(const ComQueryResponse::ColumnDef& column_def, + const char* zUser, + const char* zHost) + : m_column_def(column_def) + , m_zUser(zUser) + , m_zHost(zHost) + { + } + + bool operator()(const MaskingRules::SRule& sRule) + { + return sRule->matches(m_column_def, m_zUser, m_zHost); + } + +private: + const ComQueryResponse::ColumnDef& m_column_def; + const char* m_zUser; + const char* m_zHost; +}; + +} + +const MaskingRules::Rule* MaskingRules::get_rule_for(const ComQueryResponse::ColumnDef& column_def, + const char* zUser, + const char* zHost) const +{ + const Rule* pRule = NULL; + + RuleMatcher matcher(column_def, zUser, zHost); + vector::const_iterator i = std::find_if(m_rules.begin(), m_rules.end(), matcher); + + if (i != m_rules.end()) + { + const SRule& sRule = *i; + + pRule = sRule.get(); + } + + return pRule; +} diff --git a/server/modules/filter/masking/maskingrules.hh b/server/modules/filter/masking/maskingrules.hh index f42a7a615..b3021d786 100644 --- a/server/modules/filter/masking/maskingrules.hh +++ b/server/modules/filter/masking/maskingrules.hh @@ -18,6 +18,7 @@ #include #include #include +#include "mysql.hh" /** * @class MaskingRules @@ -26,6 +27,8 @@ */ class MaskingRules { + friend class MaskingRulesTester; + public: /** * @class Rule @@ -103,6 +106,21 @@ public: */ static std::auto_ptr create_from(json_t* pRule); + /** + * Establish whether a rule matches a column definition and user/host. + * + * @param column_def A column definition. + * @param zUser The current user. + * @param zHost The current host. + * + * @return True, if the rule matches. + */ + bool matches(const ComQueryResponse::ColumnDef& column_def, + const char* zUser, + const char* zHost) const; + + void rewrite(LEncString& s) const; + private: Rule(const Rule&); Rule& operator = (const Rule&); @@ -147,11 +165,24 @@ public: */ static std::auto_ptr create_from(json_t* pRoot); + /** + * Return the rule object that matches a column definition and user/host. + * + * @param column_def A column definition. + * @param zUser The current user. + * @param zHost The current host. + * + * @return A rule object that matches the column definition and user/host + * or NULL if no such rule object exists. + * + * @attention The returned object remains value only as long as the + * @c MaskingRules object remains valid. + */ + const Rule* get_rule_for(const ComQueryResponse::ColumnDef& column_def, + const char* zUser, + const char* zHost) const; + typedef std::tr1::shared_ptr SRule; - const std::vector& rules() const - { - return m_rules; - } private: MaskingRules(json_t* pRoot, const std::vector& rules); diff --git a/server/modules/filter/masking/test/testrules.cc b/server/modules/filter/masking/test/testrules.cc index 3c8f2293a..b1813ab38 100644 --- a/server/modules/filter/masking/test/testrules.cc +++ b/server/modules/filter/masking/test/testrules.cc @@ -11,6 +11,7 @@ * Public License. */ +#define TESTING_MASKINGRULES #include "maskingrules.hh" #include #include @@ -116,25 +117,6 @@ struct rule_test const size_t nRule_tests = (sizeof(rule_tests) / sizeof(rule_tests[0])); -int test_parsing() -{ - int rc = EXIT_SUCCESS; - - for (size_t i = 0; i < nRule_tests; i++) - { - const rule_test& test = rule_tests[i]; - - auto_ptr sRules = MaskingRules::parse(test.zJson); - - if ((sRules.get() && !test.valid) || (!sRules.get() && test.valid)) - { - rc = EXIT_FAILURE; - } - } - - return rc; -} - // Valid, lot's of users. const char valid_users[] = "{" @@ -190,49 +172,72 @@ struct expected_account const size_t nExpected_accounts = (sizeof(expected_accounts)/sizeof(expected_accounts[0])); -int test_account_handling() +class MaskingRulesTester { - int rc = EXIT_SUCCESS; - - auto_ptr sRules = MaskingRules::parse(valid_users); - ss_dassert(sRules.get()); - - const vector >& rules = sRules->rules(); - ss_dassert(rules.size() == 1); - - shared_ptr sRule = rules[0]; - - const vector >& accounts = sRule->applies_to(); - ss_dassert(accounts.size() == nExpected_accounts); - - int j = 0; - for (vector >::const_iterator i = accounts.begin(); - i != accounts.end(); - ++i) +public: + static int test_parsing() { - const expected_account& account = expected_accounts[j]; + int rc = EXIT_SUCCESS; - string user = (*i)->user(); - - if (user != account.zUser) + for (size_t i = 0; i < nRule_tests; i++) { - cout << j << ": Expected \"" << account.zUser << "\", got \"" << user << "\"." << endl; - rc = EXIT_FAILURE; + const rule_test& test = rule_tests[i]; + + auto_ptr sRules = MaskingRules::parse(test.zJson); + + if ((sRules.get() && !test.valid) || (!sRules.get() && test.valid)) + { + rc = EXIT_FAILURE; + } } - string host = (*i)->host(); - - if (host != account.zHost) - { - cout << j << ": Expected \"" << account.zHost << "\", got \"" << host << "\"." << endl; - rc = EXIT_FAILURE; - } - - ++j; + return rc; } - return rc; -} + static int test_account_handling() + { + int rc = EXIT_SUCCESS; + + auto_ptr sRules = MaskingRules::parse(valid_users); + ss_dassert(sRules.get()); + + const vector >& rules = sRules->m_rules; + ss_dassert(rules.size() == 1); + + shared_ptr sRule = rules[0]; + + const vector >& accounts = sRule->applies_to(); + ss_dassert(accounts.size() == nExpected_accounts); + + int j = 0; + for (vector >::const_iterator i = accounts.begin(); + i != accounts.end(); + ++i) + { + const expected_account& account = expected_accounts[j]; + + string user = (*i)->user(); + + if (user != account.zUser) + { + cout << j << ": Expected \"" << account.zUser << "\", got \"" << user << "\"." << endl; + rc = EXIT_FAILURE; + } + + string host = (*i)->host(); + + if (host != account.zHost) + { + cout << j << ": Expected \"" << account.zHost << "\", got \"" << host << "\"." << endl; + rc = EXIT_FAILURE; + } + + ++j; + } + + return rc; + } +}; int main() { @@ -240,8 +245,8 @@ int main() if (mxs_log_init(NULL, ".", MXS_LOG_TARGET_DEFAULT)) { - rc = (test_parsing() == EXIT_FAILURE) ? EXIT_FAILURE : EXIT_SUCCESS; - rc = (test_account_handling() == EXIT_FAILURE) ? EXIT_FAILURE : EXIT_SUCCESS; + rc = (MaskingRulesTester::test_parsing() == EXIT_FAILURE) ? EXIT_FAILURE : EXIT_SUCCESS; + rc = (MaskingRulesTester::test_account_handling() == EXIT_FAILURE) ? EXIT_FAILURE : EXIT_SUCCESS; } return rc;