Reland "Fix race between destroying SctpTransport and receiving notification on timer thread."

This reverts commit 8a38b1cf681cd77f0d59a68fb45d8dedbd7d4cee.

Reason for reland: Problem was identified; has something to do with
the unique_ptr with the custom deleter.

Original change's description:
> Revert "Fix race between destroying SctpTransport and receiving notification on timer thread."
>
> This reverts commit a88fe7be146b9b85575504d4d5193c007f2e3de4.
>
> Reason for revert: Breaks downstream test, still investigating.
>
> Original change's description:
> > Fix race between destroying SctpTransport and receiving notification on timer thread.
> >
> > This gets rid of the SctpTransportMap::Retrieve method and forces
> > everything to go through PostToTransportThread, which behaves safely
> > with relation to the transport's destruction.
> >
> > Bug: webrtc:12467
> > Change-Id: Id4a723c2c985be2a368d2cc5c5e62deb04c509ab
> > Reviewed-on: https://webrtc-review.googlesource.com/c/src/+/208800
> > Reviewed-by: Niels Moller <nisse@webrtc.org>
> > Commit-Queue: Taylor <deadbeef@webrtc.org>
> > Cr-Commit-Position: refs/heads/master@{#33364}
>
> TBR=nisse@webrtc.org
>
> Bug: webrtc:12467
> Change-Id: Ib5d815a2cbca4feb25f360bff7ed62c02d1910a0
> Reviewed-on: https://webrtc-review.googlesource.com/c/src/+/209820
> Reviewed-by: Taylor <deadbeef@webrtc.org>
> Commit-Queue: Taylor <deadbeef@webrtc.org>
> Cr-Commit-Position: refs/heads/master@{#33386}

Bug: webrtc:12467
Change-Id: I5f9fcd6df7a211e6edfa64577fc953833f4d9b79
Reviewed-on: https://webrtc-review.googlesource.com/c/src/+/210040
Reviewed-by: Niels Moller <nisse@webrtc.org>
Reviewed-by: Florent Castelli <orphis@webrtc.org>
Commit-Queue: Taylor <deadbeef@webrtc.org>
Cr-Commit-Position: refs/heads/master@{#33427}
This commit is contained in:
Taylor Brandstetter
2021-03-08 15:36:53 -08:00
committed by Commit Bot
parent 34fdc92119
commit b2e71b8b35
3 changed files with 119 additions and 95 deletions

View File

@ -20,6 +20,7 @@ enum PreservedErrno {
// Successful return value from usrsctp callbacks. Is not actually used by // Successful return value from usrsctp callbacks. Is not actually used by
// usrsctp, but all example programs for usrsctp use 1 as their return value. // usrsctp, but all example programs for usrsctp use 1 as their return value.
constexpr int kSctpSuccessReturn = 1; constexpr int kSctpSuccessReturn = 1;
constexpr int kSctpErrorReturn = 0;
} // namespace } // namespace
@ -27,7 +28,6 @@ constexpr int kSctpSuccessReturn = 1;
#include <stdio.h> #include <stdio.h>
#include <usrsctp.h> #include <usrsctp.h>
#include <functional>
#include <memory> #include <memory>
#include <unordered_map> #include <unordered_map>
@ -96,6 +96,21 @@ enum {
// Should only be modified by UsrSctpWrapper. // Should only be modified by UsrSctpWrapper.
ABSL_CONST_INIT cricket::SctpTransportMap* g_transport_map_ = nullptr; ABSL_CONST_INIT cricket::SctpTransportMap* g_transport_map_ = nullptr;
// Helper that will call C's free automatically.
// TODO(b/181900299): Figure out why unique_ptr with a custom deleter is causing
// issues in a certain build environment.
class AutoFreedPointer {
public:
explicit AutoFreedPointer(void* ptr) : ptr_(ptr) {}
AutoFreedPointer(AutoFreedPointer&& o) : ptr_(o.ptr_) { o.ptr_ = nullptr; }
~AutoFreedPointer() { free(ptr_); }
void* get() const { return ptr_; }
private:
void* ptr_;
};
// Helper for logging SCTP messages. // Helper for logging SCTP messages.
#if defined(__GNUC__) #if defined(__GNUC__)
__attribute__((__format__(__printf__, 1, 2))) __attribute__((__format__(__printf__, 1, 2)))
@ -252,31 +267,20 @@ class SctpTransportMap {
return map_.erase(id) > 0; return map_.erase(id) > 0;
} }
// Must be called on the transport's network thread to protect against
// simultaneous deletion/deregistration of the transport; if that's not
// guaranteed, use ExecuteWithLock.
SctpTransport* Retrieve(uintptr_t id) const {
webrtc::MutexLock lock(&lock_);
SctpTransport* transport = RetrieveWhileHoldingLock(id);
if (transport) {
RTC_DCHECK_RUN_ON(transport->network_thread());
}
return transport;
}
// Posts |action| to the network thread of the transport identified by |id| // Posts |action| to the network thread of the transport identified by |id|
// and returns true if found, all while holding a lock to protect against the // and returns true if found, all while holding a lock to protect against the
// transport being simultaneously deleted/deregistered, or returns false if // transport being simultaneously deleted/deregistered, or returns false if
// not found. // not found.
bool PostToTransportThread(uintptr_t id, template <typename F>
std::function<void(SctpTransport*)> action) const { bool PostToTransportThread(uintptr_t id, F action) const {
webrtc::MutexLock lock(&lock_); webrtc::MutexLock lock(&lock_);
SctpTransport* transport = RetrieveWhileHoldingLock(id); SctpTransport* transport = RetrieveWhileHoldingLock(id);
if (!transport) { if (!transport) {
return false; return false;
} }
transport->network_thread_->PostTask(ToQueuedTask( transport->network_thread_->PostTask(ToQueuedTask(
transport->task_safety_, [transport, action]() { action(transport); })); transport->task_safety_,
[transport, action{std::move(action)}]() { action(transport); }));
return true; return true;
} }
@ -429,7 +433,7 @@ class SctpTransport::UsrSctpWrapper {
if (!found) { if (!found) {
RTC_LOG(LS_ERROR) RTC_LOG(LS_ERROR)
<< "OnSctpOutboundPacket: Failed to get transport for socket ID " << "OnSctpOutboundPacket: Failed to get transport for socket ID "
<< addr; << addr << "; possibly was already destroyed.";
return EINVAL; return EINVAL;
} }
@ -447,28 +451,46 @@ class SctpTransport::UsrSctpWrapper {
struct sctp_rcvinfo rcv, struct sctp_rcvinfo rcv,
int flags, int flags,
void* ulp_info) { void* ulp_info) {
SctpTransport* transport = GetTransportFromSocket(sock); AutoFreedPointer owned_data(data);
if (!transport) {
absl::optional<uintptr_t> id = GetTransportIdFromSocket(sock);
if (!id) {
RTC_LOG(LS_ERROR) RTC_LOG(LS_ERROR)
<< "OnSctpInboundPacket: Failed to get transport for socket " << sock << "OnSctpInboundPacket: Failed to get transport ID from socket "
<< "; possibly was already destroyed."; << sock;
free(data); return kSctpErrorReturn;
return 0;
} }
// Sanity check that both methods of getting the SctpTransport pointer
// yield the same result. if (!g_transport_map_) {
RTC_CHECK_EQ(transport, static_cast<SctpTransport*>(ulp_info)); RTC_LOG(LS_ERROR)
int result = << "OnSctpInboundPacket called after usrsctp uninitialized?";
transport->OnDataOrNotificationFromSctp(data, length, rcv, flags); return kSctpErrorReturn;
free(data); }
return result; // PostsToTransportThread protects against the transport being
// simultaneously deregistered/deleted, since this callback may come from
// the SCTP timer thread and thus race with the network thread.
bool found = g_transport_map_->PostToTransportThread(
*id, [owned_data{std::move(owned_data)}, length, rcv,
flags](SctpTransport* transport) {
transport->OnDataOrNotificationFromSctp(owned_data.get(), length, rcv,
flags);
});
if (!found) {
RTC_LOG(LS_ERROR)
<< "OnSctpInboundPacket: Failed to get transport for socket ID "
<< *id << "; possibly was already destroyed.";
return kSctpErrorReturn;
}
return kSctpSuccessReturn;
} }
static SctpTransport* GetTransportFromSocket(struct socket* sock) { static absl::optional<uintptr_t> GetTransportIdFromSocket(
struct socket* sock) {
absl::optional<uintptr_t> ret;
struct sockaddr* addrs = nullptr; struct sockaddr* addrs = nullptr;
int naddrs = usrsctp_getladdrs(sock, 0, &addrs); int naddrs = usrsctp_getladdrs(sock, 0, &addrs);
if (naddrs <= 0 || addrs[0].sa_family != AF_CONN) { if (naddrs <= 0 || addrs[0].sa_family != AF_CONN) {
return nullptr; return ret;
} }
// usrsctp_getladdrs() returns the addresses bound to this socket, which // usrsctp_getladdrs() returns the addresses bound to this socket, which
// contains the SctpTransport id as sconn_addr. Read the id, // contains the SctpTransport id as sconn_addr. Read the id,
@ -477,17 +499,10 @@ class SctpTransport::UsrSctpWrapper {
// id of the transport that created them, so [0] is as good as any other. // id of the transport that created them, so [0] is as good as any other.
struct sockaddr_conn* sconn = struct sockaddr_conn* sconn =
reinterpret_cast<struct sockaddr_conn*>(&addrs[0]); reinterpret_cast<struct sockaddr_conn*>(&addrs[0]);
if (!g_transport_map_) { ret = reinterpret_cast<uintptr_t>(sconn->sconn_addr);
RTC_LOG(LS_ERROR)
<< "GetTransportFromSocket called after usrsctp uninitialized?";
usrsctp_freeladdrs(addrs);
return nullptr;
}
SctpTransport* transport = g_transport_map_->Retrieve(
reinterpret_cast<uintptr_t>(sconn->sconn_addr));
usrsctp_freeladdrs(addrs); usrsctp_freeladdrs(addrs);
return transport; return ret;
} }
// TODO(crbug.com/webrtc/11899): This is a legacy callback signature, remove // TODO(crbug.com/webrtc/11899): This is a legacy callback signature, remove
@ -496,14 +511,26 @@ class SctpTransport::UsrSctpWrapper {
// Fired on our I/O thread. SctpTransport::OnPacketReceived() gets // Fired on our I/O thread. SctpTransport::OnPacketReceived() gets
// a packet containing acknowledgments, which goes into usrsctp_conninput, // a packet containing acknowledgments, which goes into usrsctp_conninput,
// and then back here. // and then back here.
SctpTransport* transport = GetTransportFromSocket(sock); absl::optional<uintptr_t> id = GetTransportIdFromSocket(sock);
if (!transport) { if (!id) {
RTC_LOG(LS_ERROR) RTC_LOG(LS_ERROR)
<< "SendThresholdCallback: Failed to get transport for socket " << "SendThresholdCallback: Failed to get transport ID from socket "
<< sock << "; possibly was already destroyed."; << sock;
return 0; return 0;
} }
transport->OnSendThresholdCallback(); if (!g_transport_map_) {
RTC_LOG(LS_ERROR)
<< "SendThresholdCallback called after usrsctp uninitialized?";
return 0;
}
bool found = g_transport_map_->PostToTransportThread(
*id,
[](SctpTransport* transport) { transport->OnSendThresholdCallback(); });
if (!found) {
RTC_LOG(LS_ERROR)
<< "SendThresholdCallback: Failed to get transport for socket ID "
<< *id << "; possibly was already destroyed.";
}
return 0; return 0;
} }
@ -513,17 +540,26 @@ class SctpTransport::UsrSctpWrapper {
// Fired on our I/O thread. SctpTransport::OnPacketReceived() gets // Fired on our I/O thread. SctpTransport::OnPacketReceived() gets
// a packet containing acknowledgments, which goes into usrsctp_conninput, // a packet containing acknowledgments, which goes into usrsctp_conninput,
// and then back here. // and then back here.
SctpTransport* transport = GetTransportFromSocket(sock); absl::optional<uintptr_t> id = GetTransportIdFromSocket(sock);
if (!transport) { if (!id) {
RTC_LOG(LS_ERROR) RTC_LOG(LS_ERROR)
<< "SendThresholdCallback: Failed to get transport for socket " << "SendThresholdCallback: Failed to get transport ID from socket "
<< sock << "; possibly was already destroyed."; << sock;
return 0; return 0;
} }
// Sanity check that both methods of getting the SctpTransport pointer if (!g_transport_map_) {
// yield the same result. RTC_LOG(LS_ERROR)
RTC_CHECK_EQ(transport, static_cast<SctpTransport*>(ulp_info)); << "SendThresholdCallback called after usrsctp uninitialized?";
transport->OnSendThresholdCallback(); return 0;
}
bool found = g_transport_map_->PostToTransportThread(
*id,
[](SctpTransport* transport) { transport->OnSendThresholdCallback(); });
if (!found) {
RTC_LOG(LS_ERROR)
<< "SendThresholdCallback: Failed to get transport for socket ID "
<< *id << "; possibly was already destroyed.";
}
return 0; return 0;
} }
}; };
@ -1175,24 +1211,25 @@ void SctpTransport::OnPacketFromSctpToNetwork(
rtc::PacketOptions(), PF_NORMAL); rtc::PacketOptions(), PF_NORMAL);
} }
int SctpTransport::InjectDataOrNotificationFromSctpForTesting( void SctpTransport::InjectDataOrNotificationFromSctpForTesting(
const void* data, const void* data,
size_t length, size_t length,
struct sctp_rcvinfo rcv, struct sctp_rcvinfo rcv,
int flags) { int flags) {
return OnDataOrNotificationFromSctp(data, length, rcv, flags); OnDataOrNotificationFromSctp(data, length, rcv, flags);
} }
int SctpTransport::OnDataOrNotificationFromSctp(const void* data, void SctpTransport::OnDataOrNotificationFromSctp(const void* data,
size_t length, size_t length,
struct sctp_rcvinfo rcv, struct sctp_rcvinfo rcv,
int flags) { int flags) {
RTC_DCHECK_RUN_ON(network_thread_);
// If data is NULL, the SCTP association has been closed. // If data is NULL, the SCTP association has been closed.
if (!data) { if (!data) {
RTC_LOG(LS_INFO) << debug_name_ RTC_LOG(LS_INFO) << debug_name_
<< "->OnDataOrNotificationFromSctp(...): " << "->OnDataOrNotificationFromSctp(...): "
"No data; association closed."; "No data; association closed.";
return kSctpSuccessReturn; return;
} }
// Handle notifications early. // Handle notifications early.
@ -1205,14 +1242,10 @@ int SctpTransport::OnDataOrNotificationFromSctp(const void* data,
<< "->OnDataOrNotificationFromSctp(...): SCTP notification" << "->OnDataOrNotificationFromSctp(...): SCTP notification"
<< " length=" << length; << " length=" << length;
// Copy and dispatch asynchronously
rtc::CopyOnWriteBuffer notification(reinterpret_cast<const uint8_t*>(data), rtc::CopyOnWriteBuffer notification(reinterpret_cast<const uint8_t*>(data),
length); length);
network_thread_->PostTask(ToQueuedTask( OnNotificationFromSctp(notification);
task_safety_, [this, notification = std::move(notification)]() { return;
OnNotificationFromSctp(notification);
}));
return kSctpSuccessReturn;
} }
// Log data chunk // Log data chunk
@ -1230,7 +1263,7 @@ int SctpTransport::OnDataOrNotificationFromSctp(const void* data,
// Unexpected PPID, dropping // Unexpected PPID, dropping
RTC_LOG(LS_ERROR) << "Received an unknown PPID " << ppid RTC_LOG(LS_ERROR) << "Received an unknown PPID " << ppid
<< " on an SCTP packet. Dropping."; << " on an SCTP packet. Dropping.";
return kSctpSuccessReturn; return;
} }
// Expect only continuation messages belonging to the same SID. The SCTP // Expect only continuation messages belonging to the same SID. The SCTP
@ -1266,7 +1299,7 @@ int SctpTransport::OnDataOrNotificationFromSctp(const void* data,
if (partial_incoming_message_.size() < kSctpSendBufferSize) { if (partial_incoming_message_.size() < kSctpSendBufferSize) {
// We still have space in the buffer. Continue buffering chunks until // We still have space in the buffer. Continue buffering chunks until
// the message is complete before handing it out. // the message is complete before handing it out.
return kSctpSuccessReturn; return;
} else { } else {
// The sender is exceeding the maximum message size that we announced. // The sender is exceeding the maximum message size that we announced.
// Spit out a warning but still hand out the partial message. Note that // Spit out a warning but still hand out the partial message. Note that
@ -1280,18 +1313,9 @@ int SctpTransport::OnDataOrNotificationFromSctp(const void* data,
} }
} }
// Dispatch the complete message. // Dispatch the complete message and reset the message buffer.
// The ownership of the packet transfers to |invoker_|. Using OnDataFromSctpToTransport(params, partial_incoming_message_);
// CopyOnWriteBuffer is the most convenient way to do this.
network_thread_->PostTask(webrtc::ToQueuedTask(
task_safety_, [this, params = std::move(params),
message = partial_incoming_message_]() {
OnDataFromSctpToTransport(params, message);
}));
// Reset the message buffer
partial_incoming_message_.Clear(); partial_incoming_message_.Clear();
return kSctpSuccessReturn;
} }
void SctpTransport::OnDataFromSctpToTransport( void SctpTransport::OnDataFromSctpToTransport(

View File

@ -96,10 +96,10 @@ class SctpTransport : public SctpTransportInternal,
void set_debug_name_for_testing(const char* debug_name) override { void set_debug_name_for_testing(const char* debug_name) override {
debug_name_ = debug_name; debug_name_ = debug_name;
} }
int InjectDataOrNotificationFromSctpForTesting(const void* data, void InjectDataOrNotificationFromSctpForTesting(const void* data,
size_t length, size_t length,
struct sctp_rcvinfo rcv, struct sctp_rcvinfo rcv,
int flags); int flags);
// Exposed to allow Post call from c-callbacks. // Exposed to allow Post call from c-callbacks.
// TODO(deadbeef): Remove this or at least make it return a const pointer. // TODO(deadbeef): Remove this or at least make it return a const pointer.
@ -180,12 +180,12 @@ class SctpTransport : public SctpTransportInternal,
// Called using |invoker_| to send packet on the network. // Called using |invoker_| to send packet on the network.
void OnPacketFromSctpToNetwork(const rtc::CopyOnWriteBuffer& buffer); void OnPacketFromSctpToNetwork(const rtc::CopyOnWriteBuffer& buffer);
// Called on the SCTP thread. // Called on the network thread.
// Flags are standard socket API flags (RFC 6458). // Flags are standard socket API flags (RFC 6458).
int OnDataOrNotificationFromSctp(const void* data, void OnDataOrNotificationFromSctp(const void* data,
size_t length, size_t length,
struct sctp_rcvinfo rcv, struct sctp_rcvinfo rcv,
int flags); int flags);
// Called using |invoker_| to decide what to do with the data. // Called using |invoker_| to decide what to do with the data.
void OnDataFromSctpToTransport(const ReceiveDataParams& params, void OnDataFromSctpToTransport(const ReceiveDataParams& params,
const rtc::CopyOnWriteBuffer& buffer); const rtc::CopyOnWriteBuffer& buffer);

View File

@ -282,8 +282,8 @@ TEST_F(SctpTransportTest, MessageInterleavedWithNotification) {
meta.rcv_tsn = 42; meta.rcv_tsn = 42;
meta.rcv_cumtsn = 42; meta.rcv_cumtsn = 42;
chunk.SetData("meow?", 5); chunk.SetData("meow?", 5);
EXPECT_EQ(1, transport1->InjectDataOrNotificationFromSctpForTesting( transport1->InjectDataOrNotificationFromSctpForTesting(chunk.data(),
chunk.data(), chunk.size(), meta, 0)); chunk.size(), meta, 0);
// Inject a notification in between chunks. // Inject a notification in between chunks.
union sctp_notification notification; union sctp_notification notification;
@ -292,15 +292,15 @@ TEST_F(SctpTransportTest, MessageInterleavedWithNotification) {
notification.sn_header.sn_type = SCTP_PEER_ADDR_CHANGE; notification.sn_header.sn_type = SCTP_PEER_ADDR_CHANGE;
notification.sn_header.sn_flags = 0; notification.sn_header.sn_flags = 0;
notification.sn_header.sn_length = sizeof(notification); notification.sn_header.sn_length = sizeof(notification);
EXPECT_EQ(1, transport1->InjectDataOrNotificationFromSctpForTesting( transport1->InjectDataOrNotificationFromSctpForTesting(
&notification, sizeof(notification), {0}, MSG_NOTIFICATION)); &notification, sizeof(notification), {0}, MSG_NOTIFICATION);
// Inject chunk 2/2 // Inject chunk 2/2
meta.rcv_tsn = 42; meta.rcv_tsn = 42;
meta.rcv_cumtsn = 43; meta.rcv_cumtsn = 43;
chunk.SetData(" rawr!", 6); chunk.SetData(" rawr!", 6);
EXPECT_EQ(1, transport1->InjectDataOrNotificationFromSctpForTesting( transport1->InjectDataOrNotificationFromSctpForTesting(
chunk.data(), chunk.size(), meta, MSG_EOR)); chunk.data(), chunk.size(), meta, MSG_EOR);
// Expect the message to contain both chunks. // Expect the message to contain both chunks.
EXPECT_TRUE_WAIT(ReceivedData(&recv1, 1, "meow? rawr!"), kDefaultTimeout); EXPECT_TRUE_WAIT(ReceivedData(&recv1, 1, "meow? rawr!"), kDefaultTimeout);