From 8b7b7b4d3f984b00f796d37120dd8e5dc966246d Mon Sep 17 00:00:00 2001 From: Niclas Antti Date: Thu, 18 Apr 2019 15:05:51 +0300 Subject: [PATCH] Add a streamable class (Host) that represents an address and port, or a unix domain socket. --- maxutils/maxbase/include/maxbase/host.hh | 111 +++++++++++ maxutils/maxbase/src/CMakeLists.txt | 1 + maxutils/maxbase/src/host.cc | 225 +++++++++++++++++++++++ 3 files changed, 337 insertions(+) create mode 100644 maxutils/maxbase/include/maxbase/host.hh create mode 100644 maxutils/maxbase/src/host.cc diff --git a/maxutils/maxbase/include/maxbase/host.hh b/maxutils/maxbase/include/maxbase/host.hh new file mode 100644 index 000000000..90bb27744 --- /dev/null +++ b/maxutils/maxbase/include/maxbase/host.hh @@ -0,0 +1,111 @@ +/* + * Copyright (c) 2019 MariaDB Corporation Ab + * + * Use of this software is governed by the Business Source License included + * in the LICENSE.TXT file and at www.mariadb.com/bsl11. + * + * Change Date: 2022-01-01 + * + * On the date above, in accordance with the Business Source License, use + * of this software will be governed by version 2 or later of the General + * Public License. + */ +#pragma once + +#include + +#include +#include + +/** Host is a streamable class that represents an address and port, or a unix domain socket. + */ +namespace maxbase +{ + +class Host; + +std::ostream& operator<<(std::ostream&, const Host& host); +std::istream& operator>>(std::istream&, Host& host); + +class Host +{ +public: + enum class Type {Invalid, UnixDomainSocket, HostName, IPV4, IPV6}; // to_string() provided + static constexpr int DefaultPort = 3306; + + Host() = default; // type() returns Type::Invalid + + /** + * Constructor + * @param str. A string parsed according to this format (the brackets are real brackets): + * unix_domain_socket | addr | [addr] | [addr]:port + * 'addr' is a plain ipv4, ipv6, host name or unix domain socket. + * An ipv6 address with a port must use the format [ipv6]:port. + * A unix domain socket must start with a forward slash ('/') and must not specify a port. + * The default port is 3306. + */ + explicit Host(const std::string& input); + + /** + * Constructor + * @param addr. Plain ipv4, ipv6, host name or unix domain socket (no brackets or port specifiers). + * @param port. A valid port number. Ignored if 'addr' is a unix domain socket (starts with '/'). + */ + Host(const std::string& addr, int port); + + Type type() const; + bool is_valid() const; + const std::string& address() const; + int port() const; + + const std::string& org_input() const; // for better error messages +private: + void set_type(bool port_string_specified); // set m_type based on m_address and m_port + + std::string m_address; + int m_port = DefaultPort; + Type m_type = Type::Invalid; + std::string m_org_input; +}; + +std::string to_string(Host::Type type); + +// impl below +inline Host::Type Host::type() const +{ + return m_type; +} + +inline bool Host::is_valid() const +{ + return m_type != Type::Invalid; +} + +inline const std::string& Host::address() const +{ + return m_address; +} + +inline int Host::port() const +{ + return m_port; +} + +inline const std::string& Host::org_input() const +{ + return m_org_input; +} + +inline bool operator==(const Host& l, const Host& r) +{ + bool port_ok = (l.port() == r.port()) + || (l.type() == Host::Type::UnixDomainSocket && r.type() == Host::Type::UnixDomainSocket); + + return port_ok && l.address() == r.address() && l.type() == r.type(); +} + +inline bool operator!=(const Host& l, const Host& r) +{ + return !(l == r); +} +} diff --git a/maxutils/maxbase/src/CMakeLists.txt b/maxutils/maxbase/src/CMakeLists.txt index 835209e10..26dba3e0f 100644 --- a/maxutils/maxbase/src/CMakeLists.txt +++ b/maxutils/maxbase/src/CMakeLists.txt @@ -17,6 +17,7 @@ add_library(maxbase STATIC workertask.cc average.cc random.cc + host.cc ) if(HAVE_SYSTEMD) diff --git a/maxutils/maxbase/src/host.cc b/maxutils/maxbase/src/host.cc new file mode 100644 index 000000000..3cebbbd9c --- /dev/null +++ b/maxutils/maxbase/src/host.cc @@ -0,0 +1,225 @@ +/* + * Copyright (c) 2019 MariaDB Corporation Ab + * + * Use of this software is governed by the Business Source License included + * in the LICENSE.TXT file and at www.mariadb.com/bsl11. + * + * Change Date: 2022-01-01 + * + * On the date above, in accordance with the Business Source License, use + * of this software will be governed by version 2 or later of the General + * Public License. + */ +#include +#include + +#include +#include +#include + +namespace +{ +// Simple but not exhaustive address validation functions. +// An ipv4 address "x.x.x.x" cannot be a hostname (where x is a number), but pretty +// much anything else can. Call is_valid_hostname() last. +bool is_valid_ipv4(const std::string& ip) +{ + bool ret = ip.find_first_not_of("0123456789.") == std::string::npos + && (ip.length() <= 15 && ip.length() >= 7) + && std::count(begin(ip), end(ip), '.') == 3; + + return ret; +} + +bool is_valid_ipv6(const std::string& ip) +{ + auto invalid_char = [](char ch) { + bool valid = std::isxdigit(ch) || ch == ':' || ch == '.'; + return !valid; + }; + + bool ret = std::count(begin(ip), end(ip), ':') >= 2 + && std::none_of(begin(ip), end(ip), invalid_char) + && (ip.length() <= 45 && ip.length() >= 2); + + return ret; +} + +bool is_valid_hostname(const std::string& hn) +{ + auto invalid_char = [](char ch) { + bool valid = std::isalnum(ch) || ch == '_' || ch == '.'; + return !valid; + }; + + bool ret = std::none_of(begin(hn), end(hn), invalid_char) + && hn.front() != '_' + && (hn.length() <= 253 && hn.length() > 0); + + return ret; +} + +bool is_valid_socket(const std::string& addr) +{ + // Can't check the file system, the socket may not have been created yet. + // Just not bothering to check much, file names can be almost anything and errors are easy to spot. + bool ret = addr.front() == '/' + && addr.back() != '/'; // avoids the confusing error: Address already in use + + return ret; +} + +bool is_valid_port(int port) +{ + return 0 < port && port < (1 << 16); +} + +// Make sure the order here is the same as in Host::Type. +static std::vector host_type_names = {"Invalid", "UnixDomainSocket", "HostName", "IPV4", "IPV6"}; +} + +namespace maxbase +{ +std::string to_string(Host::Type type) +{ + size_t i = size_t(type); + return i >= host_type_names.size() ? "UNKNOWN" : host_type_names[i]; +} + +void Host::set_type(bool port_string_specified) +{ + if (is_valid_socket(m_address)) + { + if (!port_string_specified) + { + m_type = Type::UnixDomainSocket; + } + } + else if (is_valid_port(m_port)) + { + if (is_valid_ipv4(m_address)) + { + m_type = Type::IPV4; + } + else if (is_valid_ipv6(m_address)) + { + m_type = Type::IPV6; + } + else if (is_valid_hostname(m_address)) + { + m_type = Type::HostName; + } + } +} + +Host::Host(const std::string& in) +{ + m_org_input = in; + std::string input = maxbase::trimmed_copy(in); + + if (input.empty()) + { + return; + } + + std::string port_part; + + // 'ite' is left pointing into the input if there is an error in parsing. Not exhaustive error checking. + auto ite = input.begin(); + + if (*ite == '[') + { // expecting [address]:port, where :port is optional + auto last = std::find(begin(input), end(input), ']'); + std::copy(++ite, last, std::back_inserter(m_address)); + if (last != end(input)) + { + if (++last != end(input) && *last == ':' && last + 1 != end(input)) + { + ++last; + std::copy(last, end(input), std::back_inserter(port_part)); + last = end(input); + } + ite = last; + } + } + else + { + if (is_valid_ipv6(input)) + { + m_address = input; + ite = end(input); + } + else + { + // expecting address:port, where :port is optional => (hostnames with colons must use [xxx]:port) + auto colon = std::find(begin(input), end(input), ':'); + std::copy(begin(input), colon, std::back_inserter(m_address)); + ite = colon; + if (colon != end(input) && ++colon != end(input)) + { + std::copy(colon, end(input), std::back_inserter(port_part)); + ite = end(input); + } + } + } + + if (ite == end(input)) // if all input consumed + { + if (!port_part.empty()) + { + bool all_digits = std::all_of(begin(port_part), end(port_part), + [](char ch) { + return std::isdigit(ch); + }); + m_port = all_digits ? std::atoi(port_part.c_str()) : -1; + } + + set_type(!port_part.empty()); + } +} + +Host::Host(const std::string& addr, int port) +{ + m_org_input = addr; + m_address = addr; + m_port = port; + + if (!m_address.empty() && m_address.front() != '[') + { + set_type(false); + } +} + +std::ostream& operator<<(std::ostream& os, const Host& host) +{ + switch (host.type()) + { + case Host::Type::Invalid: + os << "INVALID input: '" << host.org_input() << "' parsed to " + << host.address() << ":" << host.port(); + break; + + case Host::Type::UnixDomainSocket: + os << host.address(); + break; + + case Host::Type::HostName: + case Host::Type::IPV4: + os << host.address() << ':' << host.port(); + break; + + case Host::Type::IPV6: + os << '[' << host.address() << "]:" << host.port(); + break; + } + return os; +} + +std::istream& operator>>(std::istream& is, Host& host) +{ + std::string input; + is >> input; + host = Host(input); + return is; +} +}