MXS-1346: Use std::unordered_map for storing user definitions

The users are now stored in a unordered_map which removes the need for the
use of HASHTABLE. Altered all functions to use a shared_ptr of a User
instead of a raw pointer. Made parsing of rules exception-safe.
This commit is contained in:
Markus Mäkelä
2017-08-30 18:49:02 +03:00
parent 17e7097b00
commit 824962d59a
2 changed files with 55 additions and 66 deletions

View File

@ -30,7 +30,7 @@ int main(int argc, char **argv)
MXS_NOTICE("Parsing rule file: %s", argv[1]); MXS_NOTICE("Parsing rule file: %s", argv[1]);
RuleList rules; RuleList rules;
HASHTABLE *users; UserMap users;
if (process_rule_file(argv[1], &rules, &users)) if (process_rule_file(argv[1], &rules, &users))
{ {

View File

@ -66,12 +66,11 @@
#include <stdio.h> #include <stdio.h>
#include <string.h> #include <string.h>
#include <time.h> #include <time.h>
#include <assert.h>
#include <regex.h>
#include <stdlib.h> #include <stdlib.h>
#include <string> #include <string>
#include <list> #include <list>
#include <tr1/memory> #include <tr1/memory>
#include <tr1/unordered_map>
#include <maxscale/filter.h> #include <maxscale/filter.h>
#include <maxscale/atomic.h> #include <maxscale/atomic.h>
@ -302,14 +301,8 @@ struct Rule
typedef std::tr1::shared_ptr<Rule> SRule; typedef std::tr1::shared_ptr<Rule> SRule;
typedef std::list<SRule> RuleList; typedef std::list<SRule> RuleList;
thread_local struct /** Typedef for a list of strings */
{ typedef std::list<std::string> ValueList;
int rule_version;
RuleList rules;
HASHTABLE *users;
} this_thread = {};
typedef std::list<std::string> ValueList;
/** /**
* A temporary template structure used in the creation of actual users. * A temporary template structure used in the creation of actual users.
@ -352,6 +345,9 @@ struct User
* fails. This is only for rules paired with 'match strict_all'. */ * fails. This is only for rules paired with 'match strict_all'. */
}; };
typedef std::tr1::shared_ptr<User> SUser;
typedef std::tr1::unordered_map<std::string, SUser> UserMap;
/** /**
* The Firewall filter instance. * The Firewall filter instance.
*/ */
@ -377,10 +373,18 @@ typedef struct
MXS_UPSTREAM up; /*< Next object in the upstream chain */ MXS_UPSTREAM up; /*< Next object in the upstream chain */
} FW_SESSION; } FW_SESSION;
/** The rules and users for each thread */
thread_local struct
{
int rule_version;
RuleList rules;
UserMap users;
} this_thread;
bool parse_at_times(const char** tok, char** saveptr, Rule* ruledef); bool parse_at_times(const char** tok, char** saveptr, Rule* ruledef);
bool parse_limit_queries(FW_INSTANCE* instance, Rule* ruledef, const char* rule, char** saveptr); bool parse_limit_queries(FW_INSTANCE* instance, Rule* ruledef, const char* rule, char** saveptr);
static void rule_free_all(Rule* rule); static void rule_free_all(Rule* rule);
static bool process_rule_file(const char* filename, RuleList* rules, HASHTABLE **users); static bool process_rule_file(const char* filename, RuleList* rules, UserMap* users);
bool replace_rules(FW_INSTANCE* instance); bool replace_rules(FW_INSTANCE* instance);
static void print_rule(Rule *rules, char *dest) static void print_rule(Rule *rules, char *dest)
@ -509,24 +513,6 @@ static STRLINK* strlink_reverse_clone(STRLINK* head)
return clone; return clone;
} }
static void dbfw_user_free(void* fval)
{
User* value = (User*) fval;
delete value;
}
HASHTABLE *dbfw_userlist_create()
{
HASHTABLE *ht = hashtable_alloc(100, hashtable_item_strhash, hashtable_item_strcmp);
if (ht)
{
hashtable_memory_fns(ht, hashtable_item_strdup, NULL, hashtable_item_free, dbfw_user_free);
}
return ht;
}
/** /**
* Parses a string that contains an IP address and converts the last octet to '%'. * Parses a string that contains an IP address and converts the last octet to '%'.
* This modifies the string passed as the parameter. * This modifies the string passed as the parameter.
@ -800,7 +786,7 @@ bool dbfw_reload_rules(const MODULECMD_ARG *argv, json_t** output)
spinlock_release(&inst->lock); spinlock_release(&inst->lock);
RuleList rules; RuleList rules;
HASHTABLE *users = NULL; UserMap users;
if (rval && access(filename, R_OK) == 0) if (rval && access(filename, R_OK) == 0)
{ {
@ -823,8 +809,6 @@ bool dbfw_reload_rules(const MODULECMD_ARG *argv, json_t** output)
rval = false; rval = false;
} }
hashtable_free(users);
return rval; return rval;
} }
@ -836,7 +820,7 @@ bool dbfw_show_rules(const MODULECMD_ARG *argv, json_t** output)
dcb_printf(dcb, "Rule, Type, Times Matched\n"); dcb_printf(dcb, "Rule, Type, Times Matched\n");
if (this_thread.rules.empty() || !this_thread.users) if (this_thread.rules.empty() || this_thread.users.empty())
{ {
if (!replace_rules(inst)) if (!replace_rules(inst))
{ {
@ -862,7 +846,7 @@ bool dbfw_show_rules_json(const MODULECMD_ARG *argv, json_t** output)
json_t* arr = json_array(); json_t* arr = json_array();
if (this_thread.rules.empty() || !this_thread.users) if (this_thread.rules.empty() || this_thread.users.empty())
{ {
if (!replace_rules(inst)) if (!replace_rules(inst))
{ {
@ -1416,7 +1400,7 @@ bool define_regex_rule(void* scanner, char* pattern)
* @param rules List of all rules * @param rules List of all rules
* @return True on success, false on error. * @return True on success, false on error.
*/ */
static bool process_user_templates(HASHTABLE *users, const TemplateList& templates, static bool process_user_templates(UserMap& users, const TemplateList& templates,
RuleList& rules) RuleList& rules)
{ {
bool rval = true; bool rval = true;
@ -1430,14 +1414,13 @@ static bool process_user_templates(HASHTABLE *users, const TemplateList& templat
for (TemplateList::const_iterator it = templates.begin(); it != templates.end(); it++) for (TemplateList::const_iterator it = templates.begin(); it != templates.end(); it++)
{ {
const SUserTemplate& ut = *it; const SUserTemplate& ut = *it;
User *user = (User*)hashtable_fetch(users, (void*)ut->name.c_str());
if (user == NULL) if (users.find(ut->name) == users.end())
{ {
user = new User(ut->name); users[ut->name] = SUser(new User(ut->name));
hashtable_add(users, (void*)user->name.c_str(), user);
} }
SUser& user = users[ut->name];
RuleList newrules; RuleList newrules;
for (ValueList::const_iterator r_it = ut->rulenames.begin(); for (ValueList::const_iterator r_it = ut->rulenames.begin();
@ -1484,7 +1467,7 @@ static bool process_user_templates(HASHTABLE *users, const TemplateList& templat
* @param instance Filter instance * @param instance Filter instance
* @return True on success, false on error. * @return True on success, false on error.
*/ */
static bool process_rule_file(const char* filename, RuleList* rules, HASHTABLE **users) static bool do_process_rule_file(const char* filename, RuleList* rules, UserMap* users)
{ {
int rc = 1; int rc = 1;
FILE *file = fopen(filename, "r"); FILE *file = fopen(filename, "r");
@ -1500,22 +1483,21 @@ static bool process_rule_file(const char* filename, RuleList* rules, HASHTABLE *
dbfw_yy_switch_to_buffer(buf, scanner); dbfw_yy_switch_to_buffer(buf, scanner);
/** Parse the rule file */ /** Parse the rule file */
MXS_EXCEPTION_GUARD(rc = dbfw_yyparse(scanner)); rc = dbfw_yyparse(scanner);
dbfw_yy_delete_buffer(buf, scanner); dbfw_yy_delete_buffer(buf, scanner);
dbfw_yylex_destroy(scanner); dbfw_yylex_destroy(scanner);
fclose(file); fclose(file);
HASHTABLE *new_users = dbfw_userlist_create(); UserMap new_users;
if (rc == 0 && new_users && process_user_templates(new_users, pstack.templates, pstack.rule)) if (rc == 0 && process_user_templates(new_users, pstack.templates, pstack.rule))
{ {
rules->swap(pstack.rule); rules->swap(pstack.rule);
*users = new_users; users->swap(new_users);
} }
else else
{ {
rc = 1; rc = 1;
hashtable_free(new_users);
MXS_ERROR("Failed to process rule file '%s'.", filename); MXS_ERROR("Failed to process rule file '%s'.", filename);
} }
} }
@ -1529,6 +1511,13 @@ static bool process_rule_file(const char* filename, RuleList* rules, HASHTABLE *
return rc == 0; return rc == 0;
} }
static bool process_rule_file(const char* filename, RuleList* rules, UserMap* users)
{
bool rval = false;
MXS_EXCEPTION_GUARD(rval = do_process_rule_file(filename, rules, users));
return rval;
}
/** /**
* @brief Replace the rule file used by this thread * @brief Replace the rule file used by this thread
* *
@ -1549,16 +1538,15 @@ bool replace_rules(FW_INSTANCE* instance)
spinlock_release(&instance->lock); spinlock_release(&instance->lock);
RuleList rules; RuleList rules;
HASHTABLE *users; UserMap users;
if (process_rule_file(filename, &rules, &users)) if (process_rule_file(filename, &rules, &users))
{ {
hashtable_free(this_thread.users);
this_thread.rules.swap(rules); this_thread.rules.swap(rules);
this_thread.users = users; this_thread.users.swap(users);
rval = true; rval = true;
} }
else if (!this_thread.rules.empty() && this_thread.users) else if (!this_thread.rules.empty() && !this_thread.users.empty())
{ {
MXS_ERROR("Failed to parse rules at '%s'. Old rules are still used.", filename); MXS_ERROR("Failed to parse rules at '%s'. Old rules are still used.", filename);
} }
@ -1608,7 +1596,7 @@ createInstance(const char *name, char **options, MXS_CONFIG_PARAMETER *params)
} }
RuleList rules; RuleList rules;
HASHTABLE *users = NULL; UserMap users;
my_instance->rulefile = MXS_STRDUP(config_get_string(params, "rules")); my_instance->rulefile = MXS_STRDUP(config_get_string(params, "rules"));
if (!my_instance->rulefile || !process_rule_file(my_instance->rulefile, &rules, &users)) if (!my_instance->rulefile || !process_rule_file(my_instance->rulefile, &rules, &users))
@ -1621,8 +1609,6 @@ createInstance(const char *name, char **options, MXS_CONFIG_PARAMETER *params)
atomic_add(&my_instance->rule_version, 1); atomic_add(&my_instance->rule_version, 1);
} }
hashtable_free(users);
return (MXS_FILTER *) my_instance; return (MXS_FILTER *) my_instance;
} }
@ -2226,7 +2212,7 @@ queryresolved:
* @return True if the query matches at least one of the rules otherwise false * @return True if the query matches at least one of the rules otherwise false
*/ */
bool check_match_any(FW_INSTANCE* my_instance, FW_SESSION* my_session, bool check_match_any(FW_INSTANCE* my_instance, FW_SESSION* my_session,
GWBUF *queue, User* user, char** rulename) GWBUF *queue, SUser user, char** rulename)
{ {
bool rval = false; bool rval = false;
@ -2306,7 +2292,7 @@ void append_string(char** dest, size_t* size, const char* src)
* @return True if the query matches all of the rules otherwise false * @return True if the query matches all of the rules otherwise false
*/ */
bool check_match_all(FW_INSTANCE* my_instance, FW_SESSION* my_session, bool check_match_all(FW_INSTANCE* my_instance, FW_SESSION* my_session,
GWBUF *queue, User* user, bool strict_all, char** rulename) GWBUF *queue, SUser user, bool strict_all, char** rulename)
{ {
bool rval = false; bool rval = false;
bool have_active_rule = false; bool have_active_rule = false;
@ -2360,35 +2346,38 @@ bool check_match_all(FW_INSTANCE* my_instance, FW_SESSION* my_session,
/** /**
* Retrieve the user specific data for this session * Retrieve the user specific data for this session
* *
* @param hash Hashtable containing the user data * @param users Map containing the user data
* @param name Username * @param name Username
* @param remote Remove network address * @param remote Remove network address
* @return The user data or NULL if it was not found * @return The user data or NULL if it was not found
*/ */
User* find_user_data(HASHTABLE *hash, const char *name, const char *remote) SUser find_user_data(const UserMap& users, const char *name, const char *remote)
{ {
char nameaddr[strlen(name) + strlen(remote) + 2]; char nameaddr[strlen(name) + strlen(remote) + 2];
snprintf(nameaddr, sizeof(nameaddr), "%s@%s", name, remote); snprintf(nameaddr, sizeof(nameaddr), "%s@%s", name, remote);
User* user = (User*) hashtable_fetch(hash, nameaddr); UserMap::const_iterator it = users.find(nameaddr);
if (user == NULL)
if (it == users.end())
{ {
char *ip_start = strchr(nameaddr, '@') + 1; char *ip_start = strchr(nameaddr, '@') + 1;
while (user == NULL && next_ip_class(ip_start)) while (it == users.end() && next_ip_class(ip_start))
{ {
user = (User*) hashtable_fetch(hash, nameaddr); it = users.find(nameaddr);
} }
if (user == NULL) if (it == users.end())
{ {
snprintf(nameaddr, sizeof(nameaddr), "%%@%s", remote); snprintf(nameaddr, sizeof(nameaddr), "%%@%s", remote);
ip_start = strchr(nameaddr, '@') + 1; ip_start = strchr(nameaddr, '@') + 1;
while (user == NULL && next_ip_class(ip_start))
while (it == users.end() && next_ip_class(ip_start))
{ {
user = (User*) hashtable_fetch(hash, nameaddr); it = users.find(nameaddr);
} }
} }
} }
return user;
return it != users.end() ? it->second : SUser();
} }
static bool command_is_mandatory(const GWBUF *buffer) static bool command_is_mandatory(const GWBUF *buffer)
@ -2470,7 +2459,7 @@ routeQuery(MXS_FILTER *instance, MXS_FILTER_SESSION *session, GWBUF *queue)
ss_dassert(analyzed_queue); ss_dassert(analyzed_queue);
} }
User *user = find_user_data(this_thread.users, dcb->user, dcb->remote); SUser user = find_user_data(this_thread.users, dcb->user, dcb->remote);
bool query_ok = command_is_mandatory(queue); bool query_ok = command_is_mandatory(queue);
if (user) if (user)