From 3648b5e7027722403db2541c6800ac42323ea4cb Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Markus=20M=C3=A4kel=C3=A4?= Date: Mon, 4 Sep 2017 10:07:16 +0300 Subject: [PATCH] MXS-1346: Clean up dbfwfilter.cc Remove redundant code, move assignments to struct constructors, organize variable declarations, use standard library functions. --- .../modules/filter/dbfwfilter/dbfwfilter.cc | 173 +++++++----------- server/modules/filter/dbfwfilter/user.hh | 13 +- 2 files changed, 76 insertions(+), 110 deletions(-) diff --git a/server/modules/filter/dbfwfilter/dbfwfilter.cc b/server/modules/filter/dbfwfilter/dbfwfilter.cc index 55e899b1a..fd8661488 100644 --- a/server/modules/filter/dbfwfilter/dbfwfilter.cc +++ b/server/modules/filter/dbfwfilter/dbfwfilter.cc @@ -65,6 +65,8 @@ #include #include #include +#include +#include #include #include @@ -721,30 +723,24 @@ void dbfw_yyerror(void* scanner, const char* error) */ static SRule find_rule_by_name(const RuleList& rules, std::string name) { - for (RuleList::const_iterator it = rules.begin(); it != rules.end(); it++) + class RuleNameComparator { - const SRule& rule = *it; + public: + RuleNameComparator(std::string name): + m_name(name) + {} - if (rule->name() == name) + bool operator()(const SRule& rule) { - return rule; + return rule->name() == m_name; } - } - return SRule(); -} + private: + std::string m_name; + }; -/** - * Create a new rule - * - * The rule is created with the default type which will always match. The rule - * is later specialized by the definition of the actual rule. - * @param scanner Current scanner - * @param name Name of the rule - */ -static Rule* create_rule(const std::string& name) -{ - return new Rule(name); + RuleList::const_iterator it = std::find_if(rules.begin(), rules.end(), RuleNameComparator(name)); + return it != rules.end() ? *it : SRule(); } bool set_rule_name(void* scanner, char* name) @@ -859,10 +855,7 @@ bool create_user_templates(void* scanner) for (ValueList::const_iterator it = rstack->user.begin(); it != rstack->user.end(); it++) { - SUserTemplate newtemp = SUserTemplate(new UserTemplate); - newtemp->name = *it; - newtemp->rulenames = rstack->active_rules; - newtemp->type = rstack->active_mode; + SUserTemplate newtemp = SUserTemplate(new UserTemplate(*it, rstack->active_rules, rstack->active_mode)); rstack->templates.push_back(newtemp); } @@ -1205,6 +1198,7 @@ static MXS_FILTER_SESSION* newSession(MXS_FILTER *instance, MXS_SESSION *session { my_session->session = session; my_session->instance = my_instance; + my_session->errmsg = NULL; } return (MXS_FILTER_SESSION*)my_session; @@ -1260,52 +1254,23 @@ setDownstream(MXS_FILTER *instance, MXS_FILTER_SESSION *session, MXS_DOWNSTREAM */ GWBUF* gen_dummy_error(FW_SESSION* session, char* msg) { - GWBUF* buf; - char* errmsg; - DCB* dcb; - MYSQL_session* mysql_session; - unsigned int errlen; - - if (session == NULL || session->session == NULL || - session->session->client_dcb == NULL || - session->session->client_dcb->data == NULL) - { - MXS_ERROR("Firewall filter session missing data."); - return NULL; - } - - dcb = session->session->client_dcb; + ss_dassert(session && session->session && session->session->client_dcb); + DCB* dcb = session->session->client_dcb; const char* db = mxs_mysql_get_current_db(session->session); - errlen = msg != NULL ? strlen(msg) : 0; - errmsg = (char*) MXS_MALLOC((512 + errlen) * sizeof(char)); + std::stringstream ss; + ss << "Access denied for user '" << dcb->user << "'@'" << dcb->remote << "'"; - if (errmsg == NULL) + if (db[0]) { - return NULL; + ss << " to database '" << db << "'"; } - - if (db[0] == '\0') + if (msg) { - sprintf(errmsg, "Access denied for user '%s'@'%s'", dcb->user, dcb->remote); - } - else - { - sprintf(errmsg, "Access denied for user '%s'@'%s' to database '%s'", - dcb->user, dcb->remote, db); + ss << ": " << msg; } - if (msg != NULL) - { - char* ptr = strchr(errmsg, '\0'); - sprintf(ptr, ": %s", msg); - - } - - buf = modutil_create_mysql_err_msg(1, 0, 1141, "HY000", (const char*) errmsg); - MXS_FREE(errmsg); - - return buf; + return modutil_create_mysql_err_msg(1, 0, 1141, "HY000", ss.str().c_str()); } /** @@ -1340,11 +1305,7 @@ bool inside_timerange(TIMERANGE* comp) to_before = difftime(now, before); to_after = difftime(now, after); - if (to_before > 0.0 && to_after < 0.0) - { - return true; - } - return false; + return to_before > 0.0 && to_after < 0.0; } /** @@ -1354,21 +1315,23 @@ bool inside_timerange(TIMERANGE* comp) */ bool rule_is_active(SRule rule) { - TIMERANGE* times; - if (rule->active != NULL) + bool rval = true; + + if (rule->active) { - times = (TIMERANGE*) rule->active; - while (times) + rval = false; + + for (TIMERANGE* times = rule->active; times; times = times->next) { if (inside_timerange(times)) { - return true; + rval = true; + break; } - times = times->next; } - return false; } - return true; + + return rval; } /** @@ -1467,40 +1430,23 @@ bool rule_matches(FW_INSTANCE* my_instance, if (parse_result == QC_QUERY_INVALID) { msg = create_parse_error(my_instance, "tokenized", query, &matches); - goto queryresolved; } else if (parse_result != QC_QUERY_PARSED && rule->need_full_parsing(queue)) { msg = create_parse_error(my_instance, "parsed completely", query, &matches); - goto queryresolved; } } - if (rule->matches_query_type(queue)) + if (msg == NULL && rule->matches_query_type(queue)) { - if (rule->matches_query(my_session, queue, &msg)) + if ((matches = rule->matches_query(my_session, queue, &msg))) { - /** New style rule matched */ - matches = true; - goto queryresolved; + rule->times_matched++; } } -queryresolved: - if (msg) - { - if (my_session->errmsg) - { - MXS_FREE(my_session->errmsg); - } - - my_session->errmsg = msg; - } - - if (matches) - { - rule->times_matched++; - } + MXS_FREE(my_session->errmsg); + my_session->errmsg = msg; return matches; } @@ -1560,6 +1506,24 @@ static bool command_is_mandatory(const GWBUF *buffer) } } +static bool update_rules(FW_INSTANCE* my_instance) +{ + bool rval = true; + int rule_version = my_instance->rule_version; + + if (this_thread.rule_version < rule_version) + { + if (!replace_rules(my_instance)) + { + rval = false; + } + + this_thread.rule_version = rule_version; + } + + return rval; +} + /** * The routeQuery entry point. This is passed the query buffer * to which the filter should be applied. Once processed the @@ -1573,22 +1537,17 @@ static bool command_is_mandatory(const GWBUF *buffer) static int routeQuery(MXS_FILTER *instance, MXS_FILTER_SESSION *session, GWBUF *queue) { - FW_SESSION *my_session = (FW_SESSION *) session; FW_INSTANCE *my_instance = (FW_INSTANCE *) instance; + + if (!update_rules(my_instance)) + { + return 0; + } + + FW_SESSION *my_session = (FW_SESSION *) session; DCB *dcb = my_session->session->client_dcb; int rval = 0; ss_dassert(dcb && dcb->session); - int rule_version = my_instance->rule_version; - - if (this_thread.rule_version < rule_version) - { - if (!replace_rules(my_instance)) - { - return 0; - } - this_thread.rule_version = rule_version; - } - uint32_t type = 0; if (modutil_is_SQL(queue) || modutil_is_SQL_prepare(queue)) diff --git a/server/modules/filter/dbfwfilter/user.hh b/server/modules/filter/dbfwfilter/user.hh index 451ae928e..467c62893 100644 --- a/server/modules/filter/dbfwfilter/user.hh +++ b/server/modules/filter/dbfwfilter/user.hh @@ -22,9 +22,16 @@ */ struct UserTemplate { - std::string name; /** Name of the user */ - enum match_type type; /** Matching type */ - ValueList rulenames; /** Names of the rules */ + UserTemplate(std::string name, const ValueList& rules, match_type mode): + name(name), + type(mode), + rulenames(rules) + { + } + + std::string name; /** Name of the user */ + match_type type; /** Matching type */ + ValueList rulenames; /** Names of the rules */ }; typedef std::tr1::shared_ptr SUserTemplate;