diff --git a/include/maxscale/listener.hh b/include/maxscale/listener.hh index f890d3af7..63ad66282 100644 --- a/include/maxscale/listener.hh +++ b/include/maxscale/listener.hh @@ -186,6 +186,16 @@ public: return m_type; } + /** + * Mark authentication as failed + * + * This updates the number of failures that have occurred from this host. If the number of authentications + * exceeds a certain value, any attempts to connect from the remote in quesion will be rejected. + * + * @param remote The address where the connection originated + */ + void mark_auth_as_failed(const std::string& remote); + // Functions that are temporarily public bool create_listener_config(const char* filename); struct users* users() const; diff --git a/server/core/listener.cc b/server/core/listener.cc index 53edaac57..da5963928 100644 --- a/server/core/listener.cc +++ b/server/core/listener.cc @@ -21,6 +21,7 @@ #include #include +#include #include #include #include @@ -40,6 +41,9 @@ #include "internal/modules.hh" #include "internal/session.hh" +using Clock = std::chrono::steady_clock; +using std::chrono::seconds; + static std::list all_listeners; static std::mutex listener_lock; @@ -47,6 +51,56 @@ static RSA* rsa_512 = NULL; static RSA* rsa_1024 = NULL; static RSA* tmp_rsa_callback(SSL* s, int is_export, int keylength); +// TODO: Make these configurable +constexpr int MAX_FAILURES = 25; +constexpr int BLOCK_TIME = 60; + +namespace +{ +class RateLimit +{ +public: + void auth_failed(const std::string& remote) + { + auto& u = m_failures[remote]; + u.last_failure = Clock::now(); + u.failures++; + } + + bool is_blocked(const std::string& remote) + { + bool rval = false; + auto it = m_failures.find(remote); + + if (it != m_failures.end()) + { + auto& u = it->second; + + if (Clock::now() - u.last_failure > seconds(BLOCK_TIME)) + { + u.last_failure = Clock::now(); + u.failures = 0; + } + + rval = u.failures >= MAX_FAILURES; + } + + return rval; + } + +private: + struct Failure + { + Clock::time_point last_failure = Clock::now(); + uint32_t failures = 0; + }; + + std::unordered_map m_failures; +}; + +thread_local RateLimit rate_limit; +} + Listener::Listener(SERVICE* service, const std::string& name, const std::string& address, uint16_t port, const std::string& protocol, const std::string& authenticator, const std::string& auth_opts, void* auth_instance, SSL_LISTENER* ssl) @@ -968,6 +1022,12 @@ static ClientConn accept_one_connection(int fd) } configure_network_socket(conn.fd, conn.addr.ss_family); + + if (rate_limit.is_blocked(conn.host)) + { + close(conn.fd); + conn.fd = -1; + } } else if (errno != EAGAIN && errno != EWOULDBLOCK) { @@ -1177,3 +1237,8 @@ void Listener::accept_connections() } } } + +void Listener::mark_auth_as_failed(const std::string& remote) +{ + rate_limit.auth_failed(remote); +} diff --git a/server/modules/protocol/MySQL/mariadbclient/mysql_client.cc b/server/modules/protocol/MySQL/mariadbclient/mysql_client.cc index 8bbafda32..9caf6aed0 100644 --- a/server/modules/protocol/MySQL/mariadbclient/mysql_client.cc +++ b/server/modules/protocol/MySQL/mariadbclient/mysql_client.cc @@ -818,6 +818,8 @@ static int gw_read_do_authentication(DCB* dcb, GWBUF* read_buffer, int nbytes_re { protocol->protocol_auth_state = MXS_AUTH_STATE_FAILED; mysql_client_auth_error_handling(dcb, auth_val, next_sequence); + mxb_assert(dcb->session->listener); + dcb->session->listener->mark_auth_as_failed(dcb->remote); /** * Close DCB and which will release MYSQL_session */