Revert "Fix race between destroying SctpTransport and receiving notification on timer thread."

This reverts commit a88fe7be146b9b85575504d4d5193c007f2e3de4.

Reason for revert: Breaks downstream test, still investigating.

Original change's description:
> Fix race between destroying SctpTransport and receiving notification on timer thread.
>
> This gets rid of the SctpTransportMap::Retrieve method and forces
> everything to go through PostToTransportThread, which behaves safely
> with relation to the transport's destruction.
>
> Bug: webrtc:12467
> Change-Id: Id4a723c2c985be2a368d2cc5c5e62deb04c509ab
> Reviewed-on: https://webrtc-review.googlesource.com/c/src/+/208800
> Reviewed-by: Niels Moller <nisse@webrtc.org>
> Commit-Queue: Taylor <deadbeef@webrtc.org>
> Cr-Commit-Position: refs/heads/master@{#33364}

TBR=nisse@webrtc.org

Bug: webrtc:12467
Change-Id: Ib5d815a2cbca4feb25f360bff7ed62c02d1910a0
Reviewed-on: https://webrtc-review.googlesource.com/c/src/+/209820
Reviewed-by: Taylor <deadbeef@webrtc.org>
Commit-Queue: Taylor <deadbeef@webrtc.org>
Cr-Commit-Position: refs/heads/master@{#33386}
This commit is contained in:
Taylor
2021-03-05 01:05:38 +00:00
committed by Commit Bot
parent 83be84bb74
commit 8a38b1cf68
3 changed files with 95 additions and 107 deletions

View File

@ -20,7 +20,6 @@ enum PreservedErrno {
// Successful return value from usrsctp callbacks. Is not actually used by
// usrsctp, but all example programs for usrsctp use 1 as their return value.
constexpr int kSctpSuccessReturn = 1;
constexpr int kSctpErrorReturn = 0;
} // namespace
@ -28,6 +27,7 @@ constexpr int kSctpErrorReturn = 0;
#include <stdio.h>
#include <usrsctp.h>
#include <functional>
#include <memory>
#include <unordered_map>
@ -252,20 +252,31 @@ class SctpTransportMap {
return map_.erase(id) > 0;
}
// Must be called on the transport's network thread to protect against
// simultaneous deletion/deregistration of the transport; if that's not
// guaranteed, use ExecuteWithLock.
SctpTransport* Retrieve(uintptr_t id) const {
webrtc::MutexLock lock(&lock_);
SctpTransport* transport = RetrieveWhileHoldingLock(id);
if (transport) {
RTC_DCHECK_RUN_ON(transport->network_thread());
}
return transport;
}
// Posts |action| to the network thread of the transport identified by |id|
// and returns true if found, all while holding a lock to protect against the
// transport being simultaneously deleted/deregistered, or returns false if
// not found.
template <typename F>
bool PostToTransportThread(uintptr_t id, F action) const {
bool PostToTransportThread(uintptr_t id,
std::function<void(SctpTransport*)> action) const {
webrtc::MutexLock lock(&lock_);
SctpTransport* transport = RetrieveWhileHoldingLock(id);
if (!transport) {
return false;
}
transport->network_thread_->PostTask(ToQueuedTask(
transport->task_safety_,
[transport, action{std::move(action)}]() { action(transport); }));
transport->task_safety_, [transport, action]() { action(transport); }));
return true;
}
@ -418,7 +429,7 @@ class SctpTransport::UsrSctpWrapper {
if (!found) {
RTC_LOG(LS_ERROR)
<< "OnSctpOutboundPacket: Failed to get transport for socket ID "
<< addr << "; possibly was already destroyed.";
<< addr;
return EINVAL;
}
@ -436,49 +447,28 @@ class SctpTransport::UsrSctpWrapper {
struct sctp_rcvinfo rcv,
int flags,
void* ulp_info) {
struct DeleteByFree {
void operator()(void* p) const { free(p); }
};
std::unique_ptr<void, DeleteByFree> owned_data(data, DeleteByFree());
absl::optional<uintptr_t> id = GetTransportIdFromSocket(sock);
if (!id) {
SctpTransport* transport = GetTransportFromSocket(sock);
if (!transport) {
RTC_LOG(LS_ERROR)
<< "OnSctpInboundPacket: Failed to get transport ID from socket "
<< sock;
return kSctpErrorReturn;
<< "OnSctpInboundPacket: Failed to get transport for socket " << sock
<< "; possibly was already destroyed.";
free(data);
return 0;
}
if (!g_transport_map_) {
RTC_LOG(LS_ERROR)
<< "OnSctpInboundPacket called after usrsctp uninitialized?";
return kSctpErrorReturn;
}
// PostsToTransportThread protects against the transport being
// simultaneously deregistered/deleted, since this callback may come from
// the SCTP timer thread and thus race with the network thread.
bool found = g_transport_map_->PostToTransportThread(
*id, [owned_data{std::move(owned_data)}, length, rcv,
flags](SctpTransport* transport) {
transport->OnDataOrNotificationFromSctp(owned_data.get(), length, rcv,
flags);
});
if (!found) {
RTC_LOG(LS_ERROR)
<< "OnSctpInboundPacket: Failed to get transport for socket ID "
<< *id << "; possibly was already destroyed.";
return kSctpErrorReturn;
}
return kSctpSuccessReturn;
// Sanity check that both methods of getting the SctpTransport pointer
// yield the same result.
RTC_CHECK_EQ(transport, static_cast<SctpTransport*>(ulp_info));
int result =
transport->OnDataOrNotificationFromSctp(data, length, rcv, flags);
free(data);
return result;
}
static absl::optional<uintptr_t> GetTransportIdFromSocket(
struct socket* sock) {
absl::optional<uintptr_t> ret;
static SctpTransport* GetTransportFromSocket(struct socket* sock) {
struct sockaddr* addrs = nullptr;
int naddrs = usrsctp_getladdrs(sock, 0, &addrs);
if (naddrs <= 0 || addrs[0].sa_family != AF_CONN) {
return ret;
return nullptr;
}
// usrsctp_getladdrs() returns the addresses bound to this socket, which
// contains the SctpTransport id as sconn_addr. Read the id,
@ -487,10 +477,17 @@ class SctpTransport::UsrSctpWrapper {
// id of the transport that created them, so [0] is as good as any other.
struct sockaddr_conn* sconn =
reinterpret_cast<struct sockaddr_conn*>(&addrs[0]);
ret = reinterpret_cast<uintptr_t>(sconn->sconn_addr);
if (!g_transport_map_) {
RTC_LOG(LS_ERROR)
<< "GetTransportFromSocket called after usrsctp uninitialized?";
usrsctp_freeladdrs(addrs);
return nullptr;
}
SctpTransport* transport = g_transport_map_->Retrieve(
reinterpret_cast<uintptr_t>(sconn->sconn_addr));
usrsctp_freeladdrs(addrs);
return ret;
return transport;
}
// TODO(crbug.com/webrtc/11899): This is a legacy callback signature, remove
@ -499,26 +496,14 @@ class SctpTransport::UsrSctpWrapper {
// Fired on our I/O thread. SctpTransport::OnPacketReceived() gets
// a packet containing acknowledgments, which goes into usrsctp_conninput,
// and then back here.
absl::optional<uintptr_t> id = GetTransportIdFromSocket(sock);
if (!id) {
SctpTransport* transport = GetTransportFromSocket(sock);
if (!transport) {
RTC_LOG(LS_ERROR)
<< "SendThresholdCallback: Failed to get transport ID from socket "
<< sock;
<< "SendThresholdCallback: Failed to get transport for socket "
<< sock << "; possibly was already destroyed.";
return 0;
}
if (!g_transport_map_) {
RTC_LOG(LS_ERROR)
<< "SendThresholdCallback called after usrsctp uninitialized?";
return 0;
}
bool found = g_transport_map_->PostToTransportThread(
*id,
[](SctpTransport* transport) { transport->OnSendThresholdCallback(); });
if (!found) {
RTC_LOG(LS_ERROR)
<< "SendThresholdCallback: Failed to get transport for socket ID "
<< *id << "; possibly was already destroyed.";
}
transport->OnSendThresholdCallback();
return 0;
}
@ -528,26 +513,17 @@ class SctpTransport::UsrSctpWrapper {
// Fired on our I/O thread. SctpTransport::OnPacketReceived() gets
// a packet containing acknowledgments, which goes into usrsctp_conninput,
// and then back here.
absl::optional<uintptr_t> id = GetTransportIdFromSocket(sock);
if (!id) {
SctpTransport* transport = GetTransportFromSocket(sock);
if (!transport) {
RTC_LOG(LS_ERROR)
<< "SendThresholdCallback: Failed to get transport ID from socket "
<< sock;
<< "SendThresholdCallback: Failed to get transport for socket "
<< sock << "; possibly was already destroyed.";
return 0;
}
if (!g_transport_map_) {
RTC_LOG(LS_ERROR)
<< "SendThresholdCallback called after usrsctp uninitialized?";
return 0;
}
bool found = g_transport_map_->PostToTransportThread(
*id,
[](SctpTransport* transport) { transport->OnSendThresholdCallback(); });
if (!found) {
RTC_LOG(LS_ERROR)
<< "SendThresholdCallback: Failed to get transport for socket ID "
<< *id << "; possibly was already destroyed.";
}
// Sanity check that both methods of getting the SctpTransport pointer
// yield the same result.
RTC_CHECK_EQ(transport, static_cast<SctpTransport*>(ulp_info));
transport->OnSendThresholdCallback();
return 0;
}
};
@ -1199,25 +1175,24 @@ void SctpTransport::OnPacketFromSctpToNetwork(
rtc::PacketOptions(), PF_NORMAL);
}
void SctpTransport::InjectDataOrNotificationFromSctpForTesting(
int SctpTransport::InjectDataOrNotificationFromSctpForTesting(
const void* data,
size_t length,
struct sctp_rcvinfo rcv,
int flags) {
OnDataOrNotificationFromSctp(data, length, rcv, flags);
return OnDataOrNotificationFromSctp(data, length, rcv, flags);
}
void SctpTransport::OnDataOrNotificationFromSctp(const void* data,
size_t length,
struct sctp_rcvinfo rcv,
int flags) {
RTC_DCHECK_RUN_ON(network_thread_);
int SctpTransport::OnDataOrNotificationFromSctp(const void* data,
size_t length,
struct sctp_rcvinfo rcv,
int flags) {
// If data is NULL, the SCTP association has been closed.
if (!data) {
RTC_LOG(LS_INFO) << debug_name_
<< "->OnDataOrNotificationFromSctp(...): "
"No data; association closed.";
return;
return kSctpSuccessReturn;
}
// Handle notifications early.
@ -1230,10 +1205,14 @@ void SctpTransport::OnDataOrNotificationFromSctp(const void* data,
<< "->OnDataOrNotificationFromSctp(...): SCTP notification"
<< " length=" << length;
// Copy and dispatch asynchronously
rtc::CopyOnWriteBuffer notification(reinterpret_cast<const uint8_t*>(data),
length);
OnNotificationFromSctp(notification);
return;
network_thread_->PostTask(ToQueuedTask(
task_safety_, [this, notification = std::move(notification)]() {
OnNotificationFromSctp(notification);
}));
return kSctpSuccessReturn;
}
// Log data chunk
@ -1251,7 +1230,7 @@ void SctpTransport::OnDataOrNotificationFromSctp(const void* data,
// Unexpected PPID, dropping
RTC_LOG(LS_ERROR) << "Received an unknown PPID " << ppid
<< " on an SCTP packet. Dropping.";
return;
return kSctpSuccessReturn;
}
// Expect only continuation messages belonging to the same SID. The SCTP
@ -1287,7 +1266,7 @@ void SctpTransport::OnDataOrNotificationFromSctp(const void* data,
if (partial_incoming_message_.size() < kSctpSendBufferSize) {
// We still have space in the buffer. Continue buffering chunks until
// the message is complete before handing it out.
return;
return kSctpSuccessReturn;
} else {
// The sender is exceeding the maximum message size that we announced.
// Spit out a warning but still hand out the partial message. Note that
@ -1301,9 +1280,18 @@ void SctpTransport::OnDataOrNotificationFromSctp(const void* data,
}
}
// Dispatch the complete message and reset the message buffer.
OnDataFromSctpToTransport(params, partial_incoming_message_);
// Dispatch the complete message.
// The ownership of the packet transfers to |invoker_|. Using
// CopyOnWriteBuffer is the most convenient way to do this.
network_thread_->PostTask(webrtc::ToQueuedTask(
task_safety_, [this, params = std::move(params),
message = partial_incoming_message_]() {
OnDataFromSctpToTransport(params, message);
}));
// Reset the message buffer
partial_incoming_message_.Clear();
return kSctpSuccessReturn;
}
void SctpTransport::OnDataFromSctpToTransport(

View File

@ -96,10 +96,10 @@ class SctpTransport : public SctpTransportInternal,
void set_debug_name_for_testing(const char* debug_name) override {
debug_name_ = debug_name;
}
void InjectDataOrNotificationFromSctpForTesting(const void* data,
size_t length,
struct sctp_rcvinfo rcv,
int flags);
int InjectDataOrNotificationFromSctpForTesting(const void* data,
size_t length,
struct sctp_rcvinfo rcv,
int flags);
// Exposed to allow Post call from c-callbacks.
// TODO(deadbeef): Remove this or at least make it return a const pointer.
@ -180,12 +180,12 @@ class SctpTransport : public SctpTransportInternal,
// Called using |invoker_| to send packet on the network.
void OnPacketFromSctpToNetwork(const rtc::CopyOnWriteBuffer& buffer);
// Called on the network thread.
// Called on the SCTP thread.
// Flags are standard socket API flags (RFC 6458).
void OnDataOrNotificationFromSctp(const void* data,
size_t length,
struct sctp_rcvinfo rcv,
int flags);
int OnDataOrNotificationFromSctp(const void* data,
size_t length,
struct sctp_rcvinfo rcv,
int flags);
// Called using |invoker_| to decide what to do with the data.
void OnDataFromSctpToTransport(const ReceiveDataParams& params,
const rtc::CopyOnWriteBuffer& buffer);

View File

@ -282,8 +282,8 @@ TEST_F(SctpTransportTest, MessageInterleavedWithNotification) {
meta.rcv_tsn = 42;
meta.rcv_cumtsn = 42;
chunk.SetData("meow?", 5);
transport1->InjectDataOrNotificationFromSctpForTesting(chunk.data(),
chunk.size(), meta, 0);
EXPECT_EQ(1, transport1->InjectDataOrNotificationFromSctpForTesting(
chunk.data(), chunk.size(), meta, 0));
// Inject a notification in between chunks.
union sctp_notification notification;
@ -292,15 +292,15 @@ TEST_F(SctpTransportTest, MessageInterleavedWithNotification) {
notification.sn_header.sn_type = SCTP_PEER_ADDR_CHANGE;
notification.sn_header.sn_flags = 0;
notification.sn_header.sn_length = sizeof(notification);
transport1->InjectDataOrNotificationFromSctpForTesting(
&notification, sizeof(notification), {0}, MSG_NOTIFICATION);
EXPECT_EQ(1, transport1->InjectDataOrNotificationFromSctpForTesting(
&notification, sizeof(notification), {0}, MSG_NOTIFICATION));
// Inject chunk 2/2
meta.rcv_tsn = 42;
meta.rcv_cumtsn = 43;
chunk.SetData(" rawr!", 6);
transport1->InjectDataOrNotificationFromSctpForTesting(
chunk.data(), chunk.size(), meta, MSG_EOR);
EXPECT_EQ(1, transport1->InjectDataOrNotificationFromSctpForTesting(
chunk.data(), chunk.size(), meta, MSG_EOR));
// Expect the message to contain both chunks.
EXPECT_TRUE_WAIT(ReceivedData(&recv1, 1, "meow? rawr!"), kDefaultTimeout);