diff --git a/server/modules/filter/tee.c b/server/modules/filter/tee.c index c15cc9420..f4ee04685 100644 --- a/server/modules/filter/tee.c +++ b/server/modules/filter/tee.c @@ -107,7 +107,6 @@ static void setDownstream(FILTER *instance, void *fsession, DOWNSTREAM *downstre static int routeQuery(FILTER *instance, void *fsession, GWBUF *queue); static void diagnostic(FILTER *instance, void *fsession, DCB *dcb); - static FILTER_OBJECT MyObject = { createInstance, newSession, @@ -153,6 +152,33 @@ typedef struct { } TEE_SESSION; static int packet_is_required(GWBUF *queue); +static int detect_loops(TEE_INSTANCE *instance, HASHTABLE* ht, SERVICE* session); + +static int hkfn( + void* key) +{ + if(key == NULL){ + return 0; + } + unsigned int hash = 0,c = 0; + char* ptr = (char*)key; + while((c = *ptr++)){ + hash = c + (hash << 6) + (hash << 16) - hash; + } + return *(int *)key; +} + +static int hcfn( + void* v1, + void* v2) +{ + char* i1 = (char*) v1; + char* i2 = (char*) v2; + + return strcmp(i1,i2); +} + + /** * Implementation of the mandatory version entry point * @@ -313,7 +339,20 @@ char *remote, *userName; my_session = NULL; goto retblock; } - + + HASHTABLE* ht = hashtable_alloc(100,hkfn,hcfn); + bool is_loop = detect_loops(my_instance,ht,session->service); + hashtable_free(ht); + + if(is_loop) + { + LOGIF(LE, (skygw_log_write_flush(LOGFILE_ERROR, + "Error : %s: Recursive use of tee filter in service.", + session->service->name))); + my_session = NULL; + goto retblock; + } + if ((my_session = calloc(1, sizeof(TEE_SESSION))) != NULL) { my_session->active = 1; @@ -373,6 +412,7 @@ char *remote, *userName; goto retblock; } + ses->ses_is_child = true; my_session->branch_session = ses; my_session->branch_dcb = dcb; @@ -641,3 +681,53 @@ int i; return 1; return 0; } + +/** + * Detects possible loops in the query cloning chain. + */ +int detect_loops(TEE_INSTANCE *instance,HASHTABLE* ht, SERVICE* service) +{ + SERVICE* svc = service; + bool recurse = true; + int i; + + if(ht == NULL) + { + return -1; + } + + if(hashtable_add(ht,(void*)service->name,(void*)true) == 0) + { + return true; + } + + for(i = 0;in_filters;i++) + { + if(strcmp(svc->filters[i]->module,"tee") == 0) + { + /* + * Found a Tee filter, recurse down its path + * if the service name isn't already in the hashtable. + */ + + TEE_INSTANCE* ninst = (TEE_INSTANCE*)svc->filters[i]->filter; + if(ninst == NULL) + { + /** + * This tee instance hasn't been initialized yet and full + * resolution of recursion cannot be done now. + */ + continue; + } + SERVICE* tgt = ninst->service; + + if(detect_loops((TEE_INSTANCE*)svc->filters[i]->filter,ht,tgt)) + { + return true; + } + + } + } + + return false; +}