diff --git a/server/modules/filter/tee/local_client.cc b/server/modules/filter/tee/local_client.cc new file mode 100644 index 000000000..fb1926db6 --- /dev/null +++ b/server/modules/filter/tee/local_client.cc @@ -0,0 +1,248 @@ +/* + * Copyright (c) 2016 MariaDB Corporation Ab + * + * Use of this software is governed by the Business Source License included + * in the LICENSE.TXT file and at www.mariadb.com/bsl11. + * + * Change Date: 2019-07-01 + * + * On the date above, in accordance with the Business Source License, use + * of this software will be governed by version 2 or later of the General + * Public License. + */ + +#include "local_client.hh" + +#include + +// TODO: Find a way to cleanly expose this +#include "../../../core/maxscale/worker.hh" + +#ifdef EPOLLRDHUP +#define ERROR_EVENTS (EPOLLRDHUP | EPOLLHUP) +#else +#define ERROR_EVENTS EPOLLHUP +#endif + +static const uint32_t poll_events = EPOLLIN | EPOLLOUT | EPOLLET | ERROR_EVENTS; + +LocalClient::LocalClient(MXS_SESSION* session, int fd): + m_state(VC_WAITING_HANDSHAKE), + m_sock(fd), + m_expected_bytes(0), + m_session(session) +{ + MXS_POLL_DATA::handler = LocalClient::poll_handler; + MySQLProtocol* client = (MySQLProtocol*)m_session->client_dcb->protocol; + m_proto = {}; + m_proto.charset = client->charset; + m_proto.client_capabilities = client->client_capabilities; + m_proto.extra_capabilities = client->extra_capabilities; +} + +LocalClient::~LocalClient() +{ + if (m_state != VC_ERROR) + { + close(m_sock); + } +} + +bool LocalClient::query(GWBUF* buffer) +{ + GWBUF* my_buf = gwbuf_clone(buffer); + + if (my_buf) + { + m_queue.push_back(my_buf); + + if (m_state == VC_OK) + { + drain_queue(); + } + } + + return my_buf != NULL; +} + +void LocalClient::error() +{ + close(m_sock); + m_state = VC_ERROR; +} + +void LocalClient::process(uint32_t events) +{ + + if (events & EPOLLIN) + { + GWBUF* buf = read_complete_packet(); + + if (buf) + { + if (m_state == VC_WAITING_HANDSHAKE) + { + if (gw_decode_mysql_server_handshake(&m_proto, GWBUF_DATA(buf) + MYSQL_HEADER_LEN) == 0) + { + GWBUF* response = gw_generate_auth_response(m_session, &m_proto, false, false); + m_queue.push_front(response); + m_state = VC_RESPONSE_SENT; + } + else + { + error(); + } + } + else if (m_state == VC_RESPONSE_SENT) + { + if (mxs_mysql_is_ok_packet(buf)) + { + m_state = VC_OK; + } + else + { + error(); + } + } + + gwbuf_free(buf); + } + } + + if (events & EPOLLOUT) + { + /** Queue is drained */ + } + + if (events & ERROR_EVENTS) + { + error(); + } + + if (m_queue.size() && m_state != VC_ERROR) + { + drain_queue(); + } +} + +GWBUF* LocalClient::read_complete_packet() +{ + GWBUF* rval = NULL; + + while (true) + { + uint8_t buffer[1024]; + int rc = read(m_sock, buffer, sizeof(buffer)); + + if (rc == -1) + { + if (errno != EAGAIN && errno != EWOULDBLOCK) + { + MXS_ERROR("Failed to read from backend: %d, %s", errno, mxs_strerror(errno)); + error(); + } + break; + } + + mxs::Buffer chunk(buffer, rc); + m_partial.append(chunk); + size_t len = m_partial.length(); + + if (m_expected_bytes == 0 && len >= 3) + { + mxs::Buffer::iterator iter = m_partial.begin(); + m_expected_bytes = MYSQL_HEADER_LEN; + m_expected_bytes += *iter++; + m_expected_bytes += (*iter++ << 8); + m_expected_bytes += (*iter++ << 16); + } + + if (len >= m_expected_bytes) + { + /** Read complete packet. Reset expected byte count and make + * the buffer contiguous. */ + m_expected_bytes = 0; + m_partial.make_contiguous(); + rval = m_partial.release(); + break; + } + + } + + return rval; +} + +void LocalClient::drain_queue() +{ + bool more = true; + + while (m_queue.size() && more) + { + /** Grab a buffer from the queue */ + GWBUF* buf = m_queue.front().release(); + m_queue.pop_front(); + + while (buf) + { + int rc = write(m_sock, GWBUF_DATA(buf), GWBUF_LENGTH(buf)); + + if (rc > 0) + { + buf = gwbuf_consume(buf, rc); + } + else + { + if (rc == -1 && errno != EAGAIN && errno != EWOULDBLOCK) + { + MXS_ERROR("Failed to write to backend: %d, %s", errno, mxs_strerror(errno)); + error(); + } + + m_queue.push_front(buf); + more = false; + break; + } + } + } +} + +uint32_t LocalClient::poll_handler(struct mxs_poll_data* data, int wid, uint32_t events) +{ + LocalClient* client = static_cast(data); + client->process(events); + return 0; +} + +LocalClient* LocalClient::create(MXS_SESSION* session, SERVICE* service) +{ + LocalClient* rval = NULL; + + if (service->ports && service->ports->port > 0) + { + sockaddr_storage addr; + int fd = open_network_socket(MXS_SOCKET_NETWORK, &addr, "127.0.0.1", + service->ports->port); + + if (connect(fd, (struct sockaddr*)&addr, sizeof(addr)) == 0 || errno == EINPROGRESS) + { + LocalClient* relay = new (std::nothrow) LocalClient(session, fd); + + if (relay) + { + mxs::Worker* worker = mxs::Worker::get_current(); + + if (worker->add_fd(fd, poll_events, (MXS_POLL_DATA*)relay)) + { + rval = relay; + } + else + { + delete rval; + rval = NULL; + } + } + } + } + + return rval; +} diff --git a/server/modules/filter/tee/local_client.hh b/server/modules/filter/tee/local_client.hh new file mode 100644 index 000000000..6a3b21ea4 --- /dev/null +++ b/server/modules/filter/tee/local_client.hh @@ -0,0 +1,75 @@ +#pragma once +/* + * Copyright (c) 2016 MariaDB Corporation Ab + * + * Use of this software is governed by the Business Source License included + * in the LICENSE.TXT file and at www.mariadb.com/bsl11. + * + * Change Date: 2019-07-01 + * + * On the date above, in accordance with the Business Source License, use + * of this software will be governed by version 2 or later of the General + * Public License. + */ + +#include + +#include + +#include +#include +#include + +/** A DCB-like client abstraction which ignores responses */ +class LocalClient: public MXS_POLL_DATA +{ + LocalClient(const LocalClient&); + LocalClient& operator=(const LocalClient&); + +public: + ~LocalClient(); + + /** + * Create a local client for a service + * + * @param session Client session + * @param service Service to connect to + * + * @return New virtual client or NULL on error + */ + static LocalClient* create(MXS_SESSION* session, SERVICE* service); + + /** + * Queue a new query for execution + * + * @param buffer Buffer containing the query + * + * @return True if query was successfully queued + */ + bool query(GWBUF* buffer); + +private: + LocalClient(MXS_SESSION* session, int fd); + static uint32_t poll_handler(struct mxs_poll_data* data, int wid, uint32_t events); + void process(uint32_t events); + GWBUF* read_complete_packet(); + void drain_queue(); + void error(); + + /** Client states */ + enum vc_state + { + VC_WAITING_HANDSHAKE, // Initial state + VC_RESPONSE_SENT, // Handshake received and response sent + VC_OK, // Authentication is complete, ready for queries + VC_ERROR // Something went wrong + }; + + vc_state m_state; + int m_sock; + mxs::Buffer m_partial; + size_t m_expected_bytes; + std::deque m_queue; + MXS_SESSION* m_session; + MySQLProtocol m_proto; +};