diff --git a/include/maxscale/dcb.h b/include/maxscale/dcb.h index 578a74e55..a6e08a1d0 100644 --- a/include/maxscale/dcb.h +++ b/include/maxscale/dcb.h @@ -367,21 +367,15 @@ static inline void dcb_readq_set(DCB *dcb, GWBUF *buffer) bool dcb_foreach(bool (*func)(DCB *dcb, void *data), void *data); /** - * @brief Call a function for each connected DCB + * @brief Call a function for each connected DCB on the current worker * - * @note This function can call @c func from multiple thread at one time. + * @param func Function to call. The function should return @c true to continue + * iteration and @c false to stop iteration earlier. The first parameter + * is the current DCB. * - * @param func Function to call. The function should return @c true to continue iteration - * and @c false to stop iteration earlier. The first is a DCB and - * the second is this thread's value in the @c data array that - * the user provided. - * - * @param data Array of user provided data passed as the second parameter to @c func. - * The array must have more space for pointers thann the return - * value of `config_threadcount()`. The value passed to @c func will - * be the value of the array at the index of the current thread's ID. + * @param data User provided data passed as the second parameter to @c func */ -void dcb_foreach_parallel(bool (*func)(DCB *dcb, void *data), void **data); +void dcb_foreach_local(bool (*func)(DCB *dcb, void *data), void *data); /** * @brief Return the port number this DCB is connected to diff --git a/include/maxscale/protocol/mariadb_client.hh b/include/maxscale/protocol/mariadb_client.hh index 161e89a59..9e28d3b18 100644 --- a/include/maxscale/protocol/mariadb_client.hh +++ b/include/maxscale/protocol/mariadb_client.hh @@ -37,8 +37,8 @@ public: * * @return New virtual client or NULL on error */ - static LocalClient* create(MXS_SESSION* session, SERVICE* service); - static LocalClient* create(MXS_SESSION* session, SERVER* server); + static LocalClient* create(MYSQL_session* session, MySQLProtocol* proto, SERVICE* service); + static LocalClient* create(MYSQL_session* session, MySQLProtocol* proto, SERVER* server); /** * Queue a new query for execution @@ -57,8 +57,8 @@ public: void self_destruct(); private: - static LocalClient* create(MXS_SESSION* session, const char* ip, uint64_t port); - LocalClient(MXS_SESSION* session, int fd); + static LocalClient* create(MYSQL_session* session, MySQLProtocol* proto, const char* ip, uint64_t port); + LocalClient(MYSQL_session* session, MySQLProtocol* proto, int fd); static uint32_t poll_handler(struct mxs_poll_data* data, int wid, uint32_t events); void process(uint32_t events); GWBUF* read_complete_packet(); diff --git a/server/core/dcb.cc b/server/core/dcb.cc index 86a7e49ab..3036a3098 100644 --- a/server/core/dcb.cc +++ b/server/core/dcb.cc @@ -2943,39 +2943,17 @@ bool dcb_foreach(bool(*func)(DCB *dcb, void *data), void *data) return task.more(); } -/** Helper class for parallel iteration over all DCBs */ -class ParallelDcbTask : public WorkerTask +void dcb_foreach_local(bool(*func)(DCB *dcb, void *data), void *data) { -public: + int thread_id = Worker::get_current_id(); - ParallelDcbTask(bool(*func)(DCB *, void *), void **data): - m_func(func), - m_data(data) + for (DCB *dcb = this_unit.all_dcbs[thread_id]; dcb; dcb = dcb->thread.next) { - } - - void execute(Worker& worker) - { - int thread_id = worker.id(); - - for (DCB *dcb = this_unit.all_dcbs[thread_id]; dcb; dcb = dcb->thread.next) + if (!func(dcb, data)) { - if (!m_func(dcb, m_data[thread_id])) - { - break; - } + break; } } - -private: - bool(*m_func)(DCB *dcb, void *data); - void** m_data; -}; - -void dcb_foreach_parallel(bool(*func)(DCB *dcb, void *data), void **data) -{ - ParallelDcbTask task(func, data); - Worker::execute_concurrently(task); } int dcb_get_port(const DCB *dcb) diff --git a/server/modules/filter/tee/teesession.cc b/server/modules/filter/tee/teesession.cc index 1696f5af4..5f83cce19 100644 --- a/server/modules/filter/tee/teesession.cc +++ b/server/modules/filter/tee/teesession.cc @@ -101,7 +101,9 @@ TeeSession* TeeSession::create(Tee* my_instance, MXS_SESSION* session) return NULL; } - if ((client = LocalClient::create(session, my_instance->get_service())) == NULL) + if ((client = LocalClient::create((MYSQL_session*)session->client_dcb->data, + (MySQLProtocol*)session->client_dcb->protocol, + my_instance->get_service())) == NULL) { return NULL; } diff --git a/server/modules/protocol/MySQL/mariadb_client.cc b/server/modules/protocol/MySQL/mariadb_client.cc index 74c9cddc5..2fbd8180d 100644 --- a/server/modules/protocol/MySQL/mariadb_client.cc +++ b/server/modules/protocol/MySQL/mariadb_client.cc @@ -25,20 +25,15 @@ static const uint32_t poll_events = EPOLLIN | EPOLLOUT | EPOLLET | ERROR_EVENTS; -LocalClient::LocalClient(MXS_SESSION* session, int fd): +LocalClient::LocalClient(MYSQL_session* session, MySQLProtocol* proto, int fd): m_state(VC_WAITING_HANDSHAKE), m_sock(fd), m_expected_bytes(0), - m_client({}), - m_protocol({}), + m_client(*session), + m_protocol(*proto), m_self_destruct(false) { MXS_POLL_DATA::handler = LocalClient::poll_handler; - MySQLProtocol* client = (MySQLProtocol*)session->client_dcb->protocol; - m_protocol.charset = client->charset; - m_protocol.client_capabilities = client->client_capabilities; - m_protocol.extra_capabilities = client->extra_capabilities; - gw_get_shared_session_auth_info(session->client_dcb, &m_client); } LocalClient::~LocalClient() @@ -237,7 +232,7 @@ uint32_t LocalClient::poll_handler(struct mxs_poll_data* data, int wid, uint32_t return 0; } -LocalClient* LocalClient::create(MXS_SESSION* session, const char* ip, uint64_t port) +LocalClient* LocalClient::create(MYSQL_session* session, MySQLProtocol* proto, const char* ip, uint64_t port) { LocalClient* rval = NULL; sockaddr_storage addr; @@ -245,7 +240,7 @@ LocalClient* LocalClient::create(MXS_SESSION* session, const char* ip, uint64_t if (fd > 0 && (connect(fd, (struct sockaddr*)&addr, sizeof(addr)) == 0 || errno == EINPROGRESS)) { - LocalClient* relay = new (std::nothrow) LocalClient(session, fd); + LocalClient* relay = new (std::nothrow) LocalClient(session, proto, fd); if (relay) { @@ -271,7 +266,7 @@ LocalClient* LocalClient::create(MXS_SESSION* session, const char* ip, uint64_t return rval; } -LocalClient* LocalClient::create(MXS_SESSION* session, SERVICE* service) +LocalClient* LocalClient::create(MYSQL_session* session, MySQLProtocol* proto, SERVICE* service) { LocalClient* rval = NULL; LISTENER_ITERATOR iter; @@ -282,7 +277,7 @@ LocalClient* LocalClient::create(MXS_SESSION* session, SERVICE* service) if (listener->port > 0) { /** Pick the first network listener */ - rval = create(session, "127.0.0.1", service->ports->port); + rval = create(session, proto, "127.0.0.1", service->ports->port); break; } } @@ -290,7 +285,7 @@ LocalClient* LocalClient::create(MXS_SESSION* session, SERVICE* service) return rval; } -LocalClient* LocalClient::create(MXS_SESSION* session, SERVER* server) +LocalClient* LocalClient::create(MYSQL_session* session, MySQLProtocol* proto, SERVER* server) { - return create(session, server->name, server->port); + return create(session, proto, server->name, server->port); } diff --git a/server/modules/protocol/MySQL/mysql_common.cc b/server/modules/protocol/MySQL/mysql_common.cc index 282e4cb38..6121f2ef4 100644 --- a/server/modules/protocol/MySQL/mysql_common.cc +++ b/server/modules/protocol/MySQL/mysql_common.cc @@ -19,7 +19,7 @@ #include #include -#include +#include #include #include @@ -30,6 +30,7 @@ #include #include #include +#include uint8_t null_client_sha1[MYSQL_SCRAMBLE_LEN] = ""; @@ -1586,17 +1587,60 @@ bool mxs_mysql_command_will_respond(uint8_t cmd) cmd != MXS_COM_STMT_CLOSE; } -typedef std::vector< std::pair > TargetList; +namespace +{ + +// Servers and queries to execute on them +typedef std::map TargetList; struct KillInfo { - uint64_t target_id; + typedef bool (*DcbCallback)(DCB *dcb, void *data); + + KillInfo(std::string query, MXS_SESSION* ses, DcbCallback callback): + origin(mxs_worker_get_current_id()), + query_base(query), + protocol(*(MySQLProtocol*)ses->client_dcb->protocol), + cb(callback) + { + gw_get_shared_session_auth_info(ses->client_dcb, &session); + } + + int origin; + std::string query_base; + MYSQL_session session; + MySQLProtocol protocol; + DcbCallback cb; TargetList targets; }; +static bool kill_func(DCB *dcb, void *data); + +struct ConnKillInfo: public KillInfo +{ + ConnKillInfo(uint64_t id, std::string query, MXS_SESSION* ses): + KillInfo(query, ses, kill_func), + target_id(id) + {} + + uint64_t target_id; +}; + +static bool kill_user_func(DCB *dcb, void *data); + +struct UserKillInfo: public KillInfo +{ + UserKillInfo(std::string name, std::string query, MXS_SESSION* ses): + KillInfo(query, ses, kill_user_func), + user(name) + {} + + std::string user; +}; + static bool kill_func(DCB *dcb, void *data) { - KillInfo* info = (KillInfo*)data; + ConnKillInfo* info = static_cast(data); if (dcb->dcb_role == DCB_ROLE_BACKEND_HANDLER && dcb->session->ses_id == info->target_id) @@ -1606,7 +1650,9 @@ static bool kill_func(DCB *dcb, void *data) if (proto->thread_id) { // DCB is connected and we know the thread ID so we can kill it - info->targets.push_back(std::make_pair(dcb->server, proto->thread_id)); + std::stringstream ss; + ss << info->query_base << proto->thread_id; + info->targets[dcb->server] = ss.str(); } else { @@ -1619,82 +1665,29 @@ static bool kill_func(DCB *dcb, void *data) return true; } -void mxs_mysql_execute_kill(MXS_SESSION* issuer, uint64_t target_id, kill_type_t type) -{ - // Gather a list of servers and connection IDs to kill - KillInfo info = {target_id}; - dcb_foreach(kill_func, &info); - - if (info.targets.empty()) - { - // No session found, send an error - std::stringstream err; - err << "Unknown thread id: " << target_id; - mysql_send_standard_error(issuer->client_dcb, 1, 1094, err.str().c_str()); - } - else - { - // Execute the KILL on all of the servers - for (TargetList::iterator it = info.targets.begin(); - it != info.targets.end(); it++) - { - LocalClient* client = LocalClient::create(issuer, it->first); - const char* hard = (type & KT_HARD) ? "HARD " : - (type & KT_SOFT) ? "SOFT " : - ""; - const char* query = (type & KT_QUERY) ? "QUERY " : ""; - std::stringstream ss; - ss << "KILL " << hard << query << it->second; - GWBUF* buffer = modutil_create_query(ss.str().c_str()); - client->queue_query(buffer); - gwbuf_free(buffer); - - // The LocalClient needs to delete itself once the queries are done - client->self_destruct(); - } - mxs_mysql_send_ok(issuer->client_dcb, 1, 0, NULL); - } -} - -typedef std::set ServerSet; - -struct KillUserInfo -{ - std::string user; - ServerSet targets; -}; - - static bool kill_user_func(DCB *dcb, void *data) { - KillUserInfo* info = (KillUserInfo*)data; + UserKillInfo* info = (UserKillInfo*)data; if (dcb->dcb_role == DCB_ROLE_BACKEND_HANDLER && strcasecmp(dcb->session->client_dcb->user, info->user.c_str()) == 0) { - info->targets.insert(dcb->server); + info->targets[dcb->server] = info->query_base; } return true; } -void mxs_mysql_execute_kill_user(MXS_SESSION* issuer, const char* user, kill_type_t type) +static void worker_func(int thread_id, void* data) { - // Gather a list of servers and connection IDs to kill - KillUserInfo info = {user}; - dcb_foreach(kill_user_func, &info); + KillInfo* info = static_cast(data); + dcb_foreach_local(info->cb, info); - // Execute the KILL on all of the servers - for (ServerSet::iterator it = info.targets.begin(); - it != info.targets.end(); it++) + for (TargetList::iterator it = info->targets.begin(); + it != info->targets.end(); it++) { - LocalClient* client = LocalClient::create(issuer, *it); - const char* hard = (type & KT_HARD) ? "HARD " : - (type & KT_SOFT) ? "SOFT " : ""; - const char* query = (type & KT_QUERY) ? "QUERY " : ""; - std::stringstream ss; - ss << "KILL " << hard << query << "USER " << user; - GWBUF* buffer = modutil_create_query(ss.str().c_str()); + LocalClient* client = LocalClient::create(&info->session, &info->protocol, it->first); + GWBUF* buffer = modutil_create_query(it->second.c_str()); client->queue_query(buffer); gwbuf_free(buffer); @@ -1702,5 +1695,43 @@ void mxs_mysql_execute_kill_user(MXS_SESSION* issuer, const char* user, kill_typ client->self_destruct(); } - mxs_mysql_send_ok(issuer->client_dcb, info.targets.size(), 0, NULL); + delete info; +} + +} + +void mxs_mysql_execute_kill(MXS_SESSION* issuer, uint64_t target_id, kill_type_t type) +{ + const char* hard = (type & KT_HARD) ? "HARD " : (type & KT_SOFT) ? "SOFT " : ""; + const char* query = (type & KT_QUERY) ? "QUERY " : ""; + std::stringstream ss; + ss << "KILL " << hard << query; + + for (int i = 0; i < config_threadcount(); i++) + { + MXS_WORKER* worker = mxs_worker_get(i); + ss_dassert(worker); + mxs_worker_post_message(worker, MXS_WORKER_MSG_CALL, (intptr_t)worker_func, + (intptr_t)new ConnKillInfo(target_id, ss.str(), issuer)); + } + + mxs_mysql_send_ok(issuer->client_dcb, 1, 0, NULL); +} + +void mxs_mysql_execute_kill_user(MXS_SESSION* issuer, const char* user, kill_type_t type) +{ + const char* hard = (type & KT_HARD) ? "HARD " : (type & KT_SOFT) ? "SOFT " : ""; + const char* query = (type & KT_QUERY) ? "QUERY " : ""; + std::stringstream ss; + ss << "KILL " << hard << query << "USER " << user; + + for (int i = 0; i < config_threadcount(); i++) + { + MXS_WORKER* worker = mxs_worker_get(i); + ss_dassert(worker); + mxs_worker_post_message(worker, MXS_WORKER_MSG_CALL, (intptr_t)worker_func, + (intptr_t)new UserKillInfo(user, ss.str(), issuer)); + } + + mxs_mysql_send_ok(issuer->client_dcb, 1, 0, NULL); }