diff --git a/include/maxscale/session.h b/include/maxscale/session.h index 162ccec97..03f5f4963 100644 --- a/include/maxscale/session.h +++ b/include/maxscale/session.h @@ -416,10 +416,11 @@ void session_clear_stmt(MXS_SESSION *session); /** * Try to kill a specific session. This function only sends messages to * worker threads without waiting for the result. + * * @param issuer The session where the command originates. * @param target_id Target session id. */ -void session_broadcast_kill_command(MXS_SESSION* issuer, uint32_t target_id); +void session_broadcast_kill_command(MXS_SESSION* issuer, uint64_t target_id); /** * @brief Convert a session to JSON diff --git a/include/maxscale/worker.h b/include/maxscale/worker.h index f323e0912..711c5fbe4 100644 --- a/include/maxscale/worker.h +++ b/include/maxscale/worker.h @@ -117,25 +117,29 @@ bool mxs_worker_post_message(MXS_WORKER* worker, uint32_t msg_id, intptr_t arg1, size_t mxs_worker_broadcast_message(uint32_t msg_id, intptr_t arg1, intptr_t arg2); /** - * Add a session to the current worker's session map. - * @param id With which id to add. Typically session->ses_id. + * Add a session to the current worker's session container. Currently only + * required for some special commands e.g. "KILL " to work. + * * @param session Session to add. * @return true if successful, false if id already existed in map. */ -bool mxs_add_to_session_map(uint32_t id, MXS_SESSION* session); +bool mxs_worker_register_session(MXS_SESSION* session); /** - * Remove a session from the current worker's session map. + * Remove a session from the current worker's session container. Does not actually + * remove anything from an epoll-set or affect the session in any way. + * * @param id Which id to remove. * @return The removed session or NULL if not found. */ -MXS_SESSION* mxs_remove_from_session_map(uint32_t id); +MXS_SESSION* mxs_worker_deregister_session(uint64_t id); /** - * Find a session in the current worker's session map. + * Find a session in the current worker's session container. + * * @param id Which id to find. * @return The found session or NULL if not found. */ -MXS_SESSION* mxs_find_in_session_map(uint32_t id); +MXS_SESSION* mxs_worker_find_session(uint64_t id); MXS_END_DECLS diff --git a/server/core/maxscale/worker.hh b/server/core/maxscale/worker.hh index c3060fc15..900b5f17c 100644 --- a/server/core/maxscale/worker.hh +++ b/server/core/maxscale/worker.hh @@ -69,7 +69,7 @@ public: typedef WORKER_STATISTICS STATISTICS; typedef WorkerTask Task; typedef WorkerDisposableTask DisposableTask; - typedef std::tr1::unordered_map SessionsById; + typedef std::tr1::unordered_map SessionsById; enum state_t { @@ -387,29 +387,28 @@ public: bool post_message(uint32_t msg_id, intptr_t arg1, intptr_t arg2); /** - * Add a session to the sessions hashmap + * Add a session to the session container. * - * @param id Session id, must be unique * @param session The session to add * @return true if successful */ - bool add_to_session_map(SessionsById::key_type id, SessionsById::mapped_type session); + bool register_session(MXS_SESSION* session); /** - * Remove a session from the sessions hashmap + * Remove a session from the session container. * * @param id Session id * @return The removed session, or NULL if not found */ - SessionsById::mapped_type remove_from_session_map(SessionsById::key_type id); + MXS_SESSION* deregister_session(uint64_t id); /** - * Find a session in the sessions hashmap + * Find a session in the session container. * * @param id Session id * @return The found session, or NULL if not found */ - SessionsById::mapped_type find_in_session_map(SessionsById::key_type id); + MXS_SESSION* find_session(uint64_t id); /** * Broadcast a message to all worker. diff --git a/server/core/session.cc b/server/core/session.cc index ec6937d83..26e803452 100644 --- a/server/core/session.cc +++ b/server/core/session.cc @@ -73,35 +73,49 @@ static MXS_SESSION* session_alloc_body(SERVICE* service, DCB* client_dcb, namespace { +/** + * Checks if issuer_user@issuer_host has the privilege to kill the target session. + * Currently just checks that the user and host are the same. + * + * This function should only be called in the worker thread normally handling + * the target session, otherwise target session could be freed while function is + * running. + * + * @param issuer_user User name of command issuer + * @param issuer_host Host/ip of command issuer + * @param target Target session + * @return + */ +bool issuer_can_kill_target(const string& issuer_user, const string& issuer_host, + const MXS_SESSION* target) +{ + DCB* target_dcb = target->client_dcb; + return ((strcmp(issuer_user.c_str(), target_dcb->user) == 0) && + (strcmp(issuer_host.c_str(), target_dcb->remote) == 0)); +} class KillCmdTask : public maxscale::Worker::DisposableTask { -private: - std::string m_issuer_username; - std::string m_issuer_host; - uint64_t m_target_id; - public: KillCmdTask(MXS_SESSION* issuer, uint64_t target_id) + : m_issuer_user(issuer->client_dcb->user) + , m_issuer_host(issuer->client_dcb->remote) + , m_target_id(target_id) { - DCB* issuer_dcb = issuer->client_dcb; - m_issuer_username.assign(issuer_dcb->user); - m_issuer_host.assign(issuer_dcb->remote); - m_target_id = target_id; } + void execute(maxscale::Worker& worker) { - MXS_SESSION* target = worker.find_in_session_map(m_target_id); - if (target) + MXS_SESSION* target = worker.find_session(m_target_id); + if (target && issuer_can_kill_target(m_issuer_user, m_issuer_host, target)) { - DCB* target_dcb = target->client_dcb; - if ((strcmp(m_issuer_username.c_str(), target_dcb->user) == 0) && - (strcmp(m_issuer_host.c_str(), target_dcb->remote) == 0)) - { - poll_fake_hangup_event(target_dcb); - } + poll_fake_hangup_event(target->client_dcb); } } +private: + std::string m_issuer_user; + std::string m_issuer_host; + uint64_t m_target_id; }; } @@ -1036,26 +1050,28 @@ uint32_t session_get_next_id() return atomic_add_uint32(&next_session_id, 1); } -void session_broadcast_kill_command(MXS_SESSION* issuer, uint32_t target_id) +void session_broadcast_kill_command(MXS_SESSION* issuer, uint64_t target_id) { /* First, check if the target id belongs to the current worker. If it does, * send hangup event. Otherwise, use a worker task to send a message to all * workers. */ - MXS_SESSION* target_ses = mxs_find_in_session_map(target_id); - if (target_ses) + MXS_SESSION* target = mxs_worker_find_session(target_id); + if (target && + issuer_can_kill_target(issuer->client_dcb->user, + issuer->client_dcb->remote, + target)) { - if ((strcmp(issuer->client_dcb->user, target_ses->client_dcb->user) == 0) && - (strcmp(issuer->client_dcb->remote, target_ses->client_dcb->remote) == 0)) - { - poll_fake_hangup_event(target_ses->client_dcb); - } + poll_fake_hangup_event(target->client_dcb); } else { - KillCmdTask* kill_task = new KillCmdTask(issuer, target_id); - std::auto_ptr sTask(kill_task); - maxscale::Worker::broadcast(sTask); + KillCmdTask* kill_task = new (std::nothrow) KillCmdTask(issuer, target_id); + if (kill_task) + { + std::auto_ptr sKillTask(kill_task); + maxscale::Worker::broadcast(sKillTask); + } } } diff --git a/server/core/worker.cc b/server/core/worker.cc index e551c2774..9112a65ed 100644 --- a/server/core/worker.cc +++ b/server/core/worker.cc @@ -712,47 +712,43 @@ size_t mxs_worker_broadcast_message(uint32_t msg_id, intptr_t arg1, intptr_t arg return Worker::broadcast_message(msg_id, arg1, arg2); } -bool mxs_add_to_session_map(uint32_t id, MXS_SESSION* session) +bool mxs_worker_register_session(MXS_SESSION* session) { - bool rval = false; Worker* worker = Worker::get_current(); - if (worker) - { - rval = worker->add_to_session_map(id, session); - } - return rval; + ss_dassert(worker); + return worker->register_session(session); } -MXS_SESSION* mxs_remove_from_session_map(uint32_t id) +MXS_SESSION* mxs_worker_deregister_session(uint64_t id) { MXS_SESSION* rval = NULL; Worker* worker = Worker::get_current(); if (worker) { - rval = worker->remove_from_session_map(id); + rval = worker->deregister_session(id); } return rval; } -MXS_SESSION* mxs_find_in_session_map(uint32_t id) +MXS_SESSION* mxs_worker_find_session(uint64_t id) { MXS_SESSION* rval = NULL; Worker* worker = Worker::get_current(); if (worker) { - rval = worker->find_in_session_map(id); + rval = worker->find_session(id); } return rval; } -bool Worker::add_to_session_map(SessionsById::key_type id, SessionsById::mapped_type session) +bool Worker::register_session(MXS_SESSION* session) { - return m_sessions.insert(SessionsById::value_type(id, session)).second; + return m_sessions.insert(SessionsById::value_type(session->ses_id, session)).second; } -Worker::SessionsById::mapped_type Worker::remove_from_session_map(SessionsById::key_type id) +MXS_SESSION* Worker::deregister_session(uint64_t id) { - Worker::SessionsById::mapped_type rval = find_in_session_map(id); + MXS_SESSION* rval = find_session(id); if (rval) { m_sessions.erase(id); @@ -760,9 +756,9 @@ Worker::SessionsById::mapped_type Worker::remove_from_session_map(SessionsById:: return rval; } -Worker::SessionsById::mapped_type Worker::find_in_session_map(SessionsById::key_type id) +MXS_SESSION* Worker::find_session(uint64_t id) { - Worker::SessionsById::mapped_type rval = NULL; + MXS_SESSION* rval = NULL; SessionsById::const_iterator iter = m_sessions.find(id); if (iter != m_sessions.end()) { diff --git a/server/modules/protocol/MySQL/MySQLClient/mysql_client.c b/server/modules/protocol/MySQL/MySQLClient/mysql_client.c index ad47e712d..fa684b277 100644 --- a/server/modules/protocol/MySQL/MySQLClient/mysql_client.c +++ b/server/modules/protocol/MySQL/MySQLClient/mysql_client.c @@ -668,7 +668,7 @@ gw_read_do_authentication(DCB *dcb, GWBUF *read_buffer, int nbytes_read) * normal data handling function instead of this one. */ MXS_SESSION *session = - session_alloc_with_id(dcb->service, dcb, protocol->tid); + session_alloc_with_id(dcb->service, dcb, protocol->tid); if (session != NULL) { @@ -676,7 +676,7 @@ gw_read_do_authentication(DCB *dcb, GWBUF *read_buffer, int nbytes_read) ss_dassert(session->state != SESSION_STATE_ALLOC && session->state != SESSION_STATE_DUMMY); protocol->protocol_auth_state = MXS_AUTH_STATE_COMPLETE; - ss_debug(bool check =) mxs_add_to_session_map(session->ses_id, session); + ss_debug(bool check = ) mxs_worker_register_session(session); ss_dassert(check); mxs_mysql_send_ok(dcb, next_sequence, 0, NULL); } @@ -1258,7 +1258,7 @@ static int gw_client_close(DCB *dcb) ss_dassert(dcb->protocol); mysql_protocol_done(dcb); MXS_SESSION* target = dcb->session; - ss_debug(MXS_SESSION* removed =) mxs_remove_from_session_map(target->ses_id); + ss_debug(MXS_SESSION* removed = ) mxs_worker_deregister_session(target->ses_id); ss_dassert(removed == target); session_close(target); return 1; @@ -1515,12 +1515,12 @@ static bool process_special_commands(DCB* dcb, GWBUF *read_buffer, int nbytes_re /** * Handle COM_PROCESS_KILL */ - else if((proto->current_command == MYSQL_COM_PROCESS_KILL)) + else if ((proto->current_command == MYSQL_COM_PROCESS_KILL)) { /* Make sure we have a complete SQL packet before trying to read the * process id. If not, try again next time. */ unsigned int expected_len = - MYSQL_GET_PAYLOAD_LEN((uint8_t *)GWBUF_DATA(read_buffer)) + MYSQL_HEADER_LEN; + MYSQL_GET_PAYLOAD_LEN((uint8_t *)GWBUF_DATA(read_buffer)) + MYSQL_HEADER_LEN; if (gwbuf_length(read_buffer) < expected_len) { dcb->dcb_readqueue = read_buffer; @@ -1530,9 +1530,9 @@ static bool process_special_commands(DCB* dcb, GWBUF *read_buffer, int nbytes_re { uint8_t bytes[4]; if (gwbuf_copy_data(read_buffer, MYSQL_HEADER_LEN + 1, sizeof(bytes), (uint8_t*)bytes) - == sizeof(bytes)) + == sizeof(bytes)) { - uint32_t process_id = gw_mysql_get_byte4(bytes); + uint64_t process_id = gw_mysql_get_byte4(bytes); // Do not send this packet for routing gwbuf_free(read_buffer); session_broadcast_kill_command(dcb->session, process_id);