/* * Copyright (c) 2016 MariaDB Corporation Ab * * Use of this software is governed by the Business Source License included * in the LICENSE.TXT file and at www.mariadb.com/bsl11. * * Change Date: 2024-11-16 * * On the date above, in accordance with the Business Source License, use * of this software will be governed by version 2 or later of the General * Public License. */ #include #include #include #include #ifdef EPOLLRDHUP #define ERROR_EVENTS (EPOLLRDHUP | EPOLLHUP | EPOLLERR) #else #define ERROR_EVENTS (EPOLLHUP | EPOLLERR) #endif static const uint32_t poll_events = EPOLLIN | EPOLLOUT | EPOLLET | ERROR_EVENTS; LocalClient::LocalClient(MYSQL_session* session, MySQLProtocol* proto, int fd) : m_state(VC_WAITING_HANDSHAKE) , m_sock(fd) , m_expected_bytes(0) , m_client(*session) , m_protocol(*proto) , m_self_destruct(false) { MXB_POLL_DATA::handler = LocalClient::poll_handler; m_protocol.owner_dcb = nullptr; m_protocol.stored_query = nullptr; } LocalClient::~LocalClient() { if (m_state != VC_ERROR) { close(); } } bool LocalClient::queue_query(GWBUF* buffer) { GWBUF* my_buf = NULL; if (m_state != VC_ERROR && (my_buf = gwbuf_deep_clone(buffer))) { m_queue.push_back(my_buf); if (m_state == VC_OK) { drain_queue(); } } return my_buf != NULL; } void LocalClient::self_destruct() { GWBUF* buffer = mysql_create_com_quit(NULL, 0); queue_query(buffer); gwbuf_free(buffer); m_self_destruct = true; } void LocalClient::close() { mxb::Worker* worker = mxb::Worker::get_current(); mxb_assert(worker); worker->remove_fd(m_sock); ::close(m_sock); } void LocalClient::error() { if (m_state != VC_ERROR) { close(); m_state = VC_ERROR; } } void LocalClient::process(uint32_t events) { if (events & EPOLLIN) { GWBUF* buf = read_complete_packet(); if (buf) { if (m_state == VC_WAITING_HANDSHAKE) { if (gw_decode_mysql_server_handshake(&m_protocol, GWBUF_DATA(buf) + MYSQL_HEADER_LEN) == 0) { GWBUF* response = gw_generate_auth_response(&m_client, &m_protocol, false, false, 0); m_queue.push_front(response); m_state = VC_RESPONSE_SENT; } else { error(); } } else if (m_state == VC_RESPONSE_SENT) { if (mxs_mysql_is_ok_packet(buf)) { m_state = VC_OK; } else { error(); } } gwbuf_free(buf); } } if (events & EPOLLOUT) { /** Queue is drained */ } if (events & ERROR_EVENTS) { error(); } if (m_queue.size() && m_state != VC_ERROR && m_state != VC_WAITING_HANDSHAKE) { drain_queue(); } else if (m_state == VC_ERROR && m_self_destruct) { delete this; } } GWBUF* LocalClient::read_complete_packet() { GWBUF* rval = NULL; while (true) { uint8_t buffer[1024]; int rc = read(m_sock, buffer, sizeof(buffer)); if (rc == -1) { if (errno != EAGAIN && errno != EWOULDBLOCK) { MXS_ERROR("Failed to read from backend: %d, %s", errno, mxs_strerror(errno)); error(); } break; } mxs::Buffer chunk(buffer, rc); m_partial.append(chunk); size_t len = m_partial.length(); if (m_expected_bytes == 0 && len >= 3) { mxs::Buffer::iterator iter = m_partial.begin(); m_expected_bytes = MYSQL_HEADER_LEN; m_expected_bytes += *iter++; m_expected_bytes += (*iter++ << 8); m_expected_bytes += (*iter++ << 16); } if (len >= m_expected_bytes) { /** Read complete packet. Reset expected byte count and make * the buffer contiguous. */ m_expected_bytes = 0; m_partial.make_contiguous(); rval = m_partial.release(); break; } } return rval; } void LocalClient::drain_queue() { bool more = true; while (m_queue.size() && more) { /** Grab a buffer from the queue */ GWBUF* buf = m_queue.front().release(); m_queue.pop_front(); while (buf) { int rc = write(m_sock, GWBUF_DATA(buf), GWBUF_LENGTH(buf)); if (rc > 0) { buf = gwbuf_consume(buf, rc); } else { if (rc == -1 && errno != EAGAIN && errno != EWOULDBLOCK) { MXS_ERROR("Failed to write to backend: %d, %s", errno, mxs_strerror(errno)); error(); } m_queue.push_front(buf); more = false; break; } } } } uint32_t LocalClient::poll_handler(MXB_POLL_DATA* data, MXB_WORKER* worker, uint32_t events) { LocalClient* client = static_cast(data); client->process(events); return 0; } namespace { using namespace maxbase; int connect_socket(const Host& host) { int fd = -1; struct sockaddr* sock_addr = nullptr; socklen_t sock_len = 0; switch (host.type()) { case Host::Type::Invalid: break; case Host::Type::UnixDomainSocket: { struct sockaddr_un addr; sock_addr = reinterpret_cast(&addr); sock_len = sizeof(addr); fd = open_unix_socket(MXS_SOCKET_NETWORK, &addr, host.address().c_str()); } break; default: { struct sockaddr_storage addr; sock_addr = reinterpret_cast(&addr); sock_len = sizeof(addr); fd = open_network_socket(MXS_SOCKET_NETWORK, &addr, host.address().c_str(), host.port()); } break; } if (fd >= 0 && sock_addr && sock_len) { bool ok = connect(fd, sock_addr, sock_len) == 0 || errno == EINPROGRESS; if (!ok) { ::close(fd); fd = -1; } } return fd; } } LocalClient* LocalClient::create(MYSQL_session* session, MySQLProtocol* proto, const char* ip, uint64_t port) { LocalClient* rval = NULL; int fd = connect_socket(maxbase::Host(ip, port)); if (fd >= 0) { LocalClient* relay = new(std::nothrow) LocalClient(session, proto, fd); if (relay) { mxb::Worker* worker = mxb::Worker::get_current(); if (worker->add_fd(fd, poll_events, (MXB_POLL_DATA*)relay)) { rval = relay; } else { relay->m_state = VC_ERROR; delete rval; rval = NULL; } } } if (rval == NULL && fd >= 0) { ::close(fd); } return rval; } LocalClient* LocalClient::create(MYSQL_session* session, MySQLProtocol* proto, SERVICE* service) { LocalClient* rval = NULL; for (const auto& listener : listener_find_by_service(service)) { if (listener->port() > 0) { /** Pick the first network listener */ rval = create(session, proto, "127.0.0.1", listener->port()); break; } } return rval; } LocalClient* LocalClient::create(MYSQL_session* session, MySQLProtocol* proto, SERVER* server) { return create(session, proto, server->address, server->port); }