From effe8f32971025eea9e5bca25492575e5a1c28b8 Mon Sep 17 00:00:00 2001 From: Markus Makela Date: Wed, 15 Oct 2014 15:23:14 +0300 Subject: [PATCH] Switched over to hashtables for users and columns. --- query_classifier/query_classifier.cc | 124 +++++++++++++----- query_classifier/query_classifier.h | 3 +- server/modules/filter/fwfilter.c | 182 +++++++++++++++++---------- 3 files changed, 211 insertions(+), 98 deletions(-) diff --git a/query_classifier/query_classifier.cc b/query_classifier/query_classifier.cc index 5eedc31c0..40d5e7f91 100644 --- a/query_classifier/query_classifier.cc +++ b/query_classifier/query_classifier.cc @@ -1186,60 +1186,122 @@ bool is_drop_table_query(GWBUF* querybuf) lex->sql_command == SQLCOM_DROP_TABLE; } -char* skygw_get_where_clause(GWBUF* buf) +inline void add_str(char** buf, int* buflen, int* bufsize, char* str) +{ + int isize = strlen(str) + 1; + if(*buf == NULL || isize + *buflen >= *bufsize) + { + char *tmp = (char*)calloc((*bufsize) * 2 + isize, sizeof(char)); + if(tmp){ + memcpy(tmp,*buf,*bufsize); + if(*buf){ + free(*buf); + } + *buf = tmp; + *bufsize = (*bufsize) * 2 + isize; + } + } + + if(*buflen > 0){ + strcat(*buf," "); + } + strcat(*buf,str); + *buflen += isize; + +} + + +/** + * Returns all the fields that the query affects. + * @param buf Buffer to parse + * @return Pointer to newly allocated string or NULL if nothing was found + */ +char* skygw_get_affected_fields(GWBUF* buf) { LEX* lex; - unsigned int buffsz = 0,bufflen = 0; + int buffsz = 0,bufflen = 0; char* where = NULL; - Item* item; + Item* item; + Item::Type itype; + if(!query_is_parsed(buf)){ parse_query(buf); } + if((lex = get_lex(buf)) == NULL){ return NULL; } - - lex->current_select = lex->all_selects_list; - + + lex->current_select = lex->all_selects_list; + while(lex->current_select) { + + List_iterator ilist(lex->current_select->item_list); + item = (Item*)ilist.next(); + for (item; item != NULL; item=(Item*)ilist.next()) + { + + itype = item->type(); + if(item->name && itype == Item::FIELD_ITEM){ + add_str(&where,&buffsz,&bufflen,item->name); + } + } + + if(lex->current_select->where){ for (item=lex->current_select->where; item != NULL; item=item->next) { - Item::Type tp = item->type(); - if(item->name && tp == Item::FIELD_ITEM){ - - int isize = strlen(item->name) + 1; - if(where == NULL || isize + bufflen >= buffsz) - { - char *tmp = (char*)calloc(buffsz*2 + isize,sizeof(char)); - if(tmp){ - memcpy(tmp,where,buffsz); - if(where){ - free(where); - } - where = tmp; - buffsz = buffsz*2 + isize; - }else{ - return NULL; - } - } - - if(bufflen > 0){ - strcat(where," "); - } - strcat(where,item->name); - bufflen += isize; - + itype = item->type(); + if(item->name && itype == Item::FIELD_ITEM){ + add_str(&where,&buffsz,&bufflen,item->name); } } } + + if(lex->current_select->having){ + for (item=lex->current_select->having; item != NULL; item=item->next) + { + + itype = item->type(); + if(item->name && itype == Item::FIELD_ITEM){ + add_str(&where,&buffsz,&bufflen,item->name); + } + } + } + lex->current_select = lex->current_select->next_select_in_list(); } return where; } +bool skygw_query_has_clause(GWBUF* buf) +{ + LEX* lex; + bool clause = false; + + if(!query_is_parsed(buf)){ + parse_query(buf); + } + + if((lex = get_lex(buf)) == NULL){ + return false; + } + + lex->current_select = lex->all_selects_list; + + while(lex->current_select) + { + if(lex->current_select->where || lex->current_select->having){ + clause = true; + } + + lex->current_select = lex->current_select->next_select_in_list(); + } + return clause; +} + /* * Replace user-provided literals with question marks. Return a copy of the * querystr with replacements. diff --git a/query_classifier/query_classifier.h b/query_classifier/query_classifier.h index eb0e6d4b9..77230d4b6 100644 --- a/query_classifier/query_classifier.h +++ b/query_classifier/query_classifier.h @@ -100,8 +100,9 @@ bool parse_query (GWBUF* querybuf); parsing_info_t* parsing_info_init(void (*donefun)(void *)); void parsing_info_done(void* ptr); bool query_is_parsed(GWBUF* buf); +bool skygw_query_has_clause(GWBUF* buf); char* skygw_get_qtype_str(skygw_query_type_t qtype); -char* skygw_get_where_clause(GWBUF* buf); +char* skygw_get_affected_fields(GWBUF* buf); EXTERN_C_BLOCK_END diff --git a/server/modules/filter/fwfilter.c b/server/modules/filter/fwfilter.c index 6807d863a..0ea1892b7 100644 --- a/server/modules/filter/fwfilter.c +++ b/server/modules/filter/fwfilter.c @@ -73,21 +73,27 @@ static FILTER_OBJECT MyObject = { diagnostic, }; - +#define QUERY_TYPES 5 /** * Query types */ - -#define QUERY_TYPES 5 - -enum querytype_t{ +typedef enum{ ALL, SELECT, INSERT, UPDATE, DELETE -}; +}querytype_t; + +/** + * Rule types + */ +typedef enum { + RT_UNDEFINED, + RT_USER, + RT_COLUMN +}ruletype_t; /** * Generic linked list of string values @@ -111,8 +117,7 @@ typedef struct iprange_t{ * The Firewall filter instance. */ typedef struct { - ITEM* columns; - ITEM* users; + HASHTABLE* htable; IPRANGE* networks; int column_count, column_size, user_count, user_size; bool require_where[QUERY_TYPES]; @@ -129,6 +134,45 @@ typedef struct { SESSION* session; } FW_SESSION; +static int hashkeyfun(void* key); +static int hashcmpfun (void *, void *); + +static int hashkeyfun( + void* key) +{ + if(key == NULL){ + return 0; + } + unsigned int hash = 0,c = 0; + char* ptr = (char*)key; + while((c = *ptr++)){ + hash = c + (hash << 6) + (hash << 16) - hash; + } + return (int)hash > 0 ? hash : -hash; +} + +static int hashcmpfun( + void* v1, + void* v2) +{ + char* i1 = (char*) v1; + char* i2 = (char*) v2; + + return strcmp(i1,i2); +} + +static void* hstrdup(void* fval) +{ + char* str = (char*)fval; + return strdup(str); +} + + +static void* hfree(void* fval) +{ + free (fval); + return NULL; +} /** * Utility function to check if a string contains a valid IP address. @@ -308,6 +352,7 @@ uint32_t strtosubmask(char* str) } + /** * Implementation of the mandatory version entry point * @@ -369,11 +414,10 @@ void parse_rule(char* rule, FW_INSTANCE* instance) instance->networks = rng; } - }else{ /**Add usernames or columns*/ + }else{ /**Add rules on usernames or columns*/ char *tok = strtok(ptr," ,\0"); - ITEM* prev = NULL; - bool is_user = false, is_column = false; + bool is_user = false, is_column = false, is_time = false; if(strcmp(tok,"wildcard") == 0){ instance->block_wildcard = block ? true : false; @@ -381,11 +425,9 @@ void parse_rule(char* rule, FW_INSTANCE* instance) } if(strcmp(tok,"users") == 0){/**Adding users*/ - prev = instance->users; instance->whitelist_users = mode; is_user = true; }else if(strcmp(tok,"columns") == 0){/**Adding Columns*/ - prev = instance->columns; is_column = true; } @@ -393,21 +435,17 @@ void parse_rule(char* rule, FW_INSTANCE* instance) if(is_user || is_column){ while(tok){ - - ITEM* item = calloc(1,sizeof(ITEM)); - if(item){ - item->next = prev; - item->value = strdup(tok); - prev = item; - } + + /**Add value to hashtable*/ + + ruletype_t rtype = is_user ? RT_USER : is_column ? RT_COLUMN: RT_UNDEFINED; + hashtable_add(instance->htable, + (void *)tok, + (void *)rtype); + tok = strtok(NULL," ,\0"); } - if(is_user){ - instance->users = prev; - }else if(is_column){ - instance->columns = prev; - } } } @@ -458,7 +496,18 @@ createInstance(char **options, FILTER_PARAMETER **params) return NULL; } int i; + HASHTABLE* ht; + + if((ht = hashtable_alloc(7, hashkeyfun, hashcmpfun)) == NULL){ + skygw_log_write(LOGFILE_ERROR, "Unable to allocate hashtable."); + return NULL; + } + + hashtable_memory_fns(ht,hstrdup,NULL,hfree,NULL); + + my_instance->htable = ht; my_instance->def_op = true; + for(i = 0;params[i];i++){ if(strstr(params[i]->name,"rule")){ parse_rule(strip_tags(params[i]->value),my_instance); @@ -584,23 +633,22 @@ routeQuery(FILTER *instance, void *session, GWBUF *queue) FW_SESSION *my_session = (FW_SESSION *)session; FW_INSTANCE *my_instance = (FW_INSTANCE *)instance; IPRANGE* ipranges = my_instance->networks; - ITEM *users = my_instance->users, *columns = my_instance->columns; bool accept = false, match = false; - char *where,*query; + char *where; uint32_t ip; - int len; + ruletype_t rtype = RT_UNDEFINED; + skygw_query_op_t queryop; DCB* dcb = my_session->session->client; ip = strtoip(dcb->remote); - - while(users){ - if(strcmp(dcb->user,users->value)==0){ - match = true; - accept = my_instance->whitelist_users; - skygw_log_write(LOGFILE_TRACE, "%s@%s was %s.", - dcb->user,dcb->remote,(my_instance->whitelist_users ? "allowed":"blocked")); - break; - } - users = users->next; + + rtype = (ruletype_t)hashtable_fetch(my_instance->htable, dcb->user); + if(rtype == RT_USER){ + match = true; + accept = my_instance->whitelist_users; + skygw_log_write(LOGFILE_TRACE, "Firewall: %s@%s was %s.", + dcb->user, dcb->remote, + (my_instance->whitelist_users ? + "allowed":"blocked")); } if(!match){ @@ -608,7 +656,7 @@ routeQuery(FILTER *instance, void *session, GWBUF *queue) if(ip >= ipranges->ip && ip <= ipranges->ip + ipranges->mask){ match = true; accept = my_instance->whitelist_networks; - skygw_log_write(LOGFILE_TRACE, "%s@%s was %s.", + skygw_log_write(LOGFILE_TRACE, "Firewall: %s@%s was %s.", dcb->user,dcb->remote,(my_instance->whitelist_networks ? "allowed":"blocked")); break; } @@ -626,11 +674,11 @@ routeQuery(FILTER *instance, void *session, GWBUF *queue) if(skygw_is_real_query(queue)){ match = false; - modutil_extract_SQL(queue, &query, &len); - where = skygw_get_where_clause(queue); - skygw_query_op_t queryop = query_classifier_get_operation(queue); + + if(!skygw_query_has_clause(queue)){ + + queryop = query_classifier_get_operation(queue); - if(where == NULL){ if(my_instance->require_where[ALL] || (my_instance->require_where[SELECT] && queryop == QUERY_OP_SELECT) || (my_instance->require_where[UPDATE] && queryop == QUERY_OP_UPDATE) || @@ -638,37 +686,39 @@ routeQuery(FILTER *instance, void *session, GWBUF *queue) (my_instance->require_where[DELETE] && queryop == QUERY_OP_DELETE)){ match = true; accept = false; - skygw_log_write(LOGFILE_TRACE, "query does not have a where clause, blocking it: %.*s",len,query); - + skygw_log_write(LOGFILE_TRACE, "Firewall: query does not have a where clause or a having clause, blocking it: %.*s",GWBUF_LENGTH(queue),(char*)(queue->start + 5)); } } - if(!match && - my_instance->block_wildcard && - ((where && strchr(where,'*') != NULL) || - (memchr(query,'*',len) != NULL))){ - match = true; - accept = false; - skygw_log_write(LOGFILE_TRACE, "query contains wildcard, blocking it: %.*s",len,query); - } - if(!match){ - if(where == NULL){ - where = malloc(sizeof(char)*len+1); - memcpy(where,query,len); - memset(where+len,0,1); - } - while(columns){ - if(strstr(where,columns->value)){ + + where = skygw_get_affected_fields(queue); + + if(my_instance->block_wildcard && + where && strchr(where,'*') != NULL) + { match = true; accept = false; - skygw_log_write(LOGFILE_TRACE, "query contains a forbidden column %s, blocking it: %.*s",columns->value,len,query); - break; + skygw_log_write(LOGFILE_TRACE, "Firewall: query contains wildcard, blocking it: %.*s",GWBUF_LENGTH(queue),(char*)(queue->start + 5)); } - columns = columns->next; - } + else if(where) + { + char* tok = strtok(where," "); + + while(tok){ + rtype = (ruletype_t)hashtable_fetch(my_instance->htable, tok); + if(rtype == RT_COLUMN){ + match = true; + accept = false; + skygw_log_write(LOGFILE_TRACE, "Firewall: query contains a forbidden column %s, blocking it: %.*s",tok,GWBUF_LENGTH(queue),(char*)(queue->start + 5)); + } + tok = strtok(NULL," "); + } + + } + free(where); } - free(where); + } }