diff --git a/media/sctp/sctp_transport.cc b/media/sctp/sctp_transport.cc index 5878f459f4..539eebd50e 100644 --- a/media/sctp/sctp_transport.cc +++ b/media/sctp/sctp_transport.cc @@ -20,6 +20,7 @@ enum PreservedErrno { // Successful return value from usrsctp callbacks. Is not actually used by // usrsctp, but all example programs for usrsctp use 1 as their return value. constexpr int kSctpSuccessReturn = 1; +constexpr int kSctpErrorReturn = 0; } // namespace @@ -27,7 +28,6 @@ constexpr int kSctpSuccessReturn = 1; #include #include -#include #include #include @@ -96,6 +96,21 @@ enum { // Should only be modified by UsrSctpWrapper. ABSL_CONST_INIT cricket::SctpTransportMap* g_transport_map_ = nullptr; +// Helper that will call C's free automatically. +// TODO(b/181900299): Figure out why unique_ptr with a custom deleter is causing +// issues in a certain build environment. +class AutoFreedPointer { + public: + explicit AutoFreedPointer(void* ptr) : ptr_(ptr) {} + AutoFreedPointer(AutoFreedPointer&& o) : ptr_(o.ptr_) { o.ptr_ = nullptr; } + ~AutoFreedPointer() { free(ptr_); } + + void* get() const { return ptr_; } + + private: + void* ptr_; +}; + // Helper for logging SCTP messages. #if defined(__GNUC__) __attribute__((__format__(__printf__, 1, 2))) @@ -252,31 +267,20 @@ class SctpTransportMap { return map_.erase(id) > 0; } - // Must be called on the transport's network thread to protect against - // simultaneous deletion/deregistration of the transport; if that's not - // guaranteed, use ExecuteWithLock. - SctpTransport* Retrieve(uintptr_t id) const { - webrtc::MutexLock lock(&lock_); - SctpTransport* transport = RetrieveWhileHoldingLock(id); - if (transport) { - RTC_DCHECK_RUN_ON(transport->network_thread()); - } - return transport; - } - // Posts |action| to the network thread of the transport identified by |id| // and returns true if found, all while holding a lock to protect against the // transport being simultaneously deleted/deregistered, or returns false if // not found. - bool PostToTransportThread(uintptr_t id, - std::function action) const { + template + bool PostToTransportThread(uintptr_t id, F action) const { webrtc::MutexLock lock(&lock_); SctpTransport* transport = RetrieveWhileHoldingLock(id); if (!transport) { return false; } transport->network_thread_->PostTask(ToQueuedTask( - transport->task_safety_, [transport, action]() { action(transport); })); + transport->task_safety_, + [transport, action{std::move(action)}]() { action(transport); })); return true; } @@ -429,7 +433,7 @@ class SctpTransport::UsrSctpWrapper { if (!found) { RTC_LOG(LS_ERROR) << "OnSctpOutboundPacket: Failed to get transport for socket ID " - << addr; + << addr << "; possibly was already destroyed."; return EINVAL; } @@ -447,28 +451,46 @@ class SctpTransport::UsrSctpWrapper { struct sctp_rcvinfo rcv, int flags, void* ulp_info) { - SctpTransport* transport = GetTransportFromSocket(sock); - if (!transport) { + AutoFreedPointer owned_data(data); + + absl::optional id = GetTransportIdFromSocket(sock); + if (!id) { RTC_LOG(LS_ERROR) - << "OnSctpInboundPacket: Failed to get transport for socket " << sock - << "; possibly was already destroyed."; - free(data); - return 0; + << "OnSctpInboundPacket: Failed to get transport ID from socket " + << sock; + return kSctpErrorReturn; } - // Sanity check that both methods of getting the SctpTransport pointer - // yield the same result. - RTC_CHECK_EQ(transport, static_cast(ulp_info)); - int result = - transport->OnDataOrNotificationFromSctp(data, length, rcv, flags); - free(data); - return result; + + if (!g_transport_map_) { + RTC_LOG(LS_ERROR) + << "OnSctpInboundPacket called after usrsctp uninitialized?"; + return kSctpErrorReturn; + } + // PostsToTransportThread protects against the transport being + // simultaneously deregistered/deleted, since this callback may come from + // the SCTP timer thread and thus race with the network thread. + bool found = g_transport_map_->PostToTransportThread( + *id, [owned_data{std::move(owned_data)}, length, rcv, + flags](SctpTransport* transport) { + transport->OnDataOrNotificationFromSctp(owned_data.get(), length, rcv, + flags); + }); + if (!found) { + RTC_LOG(LS_ERROR) + << "OnSctpInboundPacket: Failed to get transport for socket ID " + << *id << "; possibly was already destroyed."; + return kSctpErrorReturn; + } + return kSctpSuccessReturn; } - static SctpTransport* GetTransportFromSocket(struct socket* sock) { + static absl::optional GetTransportIdFromSocket( + struct socket* sock) { + absl::optional ret; struct sockaddr* addrs = nullptr; int naddrs = usrsctp_getladdrs(sock, 0, &addrs); if (naddrs <= 0 || addrs[0].sa_family != AF_CONN) { - return nullptr; + return ret; } // usrsctp_getladdrs() returns the addresses bound to this socket, which // contains the SctpTransport id as sconn_addr. Read the id, @@ -477,17 +499,10 @@ class SctpTransport::UsrSctpWrapper { // id of the transport that created them, so [0] is as good as any other. struct sockaddr_conn* sconn = reinterpret_cast(&addrs[0]); - if (!g_transport_map_) { - RTC_LOG(LS_ERROR) - << "GetTransportFromSocket called after usrsctp uninitialized?"; - usrsctp_freeladdrs(addrs); - return nullptr; - } - SctpTransport* transport = g_transport_map_->Retrieve( - reinterpret_cast(sconn->sconn_addr)); + ret = reinterpret_cast(sconn->sconn_addr); usrsctp_freeladdrs(addrs); - return transport; + return ret; } // TODO(crbug.com/webrtc/11899): This is a legacy callback signature, remove @@ -496,14 +511,26 @@ class SctpTransport::UsrSctpWrapper { // Fired on our I/O thread. SctpTransport::OnPacketReceived() gets // a packet containing acknowledgments, which goes into usrsctp_conninput, // and then back here. - SctpTransport* transport = GetTransportFromSocket(sock); - if (!transport) { + absl::optional id = GetTransportIdFromSocket(sock); + if (!id) { RTC_LOG(LS_ERROR) - << "SendThresholdCallback: Failed to get transport for socket " - << sock << "; possibly was already destroyed."; + << "SendThresholdCallback: Failed to get transport ID from socket " + << sock; return 0; } - transport->OnSendThresholdCallback(); + if (!g_transport_map_) { + RTC_LOG(LS_ERROR) + << "SendThresholdCallback called after usrsctp uninitialized?"; + return 0; + } + bool found = g_transport_map_->PostToTransportThread( + *id, + [](SctpTransport* transport) { transport->OnSendThresholdCallback(); }); + if (!found) { + RTC_LOG(LS_ERROR) + << "SendThresholdCallback: Failed to get transport for socket ID " + << *id << "; possibly was already destroyed."; + } return 0; } @@ -513,17 +540,26 @@ class SctpTransport::UsrSctpWrapper { // Fired on our I/O thread. SctpTransport::OnPacketReceived() gets // a packet containing acknowledgments, which goes into usrsctp_conninput, // and then back here. - SctpTransport* transport = GetTransportFromSocket(sock); - if (!transport) { + absl::optional id = GetTransportIdFromSocket(sock); + if (!id) { RTC_LOG(LS_ERROR) - << "SendThresholdCallback: Failed to get transport for socket " - << sock << "; possibly was already destroyed."; + << "SendThresholdCallback: Failed to get transport ID from socket " + << sock; return 0; } - // Sanity check that both methods of getting the SctpTransport pointer - // yield the same result. - RTC_CHECK_EQ(transport, static_cast(ulp_info)); - transport->OnSendThresholdCallback(); + if (!g_transport_map_) { + RTC_LOG(LS_ERROR) + << "SendThresholdCallback called after usrsctp uninitialized?"; + return 0; + } + bool found = g_transport_map_->PostToTransportThread( + *id, + [](SctpTransport* transport) { transport->OnSendThresholdCallback(); }); + if (!found) { + RTC_LOG(LS_ERROR) + << "SendThresholdCallback: Failed to get transport for socket ID " + << *id << "; possibly was already destroyed."; + } return 0; } }; @@ -1175,24 +1211,25 @@ void SctpTransport::OnPacketFromSctpToNetwork( rtc::PacketOptions(), PF_NORMAL); } -int SctpTransport::InjectDataOrNotificationFromSctpForTesting( +void SctpTransport::InjectDataOrNotificationFromSctpForTesting( const void* data, size_t length, struct sctp_rcvinfo rcv, int flags) { - return OnDataOrNotificationFromSctp(data, length, rcv, flags); + OnDataOrNotificationFromSctp(data, length, rcv, flags); } -int SctpTransport::OnDataOrNotificationFromSctp(const void* data, - size_t length, - struct sctp_rcvinfo rcv, - int flags) { +void SctpTransport::OnDataOrNotificationFromSctp(const void* data, + size_t length, + struct sctp_rcvinfo rcv, + int flags) { + RTC_DCHECK_RUN_ON(network_thread_); // If data is NULL, the SCTP association has been closed. if (!data) { RTC_LOG(LS_INFO) << debug_name_ << "->OnDataOrNotificationFromSctp(...): " "No data; association closed."; - return kSctpSuccessReturn; + return; } // Handle notifications early. @@ -1205,14 +1242,10 @@ int SctpTransport::OnDataOrNotificationFromSctp(const void* data, << "->OnDataOrNotificationFromSctp(...): SCTP notification" << " length=" << length; - // Copy and dispatch asynchronously rtc::CopyOnWriteBuffer notification(reinterpret_cast(data), length); - network_thread_->PostTask(ToQueuedTask( - task_safety_, [this, notification = std::move(notification)]() { - OnNotificationFromSctp(notification); - })); - return kSctpSuccessReturn; + OnNotificationFromSctp(notification); + return; } // Log data chunk @@ -1230,7 +1263,7 @@ int SctpTransport::OnDataOrNotificationFromSctp(const void* data, // Unexpected PPID, dropping RTC_LOG(LS_ERROR) << "Received an unknown PPID " << ppid << " on an SCTP packet. Dropping."; - return kSctpSuccessReturn; + return; } // Expect only continuation messages belonging to the same SID. The SCTP @@ -1266,7 +1299,7 @@ int SctpTransport::OnDataOrNotificationFromSctp(const void* data, if (partial_incoming_message_.size() < kSctpSendBufferSize) { // We still have space in the buffer. Continue buffering chunks until // the message is complete before handing it out. - return kSctpSuccessReturn; + return; } else { // The sender is exceeding the maximum message size that we announced. // Spit out a warning but still hand out the partial message. Note that @@ -1280,18 +1313,9 @@ int SctpTransport::OnDataOrNotificationFromSctp(const void* data, } } - // Dispatch the complete message. - // The ownership of the packet transfers to |invoker_|. Using - // CopyOnWriteBuffer is the most convenient way to do this. - network_thread_->PostTask(webrtc::ToQueuedTask( - task_safety_, [this, params = std::move(params), - message = partial_incoming_message_]() { - OnDataFromSctpToTransport(params, message); - })); - - // Reset the message buffer + // Dispatch the complete message and reset the message buffer. + OnDataFromSctpToTransport(params, partial_incoming_message_); partial_incoming_message_.Clear(); - return kSctpSuccessReturn; } void SctpTransport::OnDataFromSctpToTransport( diff --git a/media/sctp/sctp_transport.h b/media/sctp/sctp_transport.h index bd166ef332..e357e706ee 100644 --- a/media/sctp/sctp_transport.h +++ b/media/sctp/sctp_transport.h @@ -96,10 +96,10 @@ class SctpTransport : public SctpTransportInternal, void set_debug_name_for_testing(const char* debug_name) override { debug_name_ = debug_name; } - int InjectDataOrNotificationFromSctpForTesting(const void* data, - size_t length, - struct sctp_rcvinfo rcv, - int flags); + void InjectDataOrNotificationFromSctpForTesting(const void* data, + size_t length, + struct sctp_rcvinfo rcv, + int flags); // Exposed to allow Post call from c-callbacks. // TODO(deadbeef): Remove this or at least make it return a const pointer. @@ -180,12 +180,12 @@ class SctpTransport : public SctpTransportInternal, // Called using |invoker_| to send packet on the network. void OnPacketFromSctpToNetwork(const rtc::CopyOnWriteBuffer& buffer); - // Called on the SCTP thread. + // Called on the network thread. // Flags are standard socket API flags (RFC 6458). - int OnDataOrNotificationFromSctp(const void* data, - size_t length, - struct sctp_rcvinfo rcv, - int flags); + void OnDataOrNotificationFromSctp(const void* data, + size_t length, + struct sctp_rcvinfo rcv, + int flags); // Called using |invoker_| to decide what to do with the data. void OnDataFromSctpToTransport(const ReceiveDataParams& params, const rtc::CopyOnWriteBuffer& buffer); diff --git a/media/sctp/sctp_transport_unittest.cc b/media/sctp/sctp_transport_unittest.cc index 98a91225b2..120f4e5a27 100644 --- a/media/sctp/sctp_transport_unittest.cc +++ b/media/sctp/sctp_transport_unittest.cc @@ -282,8 +282,8 @@ TEST_F(SctpTransportTest, MessageInterleavedWithNotification) { meta.rcv_tsn = 42; meta.rcv_cumtsn = 42; chunk.SetData("meow?", 5); - EXPECT_EQ(1, transport1->InjectDataOrNotificationFromSctpForTesting( - chunk.data(), chunk.size(), meta, 0)); + transport1->InjectDataOrNotificationFromSctpForTesting(chunk.data(), + chunk.size(), meta, 0); // Inject a notification in between chunks. union sctp_notification notification; @@ -292,15 +292,15 @@ TEST_F(SctpTransportTest, MessageInterleavedWithNotification) { notification.sn_header.sn_type = SCTP_PEER_ADDR_CHANGE; notification.sn_header.sn_flags = 0; notification.sn_header.sn_length = sizeof(notification); - EXPECT_EQ(1, transport1->InjectDataOrNotificationFromSctpForTesting( - ¬ification, sizeof(notification), {0}, MSG_NOTIFICATION)); + transport1->InjectDataOrNotificationFromSctpForTesting( + ¬ification, sizeof(notification), {0}, MSG_NOTIFICATION); // Inject chunk 2/2 meta.rcv_tsn = 42; meta.rcv_cumtsn = 43; chunk.SetData(" rawr!", 6); - EXPECT_EQ(1, transport1->InjectDataOrNotificationFromSctpForTesting( - chunk.data(), chunk.size(), meta, MSG_EOR)); + transport1->InjectDataOrNotificationFromSctpForTesting( + chunk.data(), chunk.size(), meta, MSG_EOR); // Expect the message to contain both chunks. EXPECT_TRUE_WAIT(ReceivedData(&recv1, 1, "meow? rawr!"), kDefaultTimeout);