338 lines
7.8 KiB
C++
338 lines
7.8 KiB
C++
/*
|
|
* Copyright (c) 2016 MariaDB Corporation Ab
|
|
*
|
|
* Use of this software is governed by the Business Source License included
|
|
* in the LICENSE.TXT file and at www.mariadb.com/bsl11.
|
|
*
|
|
* Change Date: 2022-01-01
|
|
*
|
|
* On the date above, in accordance with the Business Source License, use
|
|
* of this software will be governed by version 2 or later of the General
|
|
* Public License.
|
|
*/
|
|
|
|
#include <maxscale/protocol/mariadb_client.hh>
|
|
#include <maxscale/routingworker.hh>
|
|
#include <maxscale/utils.h>
|
|
#include <maxbase/host.hh>
|
|
|
|
#ifdef EPOLLRDHUP
|
|
#define ERROR_EVENTS (EPOLLRDHUP | EPOLLHUP | EPOLLERR)
|
|
#else
|
|
#define ERROR_EVENTS (EPOLLHUP | EPOLLERR)
|
|
#endif
|
|
|
|
static const uint32_t poll_events = EPOLLIN | EPOLLOUT | EPOLLET | ERROR_EVENTS;
|
|
|
|
LocalClient::LocalClient(MYSQL_session* session, MySQLProtocol* proto, int fd)
|
|
: m_state(VC_WAITING_HANDSHAKE)
|
|
, m_sock(fd)
|
|
, m_expected_bytes(0)
|
|
, m_client(*session)
|
|
, m_protocol(*proto)
|
|
, m_self_destruct(false)
|
|
{
|
|
MXB_POLL_DATA::handler = LocalClient::poll_handler;
|
|
m_protocol.owner_dcb = nullptr;
|
|
m_protocol.stored_query = nullptr;
|
|
}
|
|
|
|
LocalClient::~LocalClient()
|
|
{
|
|
if (m_state != VC_ERROR)
|
|
{
|
|
close();
|
|
}
|
|
}
|
|
|
|
bool LocalClient::queue_query(GWBUF* buffer)
|
|
{
|
|
GWBUF* my_buf = NULL;
|
|
|
|
if (m_state != VC_ERROR && (my_buf = gwbuf_deep_clone(buffer)))
|
|
{
|
|
m_queue.push_back(my_buf);
|
|
|
|
if (m_state == VC_OK)
|
|
{
|
|
drain_queue();
|
|
}
|
|
}
|
|
|
|
return my_buf != NULL;
|
|
}
|
|
|
|
void LocalClient::self_destruct()
|
|
{
|
|
GWBUF* buffer = mysql_create_com_quit(NULL, 0);
|
|
queue_query(buffer);
|
|
gwbuf_free(buffer);
|
|
m_self_destruct = true;
|
|
}
|
|
|
|
void LocalClient::close()
|
|
{
|
|
mxb::Worker* worker = mxb::Worker::get_current();
|
|
mxb_assert(worker);
|
|
worker->remove_fd(m_sock);
|
|
::close(m_sock);
|
|
}
|
|
|
|
void LocalClient::error()
|
|
{
|
|
if (m_state != VC_ERROR)
|
|
{
|
|
close();
|
|
m_state = VC_ERROR;
|
|
}
|
|
}
|
|
|
|
void LocalClient::process(uint32_t events)
|
|
{
|
|
|
|
if (events & EPOLLIN)
|
|
{
|
|
GWBUF* buf = read_complete_packet();
|
|
|
|
if (buf)
|
|
{
|
|
if (m_state == VC_WAITING_HANDSHAKE)
|
|
{
|
|
if (gw_decode_mysql_server_handshake(&m_protocol, GWBUF_DATA(buf) + MYSQL_HEADER_LEN) == 0)
|
|
{
|
|
GWBUF* response = gw_generate_auth_response(&m_client, &m_protocol, false, false, 0);
|
|
m_queue.push_front(response);
|
|
m_state = VC_RESPONSE_SENT;
|
|
}
|
|
else
|
|
{
|
|
error();
|
|
}
|
|
}
|
|
else if (m_state == VC_RESPONSE_SENT)
|
|
{
|
|
if (mxs_mysql_is_ok_packet(buf))
|
|
{
|
|
m_state = VC_OK;
|
|
}
|
|
else
|
|
{
|
|
error();
|
|
}
|
|
}
|
|
|
|
gwbuf_free(buf);
|
|
}
|
|
}
|
|
|
|
if (events & EPOLLOUT)
|
|
{
|
|
/** Queue is drained */
|
|
}
|
|
|
|
if (events & ERROR_EVENTS)
|
|
{
|
|
error();
|
|
}
|
|
|
|
if (m_queue.size() && m_state != VC_ERROR && m_state != VC_WAITING_HANDSHAKE)
|
|
{
|
|
drain_queue();
|
|
}
|
|
else if (m_state == VC_ERROR && m_self_destruct)
|
|
{
|
|
delete this;
|
|
}
|
|
}
|
|
|
|
GWBUF* LocalClient::read_complete_packet()
|
|
{
|
|
GWBUF* rval = NULL;
|
|
|
|
while (true)
|
|
{
|
|
uint8_t buffer[1024];
|
|
int rc = read(m_sock, buffer, sizeof(buffer));
|
|
|
|
if (rc == -1)
|
|
{
|
|
if (errno != EAGAIN && errno != EWOULDBLOCK)
|
|
{
|
|
MXS_ERROR("Failed to read from backend: %d, %s", errno, mxs_strerror(errno));
|
|
error();
|
|
}
|
|
break;
|
|
}
|
|
|
|
mxs::Buffer chunk(buffer, rc);
|
|
m_partial.append(chunk);
|
|
size_t len = m_partial.length();
|
|
|
|
if (m_expected_bytes == 0 && len >= 3)
|
|
{
|
|
mxs::Buffer::iterator iter = m_partial.begin();
|
|
m_expected_bytes = MYSQL_HEADER_LEN;
|
|
m_expected_bytes += *iter++;
|
|
m_expected_bytes += (*iter++ << 8);
|
|
m_expected_bytes += (*iter++ << 16);
|
|
}
|
|
|
|
if (len >= m_expected_bytes)
|
|
{
|
|
/** Read complete packet. Reset expected byte count and make
|
|
* the buffer contiguous. */
|
|
m_expected_bytes = 0;
|
|
m_partial.make_contiguous();
|
|
rval = m_partial.release();
|
|
break;
|
|
}
|
|
}
|
|
|
|
return rval;
|
|
}
|
|
|
|
void LocalClient::drain_queue()
|
|
{
|
|
bool more = true;
|
|
|
|
while (m_queue.size() && more)
|
|
{
|
|
/** Grab a buffer from the queue */
|
|
GWBUF* buf = m_queue.front().release();
|
|
m_queue.pop_front();
|
|
|
|
while (buf)
|
|
{
|
|
int rc = write(m_sock, GWBUF_DATA(buf), GWBUF_LENGTH(buf));
|
|
|
|
if (rc > 0)
|
|
{
|
|
buf = gwbuf_consume(buf, rc);
|
|
}
|
|
else
|
|
{
|
|
if (rc == -1 && errno != EAGAIN && errno != EWOULDBLOCK)
|
|
{
|
|
MXS_ERROR("Failed to write to backend: %d, %s", errno, mxs_strerror(errno));
|
|
error();
|
|
}
|
|
|
|
m_queue.push_front(buf);
|
|
more = false;
|
|
break;
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
uint32_t LocalClient::poll_handler(MXB_POLL_DATA* data, MXB_WORKER* worker, uint32_t events)
|
|
{
|
|
LocalClient* client = static_cast<LocalClient*>(data);
|
|
client->process(events);
|
|
return 0;
|
|
}
|
|
|
|
namespace
|
|
{
|
|
using namespace maxbase;
|
|
|
|
int connect_socket(const Host& host)
|
|
{
|
|
int fd = -1;
|
|
struct sockaddr* sock_addr = nullptr;
|
|
socklen_t sock_len = 0;
|
|
|
|
switch (host.type())
|
|
{
|
|
case Host::Type::Invalid:
|
|
break;
|
|
|
|
case Host::Type::UnixDomainSocket:
|
|
{
|
|
struct sockaddr_un addr;
|
|
sock_addr = reinterpret_cast<struct sockaddr*>(&addr);
|
|
sock_len = sizeof(addr);
|
|
fd = open_unix_socket(MXS_SOCKET_NETWORK, &addr, host.address().c_str());
|
|
}
|
|
break;
|
|
|
|
default:
|
|
{
|
|
struct sockaddr_storage addr;
|
|
sock_addr = reinterpret_cast<struct sockaddr*>(&addr);
|
|
sock_len = sizeof(addr);
|
|
fd = open_network_socket(MXS_SOCKET_NETWORK, &addr, host.address().c_str(), host.port());
|
|
}
|
|
break;
|
|
}
|
|
|
|
if (fd >= 0 && sock_addr && sock_len)
|
|
{
|
|
bool ok = connect(fd, sock_addr, sock_len) == 0 || errno == EINPROGRESS;
|
|
if (!ok)
|
|
{
|
|
::close(fd);
|
|
fd = -1;
|
|
}
|
|
}
|
|
|
|
return fd;
|
|
}
|
|
}
|
|
|
|
LocalClient* LocalClient::create(MYSQL_session* session, MySQLProtocol* proto, const char* ip, uint64_t port)
|
|
{
|
|
LocalClient* rval = NULL;
|
|
|
|
int fd = connect_socket(maxbase::Host(ip, port));
|
|
|
|
if (fd >= 0)
|
|
{
|
|
LocalClient* relay = new(std::nothrow) LocalClient(session, proto, fd);
|
|
|
|
if (relay)
|
|
{
|
|
mxb::Worker* worker = mxb::Worker::get_current();
|
|
|
|
if (worker->add_fd(fd, poll_events, (MXB_POLL_DATA*)relay))
|
|
{
|
|
rval = relay;
|
|
}
|
|
else
|
|
{
|
|
relay->m_state = VC_ERROR;
|
|
delete rval;
|
|
rval = NULL;
|
|
}
|
|
}
|
|
}
|
|
|
|
if (rval == NULL && fd >= 0)
|
|
{
|
|
::close(fd);
|
|
}
|
|
return rval;
|
|
}
|
|
|
|
LocalClient* LocalClient::create(MYSQL_session* session, MySQLProtocol* proto, SERVICE* service)
|
|
{
|
|
LocalClient* rval = NULL;
|
|
|
|
for (const auto& listener : listener_find_by_service(service))
|
|
{
|
|
if (listener->port() > 0)
|
|
{
|
|
/** Pick the first network listener */
|
|
rval = create(session, proto, "127.0.0.1", listener->port());
|
|
break;
|
|
}
|
|
}
|
|
|
|
return rval;
|
|
}
|
|
|
|
LocalClient* LocalClient::create(MYSQL_session* session, MySQLProtocol* proto, SERVER* server)
|
|
{
|
|
return create(session, proto, server->address, server->port);
|
|
}
|