338 lines
		
	
	
		
			7.8 KiB
		
	
	
	
		
			C++
		
	
	
	
	
	
			
		
		
	
	
			338 lines
		
	
	
		
			7.8 KiB
		
	
	
	
		
			C++
		
	
	
	
	
	
| /*
 | |
|  * Copyright (c) 2016 MariaDB Corporation Ab
 | |
|  *
 | |
|  * Use of this software is governed by the Business Source License included
 | |
|  * in the LICENSE.TXT file and at www.mariadb.com/bsl11.
 | |
|  *
 | |
|  * Change Date: 2024-03-10
 | |
|  *
 | |
|  * On the date above, in accordance with the Business Source License, use
 | |
|  * of this software will be governed by version 2 or later of the General
 | |
|  * Public License.
 | |
|  */
 | |
| 
 | |
| #include <maxscale/protocol/mariadb_client.hh>
 | |
| #include <maxscale/routingworker.hh>
 | |
| #include <maxscale/utils.h>
 | |
| #include <maxbase/host.hh>
 | |
| 
 | |
| #ifdef EPOLLRDHUP
 | |
| #define ERROR_EVENTS (EPOLLRDHUP | EPOLLHUP | EPOLLERR)
 | |
| #else
 | |
| #define ERROR_EVENTS (EPOLLHUP | EPOLLERR)
 | |
| #endif
 | |
| 
 | |
| static const uint32_t poll_events = EPOLLIN | EPOLLOUT | EPOLLET | ERROR_EVENTS;
 | |
| 
 | |
| LocalClient::LocalClient(MYSQL_session* session, MySQLProtocol* proto, int fd)
 | |
|     : m_state(VC_WAITING_HANDSHAKE)
 | |
|     , m_sock(fd)
 | |
|     , m_expected_bytes(0)
 | |
|     , m_client(*session)
 | |
|     , m_protocol(*proto)
 | |
|     , m_self_destruct(false)
 | |
| {
 | |
|     MXB_POLL_DATA::handler = LocalClient::poll_handler;
 | |
|     m_protocol.owner_dcb = nullptr;
 | |
|     m_protocol.stored_query = nullptr;
 | |
| }
 | |
| 
 | |
| LocalClient::~LocalClient()
 | |
| {
 | |
|     if (m_state != VC_ERROR)
 | |
|     {
 | |
|         close();
 | |
|     }
 | |
| }
 | |
| 
 | |
| bool LocalClient::queue_query(GWBUF* buffer)
 | |
| {
 | |
|     GWBUF* my_buf = NULL;
 | |
| 
 | |
|     if (m_state != VC_ERROR && (my_buf = gwbuf_deep_clone(buffer)))
 | |
|     {
 | |
|         m_queue.push_back(my_buf);
 | |
| 
 | |
|         if (m_state == VC_OK)
 | |
|         {
 | |
|             drain_queue();
 | |
|         }
 | |
|     }
 | |
| 
 | |
|     return my_buf != NULL;
 | |
| }
 | |
| 
 | |
| void LocalClient::self_destruct()
 | |
| {
 | |
|     GWBUF* buffer = mysql_create_com_quit(NULL, 0);
 | |
|     queue_query(buffer);
 | |
|     gwbuf_free(buffer);
 | |
|     m_self_destruct = true;
 | |
| }
 | |
| 
 | |
| void LocalClient::close()
 | |
| {
 | |
|     mxb::Worker* worker = mxb::Worker::get_current();
 | |
|     mxb_assert(worker);
 | |
|     worker->remove_fd(m_sock);
 | |
|     ::close(m_sock);
 | |
| }
 | |
| 
 | |
| void LocalClient::error()
 | |
| {
 | |
|     if (m_state != VC_ERROR)
 | |
|     {
 | |
|         close();
 | |
|         m_state = VC_ERROR;
 | |
|     }
 | |
| }
 | |
| 
 | |
| void LocalClient::process(uint32_t events)
 | |
| {
 | |
| 
 | |
|     if (events & EPOLLIN)
 | |
|     {
 | |
|         GWBUF* buf = read_complete_packet();
 | |
| 
 | |
|         if (buf)
 | |
|         {
 | |
|             if (m_state == VC_WAITING_HANDSHAKE)
 | |
|             {
 | |
|                 if (gw_decode_mysql_server_handshake(&m_protocol, GWBUF_DATA(buf) + MYSQL_HEADER_LEN) == 0)
 | |
|                 {
 | |
|                     GWBUF* response = gw_generate_auth_response(&m_client, &m_protocol, false, false, 0);
 | |
|                     m_queue.push_front(response);
 | |
|                     m_state = VC_RESPONSE_SENT;
 | |
|                 }
 | |
|                 else
 | |
|                 {
 | |
|                     error();
 | |
|                 }
 | |
|             }
 | |
|             else if (m_state == VC_RESPONSE_SENT)
 | |
|             {
 | |
|                 if (mxs_mysql_is_ok_packet(buf))
 | |
|                 {
 | |
|                     m_state = VC_OK;
 | |
|                 }
 | |
|                 else
 | |
|                 {
 | |
|                     error();
 | |
|                 }
 | |
|             }
 | |
| 
 | |
|             gwbuf_free(buf);
 | |
|         }
 | |
|     }
 | |
| 
 | |
|     if (events & EPOLLOUT)
 | |
|     {
 | |
|         /** Queue is drained */
 | |
|     }
 | |
| 
 | |
|     if (events & ERROR_EVENTS)
 | |
|     {
 | |
|         error();
 | |
|     }
 | |
| 
 | |
|     if (m_queue.size() && m_state != VC_ERROR && m_state != VC_WAITING_HANDSHAKE)
 | |
|     {
 | |
|         drain_queue();
 | |
|     }
 | |
|     else if (m_state == VC_ERROR && m_self_destruct)
 | |
|     {
 | |
|         delete this;
 | |
|     }
 | |
| }
 | |
| 
 | |
| GWBUF* LocalClient::read_complete_packet()
 | |
| {
 | |
|     GWBUF* rval = NULL;
 | |
| 
 | |
|     while (true)
 | |
|     {
 | |
|         uint8_t buffer[1024];
 | |
|         int rc = read(m_sock, buffer, sizeof(buffer));
 | |
| 
 | |
|         if (rc == -1)
 | |
|         {
 | |
|             if (errno != EAGAIN && errno != EWOULDBLOCK)
 | |
|             {
 | |
|                 MXS_ERROR("Failed to read from backend: %d, %s", errno, mxs_strerror(errno));
 | |
|                 error();
 | |
|             }
 | |
|             break;
 | |
|         }
 | |
| 
 | |
|         mxs::Buffer chunk(buffer, rc);
 | |
|         m_partial.append(chunk);
 | |
|         size_t len = m_partial.length();
 | |
| 
 | |
|         if (m_expected_bytes == 0 && len >= 3)
 | |
|         {
 | |
|             mxs::Buffer::iterator iter = m_partial.begin();
 | |
|             m_expected_bytes = MYSQL_HEADER_LEN;
 | |
|             m_expected_bytes += *iter++;
 | |
|             m_expected_bytes += (*iter++ << 8);
 | |
|             m_expected_bytes += (*iter++ << 16);
 | |
|         }
 | |
| 
 | |
|         if (len >= m_expected_bytes)
 | |
|         {
 | |
|             /** Read complete packet. Reset expected byte count and make
 | |
|              * the buffer contiguous. */
 | |
|             m_expected_bytes = 0;
 | |
|             m_partial.make_contiguous();
 | |
|             rval = m_partial.release();
 | |
|             break;
 | |
|         }
 | |
|     }
 | |
| 
 | |
|     return rval;
 | |
| }
 | |
| 
 | |
| void LocalClient::drain_queue()
 | |
| {
 | |
|     bool more = true;
 | |
| 
 | |
|     while (m_queue.size() && more)
 | |
|     {
 | |
|         /** Grab a buffer from the queue */
 | |
|         GWBUF* buf = m_queue.front().release();
 | |
|         m_queue.pop_front();
 | |
| 
 | |
|         while (buf)
 | |
|         {
 | |
|             int rc = write(m_sock, GWBUF_DATA(buf), GWBUF_LENGTH(buf));
 | |
| 
 | |
|             if (rc > 0)
 | |
|             {
 | |
|                 buf = gwbuf_consume(buf, rc);
 | |
|             }
 | |
|             else
 | |
|             {
 | |
|                 if (rc == -1 && errno != EAGAIN && errno != EWOULDBLOCK)
 | |
|                 {
 | |
|                     MXS_ERROR("Failed to write to backend: %d, %s", errno, mxs_strerror(errno));
 | |
|                     error();
 | |
|                 }
 | |
| 
 | |
|                 m_queue.push_front(buf);
 | |
|                 more = false;
 | |
|                 break;
 | |
|             }
 | |
|         }
 | |
|     }
 | |
| }
 | |
| 
 | |
| uint32_t LocalClient::poll_handler(MXB_POLL_DATA* data, MXB_WORKER* worker, uint32_t events)
 | |
| {
 | |
|     LocalClient* client = static_cast<LocalClient*>(data);
 | |
|     client->process(events);
 | |
|     return 0;
 | |
| }
 | |
| 
 | |
| namespace
 | |
| {
 | |
| using namespace maxbase;
 | |
| 
 | |
| int connect_socket(const Host& host)
 | |
| {
 | |
|     int fd = -1;
 | |
|     struct sockaddr* sock_addr = nullptr;
 | |
|     socklen_t sock_len = 0;
 | |
| 
 | |
|     switch (host.type())
 | |
|     {
 | |
|     case Host::Type::Invalid:
 | |
|         break;
 | |
| 
 | |
|     case Host::Type::UnixDomainSocket:
 | |
|         {
 | |
|             struct sockaddr_un addr;
 | |
|             sock_addr = reinterpret_cast<struct sockaddr*>(&addr);
 | |
|             sock_len = sizeof(addr);
 | |
|             fd = open_unix_socket(MXS_SOCKET_NETWORK, &addr, host.address().c_str());
 | |
|         }
 | |
|         break;
 | |
| 
 | |
|     default:
 | |
|         {
 | |
|             struct sockaddr_storage addr;
 | |
|             sock_addr = reinterpret_cast<struct sockaddr*>(&addr);
 | |
|             sock_len = sizeof(addr);
 | |
|             fd = open_network_socket(MXS_SOCKET_NETWORK, &addr, host.address().c_str(), host.port());
 | |
|         }
 | |
|         break;
 | |
|     }
 | |
| 
 | |
|     if (fd >= 0 && sock_addr && sock_len)
 | |
|     {
 | |
|         bool ok = connect(fd, sock_addr, sock_len) == 0 || errno == EINPROGRESS;
 | |
|         if (!ok)
 | |
|         {
 | |
|             ::close(fd);
 | |
|             fd = -1;
 | |
|         }
 | |
|     }
 | |
| 
 | |
|     return fd;
 | |
| }
 | |
| }
 | |
| 
 | |
| LocalClient* LocalClient::create(MYSQL_session* session, MySQLProtocol* proto, const char* ip, uint64_t port)
 | |
| {
 | |
|     LocalClient* rval = NULL;
 | |
| 
 | |
|     int fd = connect_socket(maxbase::Host(ip, port));
 | |
| 
 | |
|     if (fd >= 0)
 | |
|     {
 | |
|         LocalClient* relay = new(std::nothrow) LocalClient(session, proto, fd);
 | |
| 
 | |
|         if (relay)
 | |
|         {
 | |
|             mxb::Worker* worker = mxb::Worker::get_current();
 | |
| 
 | |
|             if (worker->add_fd(fd, poll_events, (MXB_POLL_DATA*)relay))
 | |
|             {
 | |
|                 rval = relay;
 | |
|             }
 | |
|             else
 | |
|             {
 | |
|                 relay->m_state = VC_ERROR;
 | |
|                 delete rval;
 | |
|                 rval = NULL;
 | |
|             }
 | |
|         }
 | |
|     }
 | |
| 
 | |
|     if (rval == NULL && fd >= 0)
 | |
|     {
 | |
|         ::close(fd);
 | |
|     }
 | |
|     return rval;
 | |
| }
 | |
| 
 | |
| LocalClient* LocalClient::create(MYSQL_session* session, MySQLProtocol* proto, SERVICE* service)
 | |
| {
 | |
|     LocalClient* rval = NULL;
 | |
| 
 | |
|     for (const auto& listener : listener_find_by_service(service))
 | |
|     {
 | |
|         if (listener->port() > 0)
 | |
|         {
 | |
|             /** Pick the first network listener */
 | |
|             rval = create(session, proto, "127.0.0.1", listener->port());
 | |
|             break;
 | |
|         }
 | |
|     }
 | |
| 
 | |
|     return rval;
 | |
| }
 | |
| 
 | |
| LocalClient* LocalClient::create(MYSQL_session* session, MySQLProtocol* proto, SERVER* server)
 | |
| {
 | |
|     return create(session, proto, server->address, server->port);
 | |
| }
 | 
