dcsctp: Ensure callbacks are always triggered

The previous manual way of triggering the deferred callbacks was very
error-prone, and this was also forgotten at a few places.

We can do better.

Using the RAII programming idiom, the callbacks are now ensured to be
called before returning from public methods.

Also added additional debug checks to ensure that there is a
ScopedDeferrer active whenever callbacks are deferred.

Bug: webrtc:13217
Change-Id: I16a8343b52c00fb30acb018d3846acd0a64318e0
Reviewed-on: https://webrtc-review.googlesource.com/c/src/+/233242
Commit-Queue: Victor Boivie <boivie@webrtc.org>
Reviewed-by: Florent Castelli <orphis@webrtc.org>
Cr-Commit-Position: refs/heads/main@{#35117}
This commit is contained in:
Victor Boivie
2021-09-28 21:38:34 +02:00
committed by WebRTC LUCI CQ
parent 0081f1c331
commit 15a0c880cf
3 changed files with 53 additions and 14 deletions

View File

@ -36,12 +36,19 @@ class MessageDeliverer {
};
} // namespace
void CallbackDeferrer::Prepare() {
RTC_DCHECK(!prepared_);
prepared_ = true;
}
void CallbackDeferrer::TriggerDeferred() {
// Need to swap here. The client may call into the library from within a
// callback, and that might result in adding new callbacks to this instance,
// and the vector can't be modified while iterated on.
RTC_DCHECK(prepared_);
std::vector<std::function<void(DcSctpSocketCallbacks & cb)>> deferred;
deferred.swap(deferred_);
prepared_ = false;
for (auto& cb : deferred) {
cb(underlying_);
@ -70,12 +77,14 @@ uint32_t CallbackDeferrer::GetRandomInt(uint32_t low, uint32_t high) {
}
void CallbackDeferrer::OnMessageReceived(DcSctpMessage message) {
RTC_DCHECK(prepared_);
deferred_.emplace_back(
[deliverer = MessageDeliverer(std::move(message))](
DcSctpSocketCallbacks& cb) mutable { deliverer.Deliver(cb); });
}
void CallbackDeferrer::OnError(ErrorKind error, absl::string_view message) {
RTC_DCHECK(prepared_);
deferred_.emplace_back(
[error, message = std::string(message)](DcSctpSocketCallbacks& cb) {
cb.OnError(error, message);
@ -83,6 +92,7 @@ void CallbackDeferrer::OnError(ErrorKind error, absl::string_view message) {
}
void CallbackDeferrer::OnAborted(ErrorKind error, absl::string_view message) {
RTC_DCHECK(prepared_);
deferred_.emplace_back(
[error, message = std::string(message)](DcSctpSocketCallbacks& cb) {
cb.OnAborted(error, message);
@ -90,14 +100,17 @@ void CallbackDeferrer::OnAborted(ErrorKind error, absl::string_view message) {
}
void CallbackDeferrer::OnConnected() {
RTC_DCHECK(prepared_);
deferred_.emplace_back([](DcSctpSocketCallbacks& cb) { cb.OnConnected(); });
}
void CallbackDeferrer::OnClosed() {
RTC_DCHECK(prepared_);
deferred_.emplace_back([](DcSctpSocketCallbacks& cb) { cb.OnClosed(); });
}
void CallbackDeferrer::OnConnectionRestarted() {
RTC_DCHECK(prepared_);
deferred_.emplace_back(
[](DcSctpSocketCallbacks& cb) { cb.OnConnectionRestarted(); });
}
@ -105,6 +118,7 @@ void CallbackDeferrer::OnConnectionRestarted() {
void CallbackDeferrer::OnStreamsResetFailed(
rtc::ArrayView<const StreamID> outgoing_streams,
absl::string_view reason) {
RTC_DCHECK(prepared_);
deferred_.emplace_back(
[streams = std::vector<StreamID>(outgoing_streams.begin(),
outgoing_streams.end()),
@ -115,6 +129,7 @@ void CallbackDeferrer::OnStreamsResetFailed(
void CallbackDeferrer::OnStreamsResetPerformed(
rtc::ArrayView<const StreamID> outgoing_streams) {
RTC_DCHECK(prepared_);
deferred_.emplace_back(
[streams = std::vector<StreamID>(outgoing_streams.begin(),
outgoing_streams.end())](
@ -123,6 +138,7 @@ void CallbackDeferrer::OnStreamsResetPerformed(
void CallbackDeferrer::OnIncomingStreamsReset(
rtc::ArrayView<const StreamID> incoming_streams) {
RTC_DCHECK(prepared_);
deferred_.emplace_back(
[streams = std::vector<StreamID>(incoming_streams.begin(),
incoming_streams.end())](
@ -130,12 +146,14 @@ void CallbackDeferrer::OnIncomingStreamsReset(
}
void CallbackDeferrer::OnBufferedAmountLow(StreamID stream_id) {
RTC_DCHECK(prepared_);
deferred_.emplace_back([stream_id](DcSctpSocketCallbacks& cb) {
cb.OnBufferedAmountLow(stream_id);
});
}
void CallbackDeferrer::OnTotalBufferedAmountLow() {
RTC_DCHECK(prepared_);
deferred_.emplace_back(
[](DcSctpSocketCallbacks& cb) { cb.OnTotalBufferedAmountLow(); });
}

View File

@ -26,7 +26,6 @@
#include "rtc_base/ref_counted_object.h"
namespace dcsctp {
// Defers callbacks until they can be safely triggered.
//
// There are a lot of callbacks from the dcSCTP library to the client,
@ -44,11 +43,22 @@ namespace dcsctp {
// There are a number of exceptions, which is clearly annotated in the API.
class CallbackDeferrer : public DcSctpSocketCallbacks {
public:
class ScopedDeferrer {
public:
explicit ScopedDeferrer(CallbackDeferrer& callback_deferrer)
: callback_deferrer_(callback_deferrer) {
callback_deferrer_.Prepare();
}
~ScopedDeferrer() { callback_deferrer_.TriggerDeferred(); }
private:
CallbackDeferrer& callback_deferrer_;
};
explicit CallbackDeferrer(DcSctpSocketCallbacks& underlying)
: underlying_(underlying) {}
void TriggerDeferred();
// Implementation of DcSctpSocketCallbacks
SendPacketStatus SendPacketWithStatus(
rtc::ArrayView<const uint8_t> data) override;
@ -71,7 +81,11 @@ class CallbackDeferrer : public DcSctpSocketCallbacks {
void OnTotalBufferedAmountLow() override;
private:
void Prepare();
void TriggerDeferred();
DcSctpSocketCallbacks& underlying_;
bool prepared_ = false;
std::vector<std::function<void(DcSctpSocketCallbacks& cb)>> deferred_;
};
} // namespace dcsctp

View File

@ -281,6 +281,8 @@ void DcSctpSocket::MakeConnectionParameters() {
}
void DcSctpSocket::Connect() {
CallbackDeferrer::ScopedDeferrer deferrer(callbacks_);
if (state_ == State::kClosed) {
MakeConnectionParameters();
RTC_DLOG(LS_INFO)
@ -296,10 +298,11 @@ void DcSctpSocket::Connect() {
<< "Called Connect on a socket that is not closed";
}
RTC_DCHECK(IsConsistent());
callbacks_.TriggerDeferred();
}
void DcSctpSocket::RestoreFromState(const DcSctpSocketHandoverState& state) {
CallbackDeferrer::ScopedDeferrer deferrer(callbacks_);
if (state_ != State::kClosed) {
callbacks_.OnError(ErrorKind::kUnsupportedOperation,
"Only closed socket can be restored from state");
@ -334,10 +337,11 @@ void DcSctpSocket::RestoreFromState(const DcSctpSocketHandoverState& state) {
}
RTC_DCHECK(IsConsistent());
callbacks_.TriggerDeferred();
}
void DcSctpSocket::Shutdown() {
CallbackDeferrer::ScopedDeferrer deferrer(callbacks_);
if (tcb_ != nullptr) {
// https://tools.ietf.org/html/rfc4960#section-9.2
// "Upon receipt of the SHUTDOWN primitive from its upper layer, the
@ -361,10 +365,11 @@ void DcSctpSocket::Shutdown() {
InternalClose(ErrorKind::kNoError, "");
}
RTC_DCHECK(IsConsistent());
callbacks_.TriggerDeferred();
}
void DcSctpSocket::Close() {
CallbackDeferrer::ScopedDeferrer deferrer(callbacks_);
if (state_ != State::kClosed) {
if (tcb_ != nullptr) {
SctpPacket::Builder b = tcb_->PacketBuilder();
@ -379,7 +384,6 @@ void DcSctpSocket::Close() {
RTC_DLOG(LS_INFO) << log_prefix() << "Called Close on a closed socket";
}
RTC_DCHECK(IsConsistent());
callbacks_.TriggerDeferred();
}
void DcSctpSocket::CloseConnectionBecauseOfTooManyTransmissionErrors() {
@ -411,6 +415,8 @@ void DcSctpSocket::InternalClose(ErrorKind error, absl::string_view message) {
SendStatus DcSctpSocket::Send(DcSctpMessage message,
const SendOptions& send_options) {
CallbackDeferrer::ScopedDeferrer deferrer(callbacks_);
if (message.payload().empty()) {
callbacks_.OnError(ErrorKind::kProtocolViolation,
"Unable to send empty message");
@ -445,12 +451,13 @@ SendStatus DcSctpSocket::Send(DcSctpMessage message,
}
RTC_DCHECK(IsConsistent());
callbacks_.TriggerDeferred();
return SendStatus::kSuccess;
}
ResetStreamsStatus DcSctpSocket::ResetStreams(
rtc::ArrayView<const StreamID> outgoing_streams) {
CallbackDeferrer::ScopedDeferrer deferrer(callbacks_);
if (tcb_ == nullptr) {
callbacks_.OnError(ErrorKind::kWrongSequence,
"Can't reset streams as the socket is not connected");
@ -472,7 +479,6 @@ ResetStreamsStatus DcSctpSocket::ResetStreams(
}
RTC_DCHECK(IsConsistent());
callbacks_.TriggerDeferred();
return ResetStreamsStatus::kPerformed;
}
@ -654,6 +660,8 @@ bool DcSctpSocket::ValidatePacket(const SctpPacket& packet) {
}
void DcSctpSocket::HandleTimeout(TimeoutID timeout_id) {
CallbackDeferrer::ScopedDeferrer deferrer(callbacks_);
timer_manager_.HandleTimeout(timeout_id);
if (tcb_ != nullptr && tcb_->HasTooManyTxErrors()) {
@ -662,10 +670,11 @@ void DcSctpSocket::HandleTimeout(TimeoutID timeout_id) {
}
RTC_DCHECK(IsConsistent());
callbacks_.TriggerDeferred();
}
void DcSctpSocket::ReceivePacket(rtc::ArrayView<const uint8_t> data) {
CallbackDeferrer::ScopedDeferrer deferrer(callbacks_);
++metrics_.rx_packets_count;
if (packet_observer_ != nullptr) {
@ -681,7 +690,6 @@ void DcSctpSocket::ReceivePacket(rtc::ArrayView<const uint8_t> data) {
callbacks_.OnError(ErrorKind::kParseFailed,
"Failed to parse received SCTP packet");
RTC_DCHECK(IsConsistent());
callbacks_.TriggerDeferred();
return;
}
@ -696,7 +704,6 @@ void DcSctpSocket::ReceivePacket(rtc::ArrayView<const uint8_t> data) {
RTC_DLOG(LS_VERBOSE) << log_prefix()
<< "Packet failed verification tag check - dropping";
RTC_DCHECK(IsConsistent());
callbacks_.TriggerDeferred();
return;
}
@ -714,7 +721,6 @@ void DcSctpSocket::ReceivePacket(rtc::ArrayView<const uint8_t> data) {
}
RTC_DCHECK(IsConsistent());
callbacks_.TriggerDeferred();
}
void DcSctpSocket::DebugPrintOutgoing(rtc::ArrayView<const uint8_t> payload) {
@ -1646,6 +1652,8 @@ HandoverReadinessStatus DcSctpSocket::GetHandoverReadiness() const {
absl::optional<DcSctpSocketHandoverState>
DcSctpSocket::GetHandoverStateAndClose() {
CallbackDeferrer::ScopedDeferrer deferrer(callbacks_);
if (!GetHandoverReadiness().IsReady()) {
return absl::nullopt;
}
@ -1659,7 +1667,6 @@ DcSctpSocket::GetHandoverStateAndClose() {
tcb_->AddHandoverState(state);
send_queue_.AddHandoverState(state);
InternalClose(ErrorKind::kNoError, "handover");
callbacks_.TriggerDeferred();
}
return std::move(state);