diff --git a/server/modules/protocol/MySQL/mariadb_client.cc b/server/modules/protocol/MySQL/mariadb_client.cc index 39e61299e..ff95265e4 100644 --- a/server/modules/protocol/MySQL/mariadb_client.cc +++ b/server/modules/protocol/MySQL/mariadb_client.cc @@ -14,6 +14,7 @@ #include #include #include +#include #ifdef EPOLLRDHUP #define ERROR_EVENTS (EPOLLRDHUP | EPOLLHUP | EPOLLERR) @@ -231,13 +232,61 @@ uint32_t LocalClient::poll_handler(MXB_POLL_DATA* data, MXB_WORKER* worker, uint return 0; } +namespace +{ +using namespace maxbase; + +int connect_socket(const Host& host) +{ + int fd = -1; + struct sockaddr* sock_addr = nullptr; + socklen_t sock_len = 0; + + switch (host.type()) + { + case Host::Type::Invalid: + break; + + case Host::Type::UnixDomainSocket: + { + struct sockaddr_un addr; + sock_addr = reinterpret_cast(&addr); + sock_len = sizeof(addr); + fd = open_unix_socket(MXS_SOCKET_NETWORK, &addr, host.address().c_str()); + } + break; + + default: + { + struct sockaddr_storage addr; + sock_addr = reinterpret_cast(&addr); + sock_len = sizeof(addr); + fd = open_network_socket(MXS_SOCKET_NETWORK, &addr, host.address().c_str(), host.port()); + } + break; + } + + if (fd >= 0 && sock_addr && sock_len) + { + bool ok = connect(fd, sock_addr, sock_len) == 0 || errno == EINPROGRESS; + if (!ok) + { + ::close(fd); + fd = -1; + } + } + + return fd; +} +} + LocalClient* LocalClient::create(MYSQL_session* session, MySQLProtocol* proto, const char* ip, uint64_t port) { LocalClient* rval = NULL; - sockaddr_storage addr; - int fd = open_network_socket(MXS_SOCKET_NETWORK, &addr, ip, port); - if (fd > 0 && (connect(fd, (struct sockaddr*)&addr, sizeof(addr)) == 0 || errno == EINPROGRESS)) + int fd = connect_socket(maxbase::Host(ip, port)); + + if (fd >= 0) { LocalClient* relay = new(std::nothrow) LocalClient(session, proto, fd); @@ -258,7 +307,7 @@ LocalClient* LocalClient::create(MYSQL_session* session, MySQLProtocol* proto, c } } - if (rval == NULL && fd > 0) + if (rval == NULL && fd >= 0) { ::close(fd); }