diff --git a/rtc_base/BUILD.gn b/rtc_base/BUILD.gn index 77bff8d6c5..49008a2ae3 100644 --- a/rtc_base/BUILD.gn +++ b/rtc_base/BUILD.gn @@ -835,6 +835,7 @@ rtc_library("rtc_base") { "system:no_unique_address", "system:rtc_export", "task_utils:pending_task_safety_flag", + "task_utils:repeating_task", "task_utils:to_queued_task", "third_party/base64", "third_party/sigslot", @@ -1425,6 +1426,7 @@ if (rtc_include_tests) { "memory:fifo_buffer", "synchronization:mutex", "synchronization:synchronization_unittests", + "task_utils:pending_task_safety_flag", "task_utils:to_queued_task", "third_party/sigslot", ] diff --git a/rtc_base/memory/BUILD.gn b/rtc_base/memory/BUILD.gn index 5c3dd0a5d1..838fbc68d4 100644 --- a/rtc_base/memory/BUILD.gn +++ b/rtc_base/memory/BUILD.gn @@ -20,12 +20,15 @@ rtc_library("aligned_malloc") { deps = [ "..:checks" ] } +# Test only utility. +# TODO: Tag with `testonly = true` once all depending targets are correctly +# tagged. rtc_library("fifo_buffer") { visibility = [ - "../../p2p:rtc_p2p", + ":unittests", "..:rtc_base_tests_utils", "..:rtc_base_unittests", - ":unittests", + "../../p2p:rtc_p2p", # This needs to be fixed. ] sources = [ "fifo_buffer.cc", @@ -34,6 +37,8 @@ rtc_library("fifo_buffer") { deps = [ "..:rtc_base", "../synchronization:mutex", + "../task_utils:pending_task_safety_flag", + "../task_utils:to_queued_task", ] } diff --git a/rtc_base/memory/fifo_buffer.cc b/rtc_base/memory/fifo_buffer.cc index 49e926719f..3fbea8dc20 100644 --- a/rtc_base/memory/fifo_buffer.cc +++ b/rtc_base/memory/fifo_buffer.cc @@ -104,7 +104,7 @@ StreamResult FifoBuffer::Read(void* buffer, // if we were full before, and now we're not, post an event if (!was_writable && copy > 0) { - PostEvent(owner_, SE_WRITE, 0); + PostEvent(SE_WRITE, 0); } } return result; @@ -129,7 +129,7 @@ StreamResult FifoBuffer::Write(const void* buffer, // if we didn't have any data to read before, and now we do, post an event if (!was_readable && copy > 0) { - PostEvent(owner_, SE_READ, 0); + PostEvent(SE_READ, 0); } } return result; @@ -155,7 +155,7 @@ void FifoBuffer::ConsumeReadData(size_t size) { read_position_ = (read_position_ + size) % buffer_length_; data_length_ -= size; if (!was_writable && size > 0) { - PostEvent(owner_, SE_WRITE, 0); + PostEvent(SE_WRITE, 0); } } @@ -185,7 +185,7 @@ void FifoBuffer::ConsumeWriteBuffer(size_t size) { const bool was_readable = (data_length_ > 0); data_length_ += size; if (!was_readable && size > 0) { - PostEvent(owner_, SE_READ, 0); + PostEvent(SE_READ, 0); } } diff --git a/rtc_base/memory/fifo_buffer.h b/rtc_base/memory/fifo_buffer.h index 04c4cbf33b..bf2edf6e24 100644 --- a/rtc_base/memory/fifo_buffer.h +++ b/rtc_base/memory/fifo_buffer.h @@ -15,6 +15,8 @@ #include "rtc_base/stream.h" #include "rtc_base/synchronization/mutex.h" +#include "rtc_base/task_utils/pending_task_safety_flag.h" +#include "rtc_base/task_utils/to_queued_task.h" namespace rtc { @@ -98,6 +100,12 @@ class FifoBuffer final : public StreamInterface { bool GetWriteRemaining(size_t* size) const; private: + void PostEvent(int events, int err) { + owner_->PostTask(webrtc::ToQueuedTask(task_safety_, [this, events, err]() { + SignalEvent(this, events, err); + })); + } + // Helper method that implements ReadOffset. Caller must acquire a lock // when calling this method. StreamResult ReadOffsetLocked(void* buffer, @@ -114,6 +122,8 @@ class FifoBuffer final : public StreamInterface { size_t* bytes_written) RTC_EXCLUSIVE_LOCKS_REQUIRED(mutex_); + webrtc::ScopedTaskSafety task_safety_; + // keeps the opened/closed state of the stream StreamState state_ RTC_GUARDED_BY(mutex_); // the allocated buffer @@ -125,7 +135,7 @@ class FifoBuffer final : public StreamInterface { // offset to the readable data size_t read_position_ RTC_GUARDED_BY(mutex_); // stream callbacks are dispatched on this thread - Thread* owner_; + Thread* const owner_; // object lock mutable webrtc::Mutex mutex_; RTC_DISALLOW_COPY_AND_ASSIGN(FifoBuffer); diff --git a/rtc_base/openssl_stream_adapter.cc b/rtc_base/openssl_stream_adapter.cc index 0f8c1fcb13..426100b150 100644 --- a/rtc_base/openssl_stream_adapter.cc +++ b/rtc_base/openssl_stream_adapter.cc @@ -35,6 +35,7 @@ #include "rtc_base/openssl_identity.h" #include "rtc_base/ssl_certificate.h" #include "rtc_base/stream.h" +#include "rtc_base/task_utils/to_queued_task.h" #include "rtc_base/thread.h" #include "rtc_base/time_utils.h" #include "system_wrappers/include/field_trial.h" @@ -51,7 +52,6 @@ namespace rtc { namespace { - // SRTP cipher suite table. |internal_name| is used to construct a // colon-separated profile strings which is needed by // SSL_CTX_set_tlsext_use_srtp(). @@ -284,6 +284,7 @@ bool ShouldAllowLegacyTLSProtocols() { OpenSSLStreamAdapter::OpenSSLStreamAdapter( std::unique_ptr stream) : SSLStreamAdapter(std::move(stream)), + owner_(rtc::Thread::Current()), state_(SSL_NONE), role_(SSL_CLIENT), ssl_read_needs_write_(false), @@ -297,6 +298,7 @@ OpenSSLStreamAdapter::OpenSSLStreamAdapter( support_legacy_tls_protocols_flag_(ShouldAllowLegacyTLSProtocols()) {} OpenSSLStreamAdapter::~OpenSSLStreamAdapter() { + timeout_task_.Stop(); Cleanup(0); } @@ -802,6 +804,33 @@ void OpenSSLStreamAdapter::OnEvent(StreamInterface* stream, } } +void OpenSSLStreamAdapter::PostEvent(int events, int err) { + owner_->PostTask(webrtc::ToQueuedTask( + task_safety_, [this, events, err]() { SignalEvent(this, events, err); })); +} + +void OpenSSLStreamAdapter::SetTimeout(int delay_ms) { + // We need to accept 0 delay here as well as >0 delay, because + // DTLSv1_get_timeout seems to frequently return 0 ms. + RTC_DCHECK_GE(delay_ms, 0); + RTC_DCHECK(!timeout_task_.Running()); + + timeout_task_ = webrtc::RepeatingTaskHandle::DelayedStart( + owner_, webrtc::TimeDelta::Millis(delay_ms), + [flag = task_safety_.flag(), this]() { + if (flag->alive()) { + RTC_LOG(LS_INFO) << "DTLS timeout expired"; + timeout_task_.Stop(); + DTLSv1_handle_timeout(ssl_); + ContinueSSL(); + } else { + RTC_NOTREACHED(); + } + // This callback will never run again (stopped above). + return webrtc::TimeDelta::PlusInfinity(); + }); +} + int OpenSSLStreamAdapter::BeginSSL() { RTC_DCHECK(state_ == SSL_CONNECTING); // The underlying stream has opened. @@ -852,7 +881,7 @@ int OpenSSLStreamAdapter::ContinueSSL() { RTC_DCHECK(state_ == SSL_CONNECTING); // Clear the DTLS timer - Thread::Current()->Clear(this, MSG_TIMEOUT); + timeout_task_.Stop(); const int code = (role_ == SSL_CLIENT) ? SSL_connect(ssl_) : SSL_accept(ssl_); const int ssl_error = SSL_get_error(ssl_, code); @@ -884,9 +913,7 @@ int OpenSSLStreamAdapter::ContinueSSL() { struct timeval timeout; if (DTLSv1_get_timeout(ssl_, &timeout)) { int delay = timeout.tv_sec * 1000 + timeout.tv_usec / 1000; - - Thread::Current()->PostDelayed(RTC_FROM_HERE, delay, this, MSG_TIMEOUT, - 0); + SetTimeout(delay); } } break; @@ -963,18 +990,7 @@ void OpenSSLStreamAdapter::Cleanup(uint8_t alert) { peer_cert_chain_.reset(); // Clear the DTLS timer - Thread::Current()->Clear(this, MSG_TIMEOUT); -} - -void OpenSSLStreamAdapter::OnMessage(Message* msg) { - // Process our own messages and then pass others to the superclass - if (MSG_TIMEOUT == msg->message_id) { - RTC_LOG(LS_INFO) << "DTLS timeout expired"; - DTLSv1_handle_timeout(ssl_); - ContinueSSL(); - } else { - StreamInterface::OnMessage(msg); - } + timeout_task_.Stop(); } SSL_CTX* OpenSSLStreamAdapter::SetupSSLContext() { diff --git a/rtc_base/openssl_stream_adapter.h b/rtc_base/openssl_stream_adapter.h index d4cde84d74..fbfccd6844 100644 --- a/rtc_base/openssl_stream_adapter.h +++ b/rtc_base/openssl_stream_adapter.h @@ -26,6 +26,8 @@ #include "rtc_base/ssl_stream_adapter.h" #include "rtc_base/stream.h" #include "rtc_base/system/rtc_export.h" +#include "rtc_base/task_utils/pending_task_safety_flag.h" +#include "rtc_base/task_utils/repeating_task.h" namespace rtc { @@ -145,7 +147,8 @@ class OpenSSLStreamAdapter final : public SSLStreamAdapter { SSL_CLOSED // Clean close }; - enum { MSG_TIMEOUT = MSG_MAX + 1 }; + void PostEvent(int events, int err); + void SetTimeout(int delay_ms); // The following three methods return 0 on success and a negative // error code on failure. The error code may be from OpenSSL or -1 @@ -169,9 +172,6 @@ class OpenSSLStreamAdapter final : public SSLStreamAdapter { void Error(const char* context, int err, uint8_t alert, bool signal); void Cleanup(uint8_t alert); - // Override MessageHandler - void OnMessage(Message* msg) override; - // Flush the input buffers by reading left bytes (for DTLS) void FlushInput(unsigned int left); @@ -192,6 +192,10 @@ class OpenSSLStreamAdapter final : public SSLStreamAdapter { !peer_certificate_digest_value_.empty(); } + rtc::Thread* const owner_; + webrtc::ScopedTaskSafety task_safety_; + webrtc::RepeatingTaskHandle timeout_task_; + SSLState state_; SSLRole role_; int ssl_error_code_; // valid when state_ == SSL_ERROR or SSL_CLOSED diff --git a/rtc_base/ssl_stream_adapter_unittest.cc b/rtc_base/ssl_stream_adapter_unittest.cc index bfbaf0f301..1ba2f3e259 100644 --- a/rtc_base/ssl_stream_adapter_unittest.cc +++ b/rtc_base/ssl_stream_adapter_unittest.cc @@ -26,6 +26,8 @@ #include "rtc_base/ssl_identity.h" #include "rtc_base/ssl_stream_adapter.h" #include "rtc_base/stream.h" +#include "rtc_base/task_utils/pending_task_safety_flag.h" +#include "rtc_base/task_utils/to_queued_task.h" #include "test/field_trial.h" using ::testing::Combine; @@ -214,7 +216,15 @@ class SSLDummyStreamBase : public rtc::StreamInterface, out_->Close(); } - protected: + private: + void PostEvent(int events, int err) { + thread_->PostTask(webrtc::ToQueuedTask(task_safety_, [this, events, err]() { + SignalEvent(this, events, err); + })); + } + + webrtc::ScopedTaskSafety task_safety_; + rtc::Thread* const thread_ = rtc::Thread::Current(); SSLStreamAdapterTestBase* test_base_; const std::string side_; rtc::StreamInterface* in_; @@ -276,10 +286,17 @@ class BufferQueueStream : public rtc::StreamInterface { protected: void NotifyReadableForTest() { PostEvent(rtc::SE_READ, 0); } - void NotifyWritableForTest() { PostEvent(rtc::SE_WRITE, 0); } private: + void PostEvent(int events, int err) { + thread_->PostTask(webrtc::ToQueuedTask(task_safety_, [this, events, err]() { + SignalEvent(this, events, err); + })); + } + + rtc::Thread* const thread_ = rtc::Thread::Current(); + webrtc::ScopedTaskSafety task_safety_; rtc::BufferQueue buffer_; }; diff --git a/rtc_base/stream.cc b/rtc_base/stream.cc index 1b0a4d759b..ee72f8d2b8 100644 --- a/rtc_base/stream.cc +++ b/rtc_base/stream.cc @@ -24,7 +24,6 @@ namespace rtc { /////////////////////////////////////////////////////////////////////////////// // StreamInterface /////////////////////////////////////////////////////////////////////////////// -StreamInterface::~StreamInterface() {} StreamResult StreamInterface::WriteAll(const void* data, size_t data_len, @@ -44,29 +43,12 @@ StreamResult StreamInterface::WriteAll(const void* data, return result; } -void StreamInterface::PostEvent(Thread* t, int events, int err) { - t->Post(RTC_FROM_HERE, this, MSG_POST_EVENT, - new StreamEventData(events, err)); -} - -void StreamInterface::PostEvent(int events, int err) { - PostEvent(Thread::Current(), events, err); -} - bool StreamInterface::Flush() { return false; } StreamInterface::StreamInterface() {} -void StreamInterface::OnMessage(Message* msg) { - if (MSG_POST_EVENT == msg->message_id) { - StreamEventData* pe = static_cast(msg->pdata); - SignalEvent(this, pe->events, pe->error); - delete msg->pdata; - } -} - /////////////////////////////////////////////////////////////////////////////// // StreamAdapterInterface /////////////////////////////////////////////////////////////////////////////// diff --git a/rtc_base/stream.h b/rtc_base/stream.h index 940bfb4ba4..9bf11a2405 100644 --- a/rtc_base/stream.h +++ b/rtc_base/stream.h @@ -48,16 +48,9 @@ enum StreamResult { SR_ERROR, SR_SUCCESS, SR_BLOCK, SR_EOS }; // SE_WRITE: Data can be written, so Write is likely to not return SR_BLOCK enum StreamEvent { SE_OPEN = 1, SE_READ = 2, SE_WRITE = 4, SE_CLOSE = 8 }; -struct StreamEventData : public MessageData { - int events, error; - StreamEventData(int ev, int er) : events(ev), error(er) {} -}; - -class RTC_EXPORT StreamInterface : public MessageHandlerAutoCleanup { +class RTC_EXPORT StreamInterface { public: - enum { MSG_POST_EVENT = 0xF1F1, MSG_MAX = MSG_POST_EVENT }; - - ~StreamInterface() override; + virtual ~StreamInterface() {} virtual StreamState GetState() const = 0; @@ -96,13 +89,6 @@ class RTC_EXPORT StreamInterface : public MessageHandlerAutoCleanup { // certain events will be raised in the future. sigslot::signal3 SignalEvent; - // Like calling SignalEvent, but posts a message to the specified thread, - // which will call SignalEvent. This helps unroll the stack and prevent - // re-entrancy. - void PostEvent(Thread* t, int events, int err); - // Like the aforementioned method, but posts to the current thread. - void PostEvent(int events, int err); - // Return true if flush is successful. virtual bool Flush(); @@ -125,9 +111,6 @@ class RTC_EXPORT StreamInterface : public MessageHandlerAutoCleanup { protected: StreamInterface(); - // MessageHandler Interface - void OnMessage(Message* msg) override; - private: RTC_DISALLOW_COPY_AND_ASSIGN(StreamInterface); };