diff --git a/query_classifier/query_classifier.cc b/query_classifier/query_classifier.cc index 1ee07135f..5eedc31c0 100644 --- a/query_classifier/query_classifier.cc +++ b/query_classifier/query_classifier.cc @@ -1449,7 +1449,7 @@ static void parsing_info_set_plain_str( * @return string representing the query type value */ char* skygw_get_qtype_str( - skygw_query_type_t qtype) + skygw_query_type_t qtype) { int t1 = (int)qtype; int t2 = 1; @@ -1461,26 +1461,72 @@ char* skygw_get_qtype_str( * t1 is completely cleared. */ while (t1 != 0) - { - if (t1&t2) - { - t = (skygw_query_type_t)t2; + { + if (t1&t2) + { + t = (skygw_query_type_t)t2; - if (qtype_str == NULL) - { - qtype_str = strdup(STRQTYPE(t)); - } - else - { - size_t len = strlen(STRQTYPE(t)); - /** reallocate space for delimiter, new string and termination */ - qtype_str = (char *)realloc(qtype_str, strlen(qtype_str)+1+len+1); - snprintf(qtype_str+strlen(qtype_str), 1+len+1, "|%s", STRQTYPE(t)); - } - /** Remove found value from t1 */ - t1 &= ~t2; + if (qtype_str == NULL) + { + qtype_str = strdup(STRQTYPE(t)); + } + else + { + size_t len = strlen(STRQTYPE(t)); + /** reallocate space for delimiter, new string and termination */ + qtype_str = (char *)realloc(qtype_str, strlen(qtype_str)+1+len+1); + snprintf(qtype_str+strlen(qtype_str), 1+len+1, "|%s", STRQTYPE(t)); + } + /** Remove found value from t1 */ + t1 &= ~t2; + } + t2 <<= 1; } - t2 <<= 1; - } return qtype_str; } +skygw_query_op_t query_classifier_get_operation(GWBUF* querybuf) +{ + LEX* lex = get_lex(querybuf); + skygw_query_op_t operation; + if(lex){ + switch(lex->sql_command){ + case SQLCOM_SELECT: + operation = QUERY_OP_SELECT; + break; + case SQLCOM_CREATE_TABLE: + operation = QUERY_OP_CREATE_TABLE; + break; + case SQLCOM_CREATE_INDEX: + operation = QUERY_OP_CREATE_INDEX; + break; + case SQLCOM_ALTER_TABLE: + operation = QUERY_OP_ALTER_TABLE; + break; + case SQLCOM_UPDATE: + operation = QUERY_OP_UPDATE; + break; + case SQLCOM_INSERT: + operation = QUERY_OP_INSERT; + break; + case SQLCOM_INSERT_SELECT: + operation = QUERY_OP_INSERT_SELECT; + break; + case SQLCOM_DELETE: + operation = QUERY_OP_DELETE; + break; + case SQLCOM_TRUNCATE: + operation = QUERY_OP_TRUNCATE; + break; + case SQLCOM_DROP_TABLE: + operation = QUERY_OP_DROP_TABLE; + break; + case SQLCOM_DROP_INDEX: + operation = QUERY_OP_DROP_INDEX; + break; + + default: + operation = QUERY_OP_UNDEFINED; + } + } + return operation; +} diff --git a/query_classifier/query_classifier.h b/query_classifier/query_classifier.h index 3e6bdf59b..eb0e6d4b9 100644 --- a/query_classifier/query_classifier.h +++ b/query_classifier/query_classifier.h @@ -60,6 +60,12 @@ typedef enum { QUERY_TYPE_SHOW_TABLES = 0x400000 /*< Show list of tables */ } skygw_query_type_t; +typedef enum { + QUERY_OP_UNDEFINED, QUERY_OP_SELECT, QUERY_OP_CREATE_TABLE, QUERY_OP_CREATE_INDEX, + QUERY_OP_ALTER_TABLE, QUERY_OP_UPDATE, QUERY_OP_INSERT, QUERY_OP_INSERT_SELECT, + QUERY_OP_DELETE, QUERY_OP_TRUNCATE, QUERY_OP_DROP_TABLE, QUERY_OP_DROP_INDEX, + +}skygw_query_op_t; typedef struct parsing_info_st { #if defined(SS_DEBUG) @@ -81,7 +87,7 @@ typedef struct parsing_info_st { * classify the query. */ skygw_query_type_t query_classifier_get_type(GWBUF* querybuf); - +skygw_query_op_t query_classifier_get_operation(GWBUF* querybuf); /** Free THD context and close MYSQL */ char* skygw_query_classifier_get_stmtname(MYSQL* mysql); char* skygw_get_created_table_name(GWBUF* querybuf); diff --git a/server/modules/filter/fwfilter.c b/server/modules/filter/fwfilter.c index bb9b4526b..7893b0724 100644 --- a/server/modules/filter/fwfilter.c +++ b/server/modules/filter/fwfilter.c @@ -22,6 +22,9 @@ * * A filter that acts as a firewall, blocking queries that do not meet the set requirements. */ +#include +#include +#include #include #include #include @@ -30,12 +33,12 @@ #include #include #include +#include #include #include #include -#include #include -#include + MODULE_INFO info = { MODULE_API_FILTER, @@ -44,6 +47,89 @@ MODULE_INFO info = { "Firewall Filter" }; +static char *version_str = "V1.0.0"; + +/* + * The filter entry points + */ +static FILTER *createInstance(char **options, FILTER_PARAMETER **); +static void *newSession(FILTER *instance, SESSION *session); +static void closeSession(FILTER *instance, void *session); +static void freeSession(FILTER *instance, void *session); +static void setDownstream(FILTER *instance, void *fsession, DOWNSTREAM *downstream); +static int routeQuery(FILTER *instance, void *fsession, GWBUF *queue); +static void diagnostic(FILTER *instance, void *fsession, DCB *dcb); + + +static FILTER_OBJECT MyObject = { + createInstance, + newSession, + closeSession, + freeSession, + setDownstream, + NULL, + routeQuery, + NULL, + diagnostic, +}; + + + +/** + * Query types + */ + +#define QUERY_TYPES 5 + +enum querytype_t{ + ALL, + SELECT, + INSERT, + UPDATE, + DELETE +}; + +/** + * Generic linked list of string values + */ + +typedef struct item_t{ + struct item_t* next; + char* value; +}ITEM; + +/** + * A link in a list of IP adresses and subnet masks + */ +typedef struct iprange_t{ + struct iprange_t* next; + uint32_t ip; + uint32_t mask; +}IPRANGE; + +/** + * The Firewall filter instance. + */ +typedef struct { + ITEM* columns; + ITEM* users; + IPRANGE* networks; + int column_count, column_size, user_count, user_size; + bool require_where[QUERY_TYPES]; + bool block_wildcard, whitelist_users,whitelist_networks,def_op; + +} FW_INSTANCE; + +/** + * The session structure for Firewall filter. + */ +typedef struct { + DOWNSTREAM down; + UPSTREAM up; + SESSION* session; +} FW_SESSION; + + /** * Utility function to check if a string contains a valid IP address. * The string handled as a null-terminated string. @@ -221,85 +307,6 @@ uint32_t strtosubmask(char* str) return ~mask; } -static char *version_str = "V1.0.0"; - -/* - * The filter entry points - */ -static FILTER *createInstance(char **options, FILTER_PARAMETER **); -static void *newSession(FILTER *instance, SESSION *session); -static void closeSession(FILTER *instance, void *session); -static void freeSession(FILTER *instance, void *session); -static void setDownstream(FILTER *instance, void *fsession, DOWNSTREAM *downstream); -static int routeQuery(FILTER *instance, void *fsession, GWBUF *queue); -static void diagnostic(FILTER *instance, void *fsession, DCB *dcb); - - -static FILTER_OBJECT MyObject = { - createInstance, - newSession, - closeSession, - freeSession, - setDownstream, - NULL, - routeQuery, - NULL, - diagnostic, -}; - - - -/** - * Query types - */ - -enum querytype_t{ - ALL, - SELECT, - INSERT, - UPDATE, - DELETE -}; - -/** - * Generic linked list of string values - */ - -typedef struct item_t{ - struct item_t* next; - char* value; -}ITEM; - -/** - * A link in a list of IP adresses and subnet masks - */ -typedef struct iprange_t{ - struct iprange_t* next; - uint32_t ip; - uint32_t mask; -}IPRANGE; - -/** - * The Firewall filter instance. - */ -typedef struct { - ITEM* columns; - ITEM* users; - IPRANGE* networks; - int column_count, column_size, user_count, user_size; - bool require_where[QUERY_TYPES]; - bool block_wildcard, whitelist_users,whitelist_networks; - -} FW_INSTANCE; - -/** - * The session structure for Firewall filter. - */ -typedef struct { - DOWNSTREAM down; - UPSTREAM up; - SESSION* session; -} FW_SESSION; /** * Implementation of the mandatory version entry point @@ -451,6 +458,7 @@ createInstance(char **options, FILTER_PARAMETER **params) return NULL; } int i; + my_instance->def_op = true; for(i = 0;params[i];i++){ if(strstr(params[i]->name,"rule")){ parse_rule(strip_tags(params[i]->value),my_instance); @@ -526,31 +534,26 @@ setDownstream(FILTER *instance, void *session, DOWNSTREAM *downstream) my_session->down = *downstream; } -/** - * Checks if the packet contains an empty query error - * and if the session blocked the last query - * @param buf Buffer to inspect - * @param session Filter session object - * @return true if the error is the right one and the previous query was blocked - */ -bool is_dummy(GWBUF* buf,FW_SESSION* session) -{ - return(*((unsigned char*)buf->start + 4) == 0xff && - *((unsigned char*)buf->start + 5) == 0x29 && - *((unsigned char*)buf->start + 6) == 0x04); -} - /** * Generates a dummy error packet for the client. * @return The dummy packet or NULL if an error occurred */ -GWBUF* gen_dummy_error() +GWBUF* gen_dummy_error(FW_SESSION* session) { GWBUF* buf; - const char* errmsg = "Access denied."; - unsigned int errlen = strlen(errmsg), - pktlen = errlen + 9; + char errmsg[512]; + DCB* dcb = session->session->client; + MYSQL_session* mysql_session = (MYSQL_session*)session->session->data; + unsigned int errlen, pktlen; + + sprintf(errmsg,"Access denied for user '%s'@'%s' to database '%s' ", + dcb->user, + dcb->remote, + mysql_session->db); + errlen = strlen(errmsg); + pktlen = errlen + 9; buf = gwbuf_alloc(13 + errlen); + if(buf){ strcpy(buf->start + 7,"#HY000"); memcpy(buf->start + 13,errmsg,errlen); @@ -625,14 +628,30 @@ routeQuery(FILTER *instance, void *session, GWBUF *queue) match = false; modutil_extract_SQL(queue, &query, &len); where = skygw_get_where_clause(queue); - - if(my_instance->block_wildcard && + skygw_query_op_t queryop = query_classifier_get_operation(queue); + + if(where == NULL){ + if(my_instance->require_where[ALL] || + (my_instance->require_where[SELECT] && queryop == QUERY_OP_SELECT) || + (my_instance->require_where[UPADTE] && queryop == QUERY_OP_UPDATE) || + (my_instance->require_where[INSERT] && queryop == QUERY_OP_INSERT) || + (my_instance->require_where[DELETE] && queryop == QUERY_OP_DELETE)){ + match = true; + accept = false; + skygw_log_write(LOGFILE_TRACE, "query does not have a where clause, blocking it: %.*s",len,query); + + } + } + + if(!match && + my_instance->block_wildcard && ((where && strchr(where,'*') != NULL) || (memchr(query,'*',len) != NULL))){ match = true; accept = false; skygw_log_write(LOGFILE_TRACE, "query contains wildcard, blocking it: %.*s",len,query); - } + } + if(!match){ if(where == NULL){ where = malloc(sizeof(char)*len+1); @@ -653,6 +672,11 @@ routeQuery(FILTER *instance, void *session, GWBUF *queue) } } + /**If no rules matched, do the default operation. (allow by default)*/ + if(!match){ + accept = my_instance->def_op; + } + if(accept){ return my_session->down.routeQuery(my_session->down.instance, @@ -660,7 +684,7 @@ routeQuery(FILTER *instance, void *session, GWBUF *queue) }else{ gwbuf_free(queue); - GWBUF* forward = gen_dummy_error(); + GWBUF* forward = gen_dummy_error(my_session); dcb->func.write(dcb,forward); //gwbuf_free(forward); return 0;