added checking for where clause on queries

This commit is contained in:
Markus Makela
2014-10-13 13:48:07 +03:00
parent 75117f2482
commit 9abe270da8
3 changed files with 200 additions and 124 deletions

View File

@ -1449,7 +1449,7 @@ static void parsing_info_set_plain_str(
* @return string representing the query type value * @return string representing the query type value
*/ */
char* skygw_get_qtype_str( char* skygw_get_qtype_str(
skygw_query_type_t qtype) skygw_query_type_t qtype)
{ {
int t1 = (int)qtype; int t1 = (int)qtype;
int t2 = 1; int t2 = 1;
@ -1461,26 +1461,72 @@ char* skygw_get_qtype_str(
* t1 is completely cleared. * t1 is completely cleared.
*/ */
while (t1 != 0) while (t1 != 0)
{ {
if (t1&t2) if (t1&t2)
{ {
t = (skygw_query_type_t)t2; t = (skygw_query_type_t)t2;
if (qtype_str == NULL) if (qtype_str == NULL)
{ {
qtype_str = strdup(STRQTYPE(t)); qtype_str = strdup(STRQTYPE(t));
} }
else else
{ {
size_t len = strlen(STRQTYPE(t)); size_t len = strlen(STRQTYPE(t));
/** reallocate space for delimiter, new string and termination */ /** reallocate space for delimiter, new string and termination */
qtype_str = (char *)realloc(qtype_str, strlen(qtype_str)+1+len+1); qtype_str = (char *)realloc(qtype_str, strlen(qtype_str)+1+len+1);
snprintf(qtype_str+strlen(qtype_str), 1+len+1, "|%s", STRQTYPE(t)); snprintf(qtype_str+strlen(qtype_str), 1+len+1, "|%s", STRQTYPE(t));
} }
/** Remove found value from t1 */ /** Remove found value from t1 */
t1 &= ~t2; t1 &= ~t2;
}
t2 <<= 1;
} }
t2 <<= 1;
}
return qtype_str; return qtype_str;
} }
skygw_query_op_t query_classifier_get_operation(GWBUF* querybuf)
{
LEX* lex = get_lex(querybuf);
skygw_query_op_t operation;
if(lex){
switch(lex->sql_command){
case SQLCOM_SELECT:
operation = QUERY_OP_SELECT;
break;
case SQLCOM_CREATE_TABLE:
operation = QUERY_OP_CREATE_TABLE;
break;
case SQLCOM_CREATE_INDEX:
operation = QUERY_OP_CREATE_INDEX;
break;
case SQLCOM_ALTER_TABLE:
operation = QUERY_OP_ALTER_TABLE;
break;
case SQLCOM_UPDATE:
operation = QUERY_OP_UPDATE;
break;
case SQLCOM_INSERT:
operation = QUERY_OP_INSERT;
break;
case SQLCOM_INSERT_SELECT:
operation = QUERY_OP_INSERT_SELECT;
break;
case SQLCOM_DELETE:
operation = QUERY_OP_DELETE;
break;
case SQLCOM_TRUNCATE:
operation = QUERY_OP_TRUNCATE;
break;
case SQLCOM_DROP_TABLE:
operation = QUERY_OP_DROP_TABLE;
break;
case SQLCOM_DROP_INDEX:
operation = QUERY_OP_DROP_INDEX;
break;
default:
operation = QUERY_OP_UNDEFINED;
}
}
return operation;
}

View File

@ -60,6 +60,12 @@ typedef enum {
QUERY_TYPE_SHOW_TABLES = 0x400000 /*< Show list of tables */ QUERY_TYPE_SHOW_TABLES = 0x400000 /*< Show list of tables */
} skygw_query_type_t; } skygw_query_type_t;
typedef enum {
QUERY_OP_UNDEFINED, QUERY_OP_SELECT, QUERY_OP_CREATE_TABLE, QUERY_OP_CREATE_INDEX,
QUERY_OP_ALTER_TABLE, QUERY_OP_UPDATE, QUERY_OP_INSERT, QUERY_OP_INSERT_SELECT,
QUERY_OP_DELETE, QUERY_OP_TRUNCATE, QUERY_OP_DROP_TABLE, QUERY_OP_DROP_INDEX,
}skygw_query_op_t;
typedef struct parsing_info_st { typedef struct parsing_info_st {
#if defined(SS_DEBUG) #if defined(SS_DEBUG)
@ -81,7 +87,7 @@ typedef struct parsing_info_st {
* classify the query. * classify the query.
*/ */
skygw_query_type_t query_classifier_get_type(GWBUF* querybuf); skygw_query_type_t query_classifier_get_type(GWBUF* querybuf);
skygw_query_op_t query_classifier_get_operation(GWBUF* querybuf);
/** Free THD context and close MYSQL */ /** Free THD context and close MYSQL */
char* skygw_query_classifier_get_stmtname(MYSQL* mysql); char* skygw_query_classifier_get_stmtname(MYSQL* mysql);
char* skygw_get_created_table_name(GWBUF* querybuf); char* skygw_get_created_table_name(GWBUF* querybuf);

View File

@ -22,6 +22,9 @@
* *
* A filter that acts as a firewall, blocking queries that do not meet the set requirements. * A filter that acts as a firewall, blocking queries that do not meet the set requirements.
*/ */
#include <my_config.h>
#include <stdint.h>
#include <ctype.h>
#include <stdio.h> #include <stdio.h>
#include <fcntl.h> #include <fcntl.h>
#include <filter.h> #include <filter.h>
@ -30,12 +33,12 @@
#include <modutil.h> #include <modutil.h>
#include <log_manager.h> #include <log_manager.h>
#include <query_classifier.h> #include <query_classifier.h>
#include <mysql_client_server_protocol.h>
#include <spinlock.h> #include <spinlock.h>
#include <session.h> #include <session.h>
#include <plugin.h> #include <plugin.h>
#include <stdint.h>
#include <skygw_types.h> #include <skygw_types.h>
#include <ctype.h>
MODULE_INFO info = { MODULE_INFO info = {
MODULE_API_FILTER, MODULE_API_FILTER,
@ -44,6 +47,89 @@ MODULE_INFO info = {
"Firewall Filter" "Firewall Filter"
}; };
static char *version_str = "V1.0.0";
/*
* The filter entry points
*/
static FILTER *createInstance(char **options, FILTER_PARAMETER **);
static void *newSession(FILTER *instance, SESSION *session);
static void closeSession(FILTER *instance, void *session);
static void freeSession(FILTER *instance, void *session);
static void setDownstream(FILTER *instance, void *fsession, DOWNSTREAM *downstream);
static int routeQuery(FILTER *instance, void *fsession, GWBUF *queue);
static void diagnostic(FILTER *instance, void *fsession, DCB *dcb);
static FILTER_OBJECT MyObject = {
createInstance,
newSession,
closeSession,
freeSession,
setDownstream,
NULL,
routeQuery,
NULL,
diagnostic,
};
/**
* Query types
*/
#define QUERY_TYPES 5
enum querytype_t{
ALL,
SELECT,
INSERT,
UPDATE,
DELETE
};
/**
* Generic linked list of string values
*/
typedef struct item_t{
struct item_t* next;
char* value;
}ITEM;
/**
* A link in a list of IP adresses and subnet masks
*/
typedef struct iprange_t{
struct iprange_t* next;
uint32_t ip;
uint32_t mask;
}IPRANGE;
/**
* The Firewall filter instance.
*/
typedef struct {
ITEM* columns;
ITEM* users;
IPRANGE* networks;
int column_count, column_size, user_count, user_size;
bool require_where[QUERY_TYPES];
bool block_wildcard, whitelist_users,whitelist_networks,def_op;
} FW_INSTANCE;
/**
* The session structure for Firewall filter.
*/
typedef struct {
DOWNSTREAM down;
UPSTREAM up;
SESSION* session;
} FW_SESSION;
/** /**
* Utility function to check if a string contains a valid IP address. * Utility function to check if a string contains a valid IP address.
* The string handled as a null-terminated string. * The string handled as a null-terminated string.
@ -221,85 +307,6 @@ uint32_t strtosubmask(char* str)
return ~mask; return ~mask;
} }
static char *version_str = "V1.0.0";
/*
* The filter entry points
*/
static FILTER *createInstance(char **options, FILTER_PARAMETER **);
static void *newSession(FILTER *instance, SESSION *session);
static void closeSession(FILTER *instance, void *session);
static void freeSession(FILTER *instance, void *session);
static void setDownstream(FILTER *instance, void *fsession, DOWNSTREAM *downstream);
static int routeQuery(FILTER *instance, void *fsession, GWBUF *queue);
static void diagnostic(FILTER *instance, void *fsession, DCB *dcb);
static FILTER_OBJECT MyObject = {
createInstance,
newSession,
closeSession,
freeSession,
setDownstream,
NULL,
routeQuery,
NULL,
diagnostic,
};
/**
* Query types
*/
enum querytype_t{
ALL,
SELECT,
INSERT,
UPDATE,
DELETE
};
/**
* Generic linked list of string values
*/
typedef struct item_t{
struct item_t* next;
char* value;
}ITEM;
/**
* A link in a list of IP adresses and subnet masks
*/
typedef struct iprange_t{
struct iprange_t* next;
uint32_t ip;
uint32_t mask;
}IPRANGE;
/**
* The Firewall filter instance.
*/
typedef struct {
ITEM* columns;
ITEM* users;
IPRANGE* networks;
int column_count, column_size, user_count, user_size;
bool require_where[QUERY_TYPES];
bool block_wildcard, whitelist_users,whitelist_networks;
} FW_INSTANCE;
/**
* The session structure for Firewall filter.
*/
typedef struct {
DOWNSTREAM down;
UPSTREAM up;
SESSION* session;
} FW_SESSION;
/** /**
* Implementation of the mandatory version entry point * Implementation of the mandatory version entry point
@ -451,6 +458,7 @@ createInstance(char **options, FILTER_PARAMETER **params)
return NULL; return NULL;
} }
int i; int i;
my_instance->def_op = true;
for(i = 0;params[i];i++){ for(i = 0;params[i];i++){
if(strstr(params[i]->name,"rule")){ if(strstr(params[i]->name,"rule")){
parse_rule(strip_tags(params[i]->value),my_instance); parse_rule(strip_tags(params[i]->value),my_instance);
@ -526,31 +534,26 @@ setDownstream(FILTER *instance, void *session, DOWNSTREAM *downstream)
my_session->down = *downstream; my_session->down = *downstream;
} }
/**
* Checks if the packet contains an empty query error
* and if the session blocked the last query
* @param buf Buffer to inspect
* @param session Filter session object
* @return true if the error is the right one and the previous query was blocked
*/
bool is_dummy(GWBUF* buf,FW_SESSION* session)
{
return(*((unsigned char*)buf->start + 4) == 0xff &&
*((unsigned char*)buf->start + 5) == 0x29 &&
*((unsigned char*)buf->start + 6) == 0x04);
}
/** /**
* Generates a dummy error packet for the client. * Generates a dummy error packet for the client.
* @return The dummy packet or NULL if an error occurred * @return The dummy packet or NULL if an error occurred
*/ */
GWBUF* gen_dummy_error() GWBUF* gen_dummy_error(FW_SESSION* session)
{ {
GWBUF* buf; GWBUF* buf;
const char* errmsg = "Access denied."; char errmsg[512];
unsigned int errlen = strlen(errmsg), DCB* dcb = session->session->client;
pktlen = errlen + 9; MYSQL_session* mysql_session = (MYSQL_session*)session->session->data;
unsigned int errlen, pktlen;
sprintf(errmsg,"Access denied for user '%s'@'%s' to database '%s' ",
dcb->user,
dcb->remote,
mysql_session->db);
errlen = strlen(errmsg);
pktlen = errlen + 9;
buf = gwbuf_alloc(13 + errlen); buf = gwbuf_alloc(13 + errlen);
if(buf){ if(buf){
strcpy(buf->start + 7,"#HY000"); strcpy(buf->start + 7,"#HY000");
memcpy(buf->start + 13,errmsg,errlen); memcpy(buf->start + 13,errmsg,errlen);
@ -625,14 +628,30 @@ routeQuery(FILTER *instance, void *session, GWBUF *queue)
match = false; match = false;
modutil_extract_SQL(queue, &query, &len); modutil_extract_SQL(queue, &query, &len);
where = skygw_get_where_clause(queue); where = skygw_get_where_clause(queue);
skygw_query_op_t queryop = query_classifier_get_operation(queue);
if(my_instance->block_wildcard &&
if(where == NULL){
if(my_instance->require_where[ALL] ||
(my_instance->require_where[SELECT] && queryop == QUERY_OP_SELECT) ||
(my_instance->require_where[UPADTE] && queryop == QUERY_OP_UPDATE) ||
(my_instance->require_where[INSERT] && queryop == QUERY_OP_INSERT) ||
(my_instance->require_where[DELETE] && queryop == QUERY_OP_DELETE)){
match = true;
accept = false;
skygw_log_write(LOGFILE_TRACE, "query does not have a where clause, blocking it: %.*s",len,query);
}
}
if(!match &&
my_instance->block_wildcard &&
((where && strchr(where,'*') != NULL) || ((where && strchr(where,'*') != NULL) ||
(memchr(query,'*',len) != NULL))){ (memchr(query,'*',len) != NULL))){
match = true; match = true;
accept = false; accept = false;
skygw_log_write(LOGFILE_TRACE, "query contains wildcard, blocking it: %.*s",len,query); skygw_log_write(LOGFILE_TRACE, "query contains wildcard, blocking it: %.*s",len,query);
} }
if(!match){ if(!match){
if(where == NULL){ if(where == NULL){
where = malloc(sizeof(char)*len+1); where = malloc(sizeof(char)*len+1);
@ -653,6 +672,11 @@ routeQuery(FILTER *instance, void *session, GWBUF *queue)
} }
} }
/**If no rules matched, do the default operation. (allow by default)*/
if(!match){
accept = my_instance->def_op;
}
if(accept){ if(accept){
return my_session->down.routeQuery(my_session->down.instance, return my_session->down.routeQuery(my_session->down.instance,
@ -660,7 +684,7 @@ routeQuery(FILTER *instance, void *session, GWBUF *queue)
}else{ }else{
gwbuf_free(queue); gwbuf_free(queue);
GWBUF* forward = gen_dummy_error(); GWBUF* forward = gen_dummy_error(my_session);
dcb->func.write(dcb,forward); dcb->func.write(dcb,forward);
//gwbuf_free(forward); //gwbuf_free(forward);
return 0; return 0;