Switched over to hashtables for users and columns.

This commit is contained in:
Markus Makela
2014-10-15 15:23:14 +03:00
parent 104e79a591
commit effe8f3297
3 changed files with 211 additions and 98 deletions

View File

@ -1186,15 +1186,48 @@ bool is_drop_table_query(GWBUF* querybuf)
lex->sql_command == SQLCOM_DROP_TABLE; lex->sql_command == SQLCOM_DROP_TABLE;
} }
char* skygw_get_where_clause(GWBUF* buf) inline void add_str(char** buf, int* buflen, int* bufsize, char* str)
{
int isize = strlen(str) + 1;
if(*buf == NULL || isize + *buflen >= *bufsize)
{
char *tmp = (char*)calloc((*bufsize) * 2 + isize, sizeof(char));
if(tmp){
memcpy(tmp,*buf,*bufsize);
if(*buf){
free(*buf);
}
*buf = tmp;
*bufsize = (*bufsize) * 2 + isize;
}
}
if(*buflen > 0){
strcat(*buf," ");
}
strcat(*buf,str);
*buflen += isize;
}
/**
* Returns all the fields that the query affects.
* @param buf Buffer to parse
* @return Pointer to newly allocated string or NULL if nothing was found
*/
char* skygw_get_affected_fields(GWBUF* buf)
{ {
LEX* lex; LEX* lex;
unsigned int buffsz = 0,bufflen = 0; int buffsz = 0,bufflen = 0;
char* where = NULL; char* where = NULL;
Item* item; Item* item;
Item::Type itype;
if(!query_is_parsed(buf)){ if(!query_is_parsed(buf)){
parse_query(buf); parse_query(buf);
} }
if((lex = get_lex(buf)) == NULL){ if((lex = get_lex(buf)) == NULL){
return NULL; return NULL;
} }
@ -1203,43 +1236,72 @@ char* skygw_get_where_clause(GWBUF* buf)
while(lex->current_select) while(lex->current_select)
{ {
List_iterator<Item> ilist(lex->current_select->item_list);
item = (Item*)ilist.next();
for (item; item != NULL; item=(Item*)ilist.next())
{
itype = item->type();
if(item->name && itype == Item::FIELD_ITEM){
add_str(&where,&buffsz,&bufflen,item->name);
}
}
if(lex->current_select->where){ if(lex->current_select->where){
for (item=lex->current_select->where; item != NULL; item=item->next) for (item=lex->current_select->where; item != NULL; item=item->next)
{ {
Item::Type tp = item->type(); itype = item->type();
if(item->name && tp == Item::FIELD_ITEM){ if(item->name && itype == Item::FIELD_ITEM){
add_str(&where,&buffsz,&bufflen,item->name);
}
}
}
int isize = strlen(item->name) + 1; if(lex->current_select->having){
if(where == NULL || isize + bufflen >= buffsz) for (item=lex->current_select->having; item != NULL; item=item->next)
{ {
char *tmp = (char*)calloc(buffsz*2 + isize,sizeof(char));
if(tmp){ itype = item->type();
memcpy(tmp,where,buffsz); if(item->name && itype == Item::FIELD_ITEM){
if(where){ add_str(&where,&buffsz,&bufflen,item->name);
free(where);
} }
where = tmp;
buffsz = buffsz*2 + isize;
}else{
return NULL;
} }
} }
if(bufflen > 0){
strcat(where," ");
}
strcat(where,item->name);
bufflen += isize;
}
}
}
lex->current_select = lex->current_select->next_select_in_list(); lex->current_select = lex->current_select->next_select_in_list();
} }
return where; return where;
} }
bool skygw_query_has_clause(GWBUF* buf)
{
LEX* lex;
bool clause = false;
if(!query_is_parsed(buf)){
parse_query(buf);
}
if((lex = get_lex(buf)) == NULL){
return false;
}
lex->current_select = lex->all_selects_list;
while(lex->current_select)
{
if(lex->current_select->where || lex->current_select->having){
clause = true;
}
lex->current_select = lex->current_select->next_select_in_list();
}
return clause;
}
/* /*
* Replace user-provided literals with question marks. Return a copy of the * Replace user-provided literals with question marks. Return a copy of the
* querystr with replacements. * querystr with replacements.

View File

@ -100,8 +100,9 @@ bool parse_query (GWBUF* querybuf);
parsing_info_t* parsing_info_init(void (*donefun)(void *)); parsing_info_t* parsing_info_init(void (*donefun)(void *));
void parsing_info_done(void* ptr); void parsing_info_done(void* ptr);
bool query_is_parsed(GWBUF* buf); bool query_is_parsed(GWBUF* buf);
bool skygw_query_has_clause(GWBUF* buf);
char* skygw_get_qtype_str(skygw_query_type_t qtype); char* skygw_get_qtype_str(skygw_query_type_t qtype);
char* skygw_get_where_clause(GWBUF* buf); char* skygw_get_affected_fields(GWBUF* buf);
EXTERN_C_BLOCK_END EXTERN_C_BLOCK_END

View File

@ -73,21 +73,27 @@ static FILTER_OBJECT MyObject = {
diagnostic, diagnostic,
}; };
#define QUERY_TYPES 5
/** /**
* Query types * Query types
*/ */
typedef enum{
#define QUERY_TYPES 5
enum querytype_t{
ALL, ALL,
SELECT, SELECT,
INSERT, INSERT,
UPDATE, UPDATE,
DELETE DELETE
}; }querytype_t;
/**
* Rule types
*/
typedef enum {
RT_UNDEFINED,
RT_USER,
RT_COLUMN
}ruletype_t;
/** /**
* Generic linked list of string values * Generic linked list of string values
@ -111,8 +117,7 @@ typedef struct iprange_t{
* The Firewall filter instance. * The Firewall filter instance.
*/ */
typedef struct { typedef struct {
ITEM* columns; HASHTABLE* htable;
ITEM* users;
IPRANGE* networks; IPRANGE* networks;
int column_count, column_size, user_count, user_size; int column_count, column_size, user_count, user_size;
bool require_where[QUERY_TYPES]; bool require_where[QUERY_TYPES];
@ -129,6 +134,45 @@ typedef struct {
SESSION* session; SESSION* session;
} FW_SESSION; } FW_SESSION;
static int hashkeyfun(void* key);
static int hashcmpfun (void *, void *);
static int hashkeyfun(
void* key)
{
if(key == NULL){
return 0;
}
unsigned int hash = 0,c = 0;
char* ptr = (char*)key;
while((c = *ptr++)){
hash = c + (hash << 6) + (hash << 16) - hash;
}
return (int)hash > 0 ? hash : -hash;
}
static int hashcmpfun(
void* v1,
void* v2)
{
char* i1 = (char*) v1;
char* i2 = (char*) v2;
return strcmp(i1,i2);
}
static void* hstrdup(void* fval)
{
char* str = (char*)fval;
return strdup(str);
}
static void* hfree(void* fval)
{
free (fval);
return NULL;
}
/** /**
* Utility function to check if a string contains a valid IP address. * Utility function to check if a string contains a valid IP address.
@ -308,6 +352,7 @@ uint32_t strtosubmask(char* str)
} }
/** /**
* Implementation of the mandatory version entry point * Implementation of the mandatory version entry point
* *
@ -369,11 +414,10 @@ void parse_rule(char* rule, FW_INSTANCE* instance)
instance->networks = rng; instance->networks = rng;
} }
}else{ /**Add usernames or columns*/ }else{ /**Add rules on usernames or columns*/
char *tok = strtok(ptr," ,\0"); char *tok = strtok(ptr," ,\0");
ITEM* prev = NULL; bool is_user = false, is_column = false, is_time = false;
bool is_user = false, is_column = false;
if(strcmp(tok,"wildcard") == 0){ if(strcmp(tok,"wildcard") == 0){
instance->block_wildcard = block ? true : false; instance->block_wildcard = block ? true : false;
@ -381,11 +425,9 @@ void parse_rule(char* rule, FW_INSTANCE* instance)
} }
if(strcmp(tok,"users") == 0){/**Adding users*/ if(strcmp(tok,"users") == 0){/**Adding users*/
prev = instance->users;
instance->whitelist_users = mode; instance->whitelist_users = mode;
is_user = true; is_user = true;
}else if(strcmp(tok,"columns") == 0){/**Adding Columns*/ }else if(strcmp(tok,"columns") == 0){/**Adding Columns*/
prev = instance->columns;
is_column = true; is_column = true;
} }
@ -394,20 +436,16 @@ void parse_rule(char* rule, FW_INSTANCE* instance)
if(is_user || is_column){ if(is_user || is_column){
while(tok){ while(tok){
ITEM* item = calloc(1,sizeof(ITEM)); /**Add value to hashtable*/
if(item){
item->next = prev; ruletype_t rtype = is_user ? RT_USER : is_column ? RT_COLUMN: RT_UNDEFINED;
item->value = strdup(tok); hashtable_add(instance->htable,
prev = item; (void *)tok,
} (void *)rtype);
tok = strtok(NULL," ,\0"); tok = strtok(NULL," ,\0");
} }
if(is_user){
instance->users = prev;
}else if(is_column){
instance->columns = prev;
}
} }
} }
@ -458,7 +496,18 @@ createInstance(char **options, FILTER_PARAMETER **params)
return NULL; return NULL;
} }
int i; int i;
HASHTABLE* ht;
if((ht = hashtable_alloc(7, hashkeyfun, hashcmpfun)) == NULL){
skygw_log_write(LOGFILE_ERROR, "Unable to allocate hashtable.");
return NULL;
}
hashtable_memory_fns(ht,hstrdup,NULL,hfree,NULL);
my_instance->htable = ht;
my_instance->def_op = true; my_instance->def_op = true;
for(i = 0;params[i];i++){ for(i = 0;params[i];i++){
if(strstr(params[i]->name,"rule")){ if(strstr(params[i]->name,"rule")){
parse_rule(strip_tags(params[i]->value),my_instance); parse_rule(strip_tags(params[i]->value),my_instance);
@ -584,23 +633,22 @@ routeQuery(FILTER *instance, void *session, GWBUF *queue)
FW_SESSION *my_session = (FW_SESSION *)session; FW_SESSION *my_session = (FW_SESSION *)session;
FW_INSTANCE *my_instance = (FW_INSTANCE *)instance; FW_INSTANCE *my_instance = (FW_INSTANCE *)instance;
IPRANGE* ipranges = my_instance->networks; IPRANGE* ipranges = my_instance->networks;
ITEM *users = my_instance->users, *columns = my_instance->columns;
bool accept = false, match = false; bool accept = false, match = false;
char *where,*query; char *where;
uint32_t ip; uint32_t ip;
int len; ruletype_t rtype = RT_UNDEFINED;
skygw_query_op_t queryop;
DCB* dcb = my_session->session->client; DCB* dcb = my_session->session->client;
ip = strtoip(dcb->remote); ip = strtoip(dcb->remote);
while(users){ rtype = (ruletype_t)hashtable_fetch(my_instance->htable, dcb->user);
if(strcmp(dcb->user,users->value)==0){ if(rtype == RT_USER){
match = true; match = true;
accept = my_instance->whitelist_users; accept = my_instance->whitelist_users;
skygw_log_write(LOGFILE_TRACE, "%s@%s was %s.", skygw_log_write(LOGFILE_TRACE, "Firewall: %s@%s was %s.",
dcb->user,dcb->remote,(my_instance->whitelist_users ? "allowed":"blocked")); dcb->user, dcb->remote,
break; (my_instance->whitelist_users ?
} "allowed":"blocked"));
users = users->next;
} }
if(!match){ if(!match){
@ -608,7 +656,7 @@ routeQuery(FILTER *instance, void *session, GWBUF *queue)
if(ip >= ipranges->ip && ip <= ipranges->ip + ipranges->mask){ if(ip >= ipranges->ip && ip <= ipranges->ip + ipranges->mask){
match = true; match = true;
accept = my_instance->whitelist_networks; accept = my_instance->whitelist_networks;
skygw_log_write(LOGFILE_TRACE, "%s@%s was %s.", skygw_log_write(LOGFILE_TRACE, "Firewall: %s@%s was %s.",
dcb->user,dcb->remote,(my_instance->whitelist_networks ? "allowed":"blocked")); dcb->user,dcb->remote,(my_instance->whitelist_networks ? "allowed":"blocked"));
break; break;
} }
@ -626,11 +674,11 @@ routeQuery(FILTER *instance, void *session, GWBUF *queue)
if(skygw_is_real_query(queue)){ if(skygw_is_real_query(queue)){
match = false; match = false;
modutil_extract_SQL(queue, &query, &len);
where = skygw_get_where_clause(queue);
skygw_query_op_t queryop = query_classifier_get_operation(queue);
if(where == NULL){ if(!skygw_query_has_clause(queue)){
queryop = query_classifier_get_operation(queue);
if(my_instance->require_where[ALL] || if(my_instance->require_where[ALL] ||
(my_instance->require_where[SELECT] && queryop == QUERY_OP_SELECT) || (my_instance->require_where[SELECT] && queryop == QUERY_OP_SELECT) ||
(my_instance->require_where[UPDATE] && queryop == QUERY_OP_UPDATE) || (my_instance->require_where[UPDATE] && queryop == QUERY_OP_UPDATE) ||
@ -638,38 +686,40 @@ routeQuery(FILTER *instance, void *session, GWBUF *queue)
(my_instance->require_where[DELETE] && queryop == QUERY_OP_DELETE)){ (my_instance->require_where[DELETE] && queryop == QUERY_OP_DELETE)){
match = true; match = true;
accept = false; accept = false;
skygw_log_write(LOGFILE_TRACE, "query does not have a where clause, blocking it: %.*s",len,query); skygw_log_write(LOGFILE_TRACE, "Firewall: query does not have a where clause or a having clause, blocking it: %.*s",GWBUF_LENGTH(queue),(char*)(queue->start + 5));
} }
} }
if(!match &&
my_instance->block_wildcard &&
((where && strchr(where,'*') != NULL) ||
(memchr(query,'*',len) != NULL))){
match = true;
accept = false;
skygw_log_write(LOGFILE_TRACE, "query contains wildcard, blocking it: %.*s",len,query);
}
if(!match){ if(!match){
if(where == NULL){
where = malloc(sizeof(char)*len+1); where = skygw_get_affected_fields(queue);
memcpy(where,query,len);
memset(where+len,0,1); if(my_instance->block_wildcard &&
} where && strchr(where,'*') != NULL)
while(columns){ {
if(strstr(where,columns->value)){
match = true; match = true;
accept = false; accept = false;
skygw_log_write(LOGFILE_TRACE, "query contains a forbidden column %s, blocking it: %.*s",columns->value,len,query); skygw_log_write(LOGFILE_TRACE, "Firewall: query contains wildcard, blocking it: %.*s",GWBUF_LENGTH(queue),(char*)(queue->start + 5));
break;
} }
columns = columns->next; else if(where)
{
char* tok = strtok(where," ");
while(tok){
rtype = (ruletype_t)hashtable_fetch(my_instance->htable, tok);
if(rtype == RT_COLUMN){
match = true;
accept = false;
skygw_log_write(LOGFILE_TRACE, "Firewall: query contains a forbidden column %s, blocking it: %.*s",tok,GWBUF_LENGTH(queue),(char*)(queue->start + 5));
} }
tok = strtok(NULL," ");
}
} }
free(where); free(where);
} }
}
} }
/**If no rules matched, do the default operation. (allow by default)*/ /**If no rules matched, do the default operation. (allow by default)*/