MXS-1346: Use std::unordered_map for storing user definitions

The users are now stored in a unordered_map which removes the need for the
use of HASHTABLE. Altered all functions to use a shared_ptr of a User
instead of a raw pointer. Made parsing of rules exception-safe.
This commit is contained in:
Markus Mäkelä
2017-08-30 18:49:02 +03:00
parent 17e7097b00
commit 824962d59a
2 changed files with 55 additions and 66 deletions

View File

@ -30,7 +30,7 @@ int main(int argc, char **argv)
MXS_NOTICE("Parsing rule file: %s", argv[1]);
RuleList rules;
HASHTABLE *users;
UserMap users;
if (process_rule_file(argv[1], &rules, &users))
{

View File

@ -66,12 +66,11 @@
#include <stdio.h>
#include <string.h>
#include <time.h>
#include <assert.h>
#include <regex.h>
#include <stdlib.h>
#include <string>
#include <list>
#include <tr1/memory>
#include <tr1/unordered_map>
#include <maxscale/filter.h>
#include <maxscale/atomic.h>
@ -302,14 +301,8 @@ struct Rule
typedef std::tr1::shared_ptr<Rule> SRule;
typedef std::list<SRule> RuleList;
thread_local struct
{
int rule_version;
RuleList rules;
HASHTABLE *users;
} this_thread = {};
typedef std::list<std::string> ValueList;
/** Typedef for a list of strings */
typedef std::list<std::string> ValueList;
/**
* A temporary template structure used in the creation of actual users.
@ -352,6 +345,9 @@ struct User
* fails. This is only for rules paired with 'match strict_all'. */
};
typedef std::tr1::shared_ptr<User> SUser;
typedef std::tr1::unordered_map<std::string, SUser> UserMap;
/**
* The Firewall filter instance.
*/
@ -377,10 +373,18 @@ typedef struct
MXS_UPSTREAM up; /*< Next object in the upstream chain */
} FW_SESSION;
/** The rules and users for each thread */
thread_local struct
{
int rule_version;
RuleList rules;
UserMap users;
} this_thread;
bool parse_at_times(const char** tok, char** saveptr, Rule* ruledef);
bool parse_limit_queries(FW_INSTANCE* instance, Rule* ruledef, const char* rule, char** saveptr);
static void rule_free_all(Rule* rule);
static bool process_rule_file(const char* filename, RuleList* rules, HASHTABLE **users);
static bool process_rule_file(const char* filename, RuleList* rules, UserMap* users);
bool replace_rules(FW_INSTANCE* instance);
static void print_rule(Rule *rules, char *dest)
@ -509,24 +513,6 @@ static STRLINK* strlink_reverse_clone(STRLINK* head)
return clone;
}
static void dbfw_user_free(void* fval)
{
User* value = (User*) fval;
delete value;
}
HASHTABLE *dbfw_userlist_create()
{
HASHTABLE *ht = hashtable_alloc(100, hashtable_item_strhash, hashtable_item_strcmp);
if (ht)
{
hashtable_memory_fns(ht, hashtable_item_strdup, NULL, hashtable_item_free, dbfw_user_free);
}
return ht;
}
/**
* Parses a string that contains an IP address and converts the last octet to '%'.
* This modifies the string passed as the parameter.
@ -800,7 +786,7 @@ bool dbfw_reload_rules(const MODULECMD_ARG *argv, json_t** output)
spinlock_release(&inst->lock);
RuleList rules;
HASHTABLE *users = NULL;
UserMap users;
if (rval && access(filename, R_OK) == 0)
{
@ -823,8 +809,6 @@ bool dbfw_reload_rules(const MODULECMD_ARG *argv, json_t** output)
rval = false;
}
hashtable_free(users);
return rval;
}
@ -836,7 +820,7 @@ bool dbfw_show_rules(const MODULECMD_ARG *argv, json_t** output)
dcb_printf(dcb, "Rule, Type, Times Matched\n");
if (this_thread.rules.empty() || !this_thread.users)
if (this_thread.rules.empty() || this_thread.users.empty())
{
if (!replace_rules(inst))
{
@ -862,7 +846,7 @@ bool dbfw_show_rules_json(const MODULECMD_ARG *argv, json_t** output)
json_t* arr = json_array();
if (this_thread.rules.empty() || !this_thread.users)
if (this_thread.rules.empty() || this_thread.users.empty())
{
if (!replace_rules(inst))
{
@ -1416,7 +1400,7 @@ bool define_regex_rule(void* scanner, char* pattern)
* @param rules List of all rules
* @return True on success, false on error.
*/
static bool process_user_templates(HASHTABLE *users, const TemplateList& templates,
static bool process_user_templates(UserMap& users, const TemplateList& templates,
RuleList& rules)
{
bool rval = true;
@ -1430,14 +1414,13 @@ static bool process_user_templates(HASHTABLE *users, const TemplateList& templat
for (TemplateList::const_iterator it = templates.begin(); it != templates.end(); it++)
{
const SUserTemplate& ut = *it;
User *user = (User*)hashtable_fetch(users, (void*)ut->name.c_str());
if (user == NULL)
if (users.find(ut->name) == users.end())
{
user = new User(ut->name);
hashtable_add(users, (void*)user->name.c_str(), user);
users[ut->name] = SUser(new User(ut->name));
}
SUser& user = users[ut->name];
RuleList newrules;
for (ValueList::const_iterator r_it = ut->rulenames.begin();
@ -1484,7 +1467,7 @@ static bool process_user_templates(HASHTABLE *users, const TemplateList& templat
* @param instance Filter instance
* @return True on success, false on error.
*/
static bool process_rule_file(const char* filename, RuleList* rules, HASHTABLE **users)
static bool do_process_rule_file(const char* filename, RuleList* rules, UserMap* users)
{
int rc = 1;
FILE *file = fopen(filename, "r");
@ -1500,22 +1483,21 @@ static bool process_rule_file(const char* filename, RuleList* rules, HASHTABLE *
dbfw_yy_switch_to_buffer(buf, scanner);
/** Parse the rule file */
MXS_EXCEPTION_GUARD(rc = dbfw_yyparse(scanner));
rc = dbfw_yyparse(scanner);
dbfw_yy_delete_buffer(buf, scanner);
dbfw_yylex_destroy(scanner);
fclose(file);
HASHTABLE *new_users = dbfw_userlist_create();
UserMap new_users;
if (rc == 0 && new_users && process_user_templates(new_users, pstack.templates, pstack.rule))
if (rc == 0 && process_user_templates(new_users, pstack.templates, pstack.rule))
{
rules->swap(pstack.rule);
*users = new_users;
users->swap(new_users);
}
else
{
rc = 1;
hashtable_free(new_users);
MXS_ERROR("Failed to process rule file '%s'.", filename);
}
}
@ -1529,6 +1511,13 @@ static bool process_rule_file(const char* filename, RuleList* rules, HASHTABLE *
return rc == 0;
}
static bool process_rule_file(const char* filename, RuleList* rules, UserMap* users)
{
bool rval = false;
MXS_EXCEPTION_GUARD(rval = do_process_rule_file(filename, rules, users));
return rval;
}
/**
* @brief Replace the rule file used by this thread
*
@ -1549,16 +1538,15 @@ bool replace_rules(FW_INSTANCE* instance)
spinlock_release(&instance->lock);
RuleList rules;
HASHTABLE *users;
UserMap users;
if (process_rule_file(filename, &rules, &users))
{
hashtable_free(this_thread.users);
this_thread.rules.swap(rules);
this_thread.users = users;
this_thread.users.swap(users);
rval = true;
}
else if (!this_thread.rules.empty() && this_thread.users)
else if (!this_thread.rules.empty() && !this_thread.users.empty())
{
MXS_ERROR("Failed to parse rules at '%s'. Old rules are still used.", filename);
}
@ -1608,7 +1596,7 @@ createInstance(const char *name, char **options, MXS_CONFIG_PARAMETER *params)
}
RuleList rules;
HASHTABLE *users = NULL;
UserMap users;
my_instance->rulefile = MXS_STRDUP(config_get_string(params, "rules"));
if (!my_instance->rulefile || !process_rule_file(my_instance->rulefile, &rules, &users))
@ -1621,8 +1609,6 @@ createInstance(const char *name, char **options, MXS_CONFIG_PARAMETER *params)
atomic_add(&my_instance->rule_version, 1);
}
hashtable_free(users);
return (MXS_FILTER *) my_instance;
}
@ -2226,7 +2212,7 @@ queryresolved:
* @return True if the query matches at least one of the rules otherwise false
*/
bool check_match_any(FW_INSTANCE* my_instance, FW_SESSION* my_session,
GWBUF *queue, User* user, char** rulename)
GWBUF *queue, SUser user, char** rulename)
{
bool rval = false;
@ -2306,7 +2292,7 @@ void append_string(char** dest, size_t* size, const char* src)
* @return True if the query matches all of the rules otherwise false
*/
bool check_match_all(FW_INSTANCE* my_instance, FW_SESSION* my_session,
GWBUF *queue, User* user, bool strict_all, char** rulename)
GWBUF *queue, SUser user, bool strict_all, char** rulename)
{
bool rval = false;
bool have_active_rule = false;
@ -2360,35 +2346,38 @@ bool check_match_all(FW_INSTANCE* my_instance, FW_SESSION* my_session,
/**
* Retrieve the user specific data for this session
*
* @param hash Hashtable containing the user data
* @param users Map containing the user data
* @param name Username
* @param remote Remove network address
* @return The user data or NULL if it was not found
*/
User* find_user_data(HASHTABLE *hash, const char *name, const char *remote)
SUser find_user_data(const UserMap& users, const char *name, const char *remote)
{
char nameaddr[strlen(name) + strlen(remote) + 2];
snprintf(nameaddr, sizeof(nameaddr), "%s@%s", name, remote);
User* user = (User*) hashtable_fetch(hash, nameaddr);
if (user == NULL)
UserMap::const_iterator it = users.find(nameaddr);
if (it == users.end())
{
char *ip_start = strchr(nameaddr, '@') + 1;
while (user == NULL && next_ip_class(ip_start))
while (it == users.end() && next_ip_class(ip_start))
{
user = (User*) hashtable_fetch(hash, nameaddr);
it = users.find(nameaddr);
}
if (user == NULL)
if (it == users.end())
{
snprintf(nameaddr, sizeof(nameaddr), "%%@%s", remote);
ip_start = strchr(nameaddr, '@') + 1;
while (user == NULL && next_ip_class(ip_start))
while (it == users.end() && next_ip_class(ip_start))
{
user = (User*) hashtable_fetch(hash, nameaddr);
it = users.find(nameaddr);
}
}
}
return user;
return it != users.end() ? it->second : SUser();
}
static bool command_is_mandatory(const GWBUF *buffer)
@ -2470,7 +2459,7 @@ routeQuery(MXS_FILTER *instance, MXS_FILTER_SESSION *session, GWBUF *queue)
ss_dassert(analyzed_queue);
}
User *user = find_user_data(this_thread.users, dcb->user, dcb->remote);
SUser user = find_user_data(this_thread.users, dcb->user, dcb->remote);
bool query_ok = command_is_mandatory(queue);
if (user)