diff --git a/server/modules/filter/masking/maskingfiltersession.cc b/server/modules/filter/masking/maskingfiltersession.cc index 00e14a8e8..483347e40 100644 --- a/server/modules/filter/masking/maskingfiltersession.cc +++ b/server/modules/filter/masking/maskingfiltersession.cc @@ -571,9 +571,19 @@ bool MaskingFilterSession::is_variable_defined(GWBUF* pPacket, const char* zUser SMaskingRules sRules = m_filter.rules(); auto pred = [&sRules, zUser, zHost](const QC_FIELD_INFO& field_info) { - const MaskingRules::Rule* pRule = sRules->get_rule_for(field_info, zUser, zHost); + bool rv = false; - return pRule ? true : false; + if (strcmp(field_info.column, "*") == 0) + { + // If "*" is used, then we must block if there is any rule for the current user. + rv = sRules->has_rule_for(zUser, zHost); + } + else + { + rv = sRules->get_rule_for(field_info, zUser, zHost) ? true : false; + } + + return rv; }; const QC_FIELD_INFO* pInfos; @@ -588,12 +598,22 @@ bool MaskingFilterSession::is_variable_defined(GWBUF* pPacket, const char* zUser if (i != end) { + const char* zColumn = i->column; + std::stringstream ss; - ss << "The field " << i->column << " that should be masked for '" << zUser << "'@'" << zHost - << "' is used when defining a variable, access is denied."; + + if (strcmp(zColumn, "*") == 0) + { + ss << "'*' is used in the definition of a variable and there are masking rules " + << "for '" << zUser << "'@'" << zHost << "', access is denied."; + } + else + { + ss << "The field " << i->column << " that should be masked for '" << zUser << "'@'" << zHost + << "' is used when defining a variable, access is denied."; + } set_response(create_error_response(ss.str().c_str())); - is_defined = true; } @@ -618,16 +638,12 @@ bool MaskingFilterSession::is_union_used(GWBUF* pPacket, const char* zUser, cons { if (strcmp(field_info.column, "*") == 0) { - rv = true; + // If "*" is used, then we must block if there is any rule for the current user. + rv = sRules->has_rule_for(zUser, zHost); } else { - const MaskingRules::Rule* pRule = sRules->get_rule_for(field_info, zUser, zHost); - - if (pRule) - { - rv = true; - } + rv = sRules->get_rule_for(field_info, zUser, zHost) ? true : false; } } @@ -652,9 +668,8 @@ bool MaskingFilterSession::is_union_used(GWBUF* pPacket, const char* zUser, cons if (strcmp(zColumn, "*") == 0) { - ss << "'*' is used in the second or subsequent SELECT of a UNION, which " - << "may include a field that should be masked for '" << zUser << "'@'" << zHost - << "', access is denied."; + ss << "'*' is used in the second or subsequent SELECT of a UNION and there are " + << "masking rules for '" << zUser << "'@'" << zHost << "', access is denied."; } else { @@ -663,7 +678,6 @@ bool MaskingFilterSession::is_union_used(GWBUF* pPacket, const char* zUser, cons } set_response(create_error_response(ss.str().c_str())); - is_used = true; } diff --git a/server/modules/filter/masking/maskingrules.cc b/server/modules/filter/masking/maskingrules.cc index 47029074b..b64ce11d0 100644 --- a/server/modules/filter/masking/maskingrules.cc +++ b/server/modules/filter/masking/maskingrules.cc @@ -1467,3 +1467,12 @@ const MaskingRules::Rule* MaskingRules::get_rule_for(const QC_FIELD_INFO& field_ return pRule; } + +bool MaskingRules::has_rule_for(const char* zUser, const char* zHost) const +{ + auto i = std::find_if(m_rules.begin(), m_rules.end(), [zUser, zHost](SRule sRule) { + return sRule->matches_account(zUser, zHost); + }); + + return i != m_rules.end(); +} diff --git a/server/modules/filter/masking/maskingrules.hh b/server/modules/filter/masking/maskingrules.hh index 9b7a80976..b74b5f096 100644 --- a/server/modules/filter/masking/maskingrules.hh +++ b/server/modules/filter/masking/maskingrules.hh @@ -143,13 +143,21 @@ public: */ virtual void rewrite(LEncString& s) const = 0; + /** + * Does this rule apply to a specific account. + * + * @param zUser The current user. + * @param zHost The current host. + * + * @return True, if the rule applies. + */ + bool matches_account(const char* zUser, + const char* zHost) const; + private: Rule(const Rule&); Rule& operator=(const Rule&); - bool matches_account(const char* zUser, - const char* zHost) const; - private: std::string m_column; std::string m_table; @@ -396,6 +404,16 @@ public: typedef std::shared_ptr SRule; + /** + * Is there any rule for the specified user. + * + * @param zUser The current user. + * @param zHost The current host. + * + * @return True, if there is a rule for that user/host combination. + */ + bool has_rule_for(const char* zUser, const char* zHost) const; + private: MaskingRules(json_t* pRoot, const std::vector& rules);