dcsctp: Ensure callbacks are always triggered
The previous manual way of triggering the deferred callbacks was very error-prone, and this was also forgotten at a few places. We can do better. Using the RAII programming idiom, the callbacks are now ensured to be called before returning from public methods. Also added additional debug checks to ensure that there is a ScopedDeferrer active whenever callbacks are deferred. Bug: webrtc:13217 Change-Id: I16a8343b52c00fb30acb018d3846acd0a64318e0 Reviewed-on: https://webrtc-review.googlesource.com/c/src/+/233242 Commit-Queue: Victor Boivie <boivie@webrtc.org> Reviewed-by: Florent Castelli <orphis@webrtc.org> Cr-Commit-Position: refs/heads/main@{#35117}
This commit is contained in:

committed by
WebRTC LUCI CQ

parent
0081f1c331
commit
15a0c880cf
@ -36,12 +36,19 @@ class MessageDeliverer {
|
||||
};
|
||||
} // namespace
|
||||
|
||||
void CallbackDeferrer::Prepare() {
|
||||
RTC_DCHECK(!prepared_);
|
||||
prepared_ = true;
|
||||
}
|
||||
|
||||
void CallbackDeferrer::TriggerDeferred() {
|
||||
// Need to swap here. The client may call into the library from within a
|
||||
// callback, and that might result in adding new callbacks to this instance,
|
||||
// and the vector can't be modified while iterated on.
|
||||
RTC_DCHECK(prepared_);
|
||||
std::vector<std::function<void(DcSctpSocketCallbacks & cb)>> deferred;
|
||||
deferred.swap(deferred_);
|
||||
prepared_ = false;
|
||||
|
||||
for (auto& cb : deferred) {
|
||||
cb(underlying_);
|
||||
@ -70,12 +77,14 @@ uint32_t CallbackDeferrer::GetRandomInt(uint32_t low, uint32_t high) {
|
||||
}
|
||||
|
||||
void CallbackDeferrer::OnMessageReceived(DcSctpMessage message) {
|
||||
RTC_DCHECK(prepared_);
|
||||
deferred_.emplace_back(
|
||||
[deliverer = MessageDeliverer(std::move(message))](
|
||||
DcSctpSocketCallbacks& cb) mutable { deliverer.Deliver(cb); });
|
||||
}
|
||||
|
||||
void CallbackDeferrer::OnError(ErrorKind error, absl::string_view message) {
|
||||
RTC_DCHECK(prepared_);
|
||||
deferred_.emplace_back(
|
||||
[error, message = std::string(message)](DcSctpSocketCallbacks& cb) {
|
||||
cb.OnError(error, message);
|
||||
@ -83,6 +92,7 @@ void CallbackDeferrer::OnError(ErrorKind error, absl::string_view message) {
|
||||
}
|
||||
|
||||
void CallbackDeferrer::OnAborted(ErrorKind error, absl::string_view message) {
|
||||
RTC_DCHECK(prepared_);
|
||||
deferred_.emplace_back(
|
||||
[error, message = std::string(message)](DcSctpSocketCallbacks& cb) {
|
||||
cb.OnAborted(error, message);
|
||||
@ -90,14 +100,17 @@ void CallbackDeferrer::OnAborted(ErrorKind error, absl::string_view message) {
|
||||
}
|
||||
|
||||
void CallbackDeferrer::OnConnected() {
|
||||
RTC_DCHECK(prepared_);
|
||||
deferred_.emplace_back([](DcSctpSocketCallbacks& cb) { cb.OnConnected(); });
|
||||
}
|
||||
|
||||
void CallbackDeferrer::OnClosed() {
|
||||
RTC_DCHECK(prepared_);
|
||||
deferred_.emplace_back([](DcSctpSocketCallbacks& cb) { cb.OnClosed(); });
|
||||
}
|
||||
|
||||
void CallbackDeferrer::OnConnectionRestarted() {
|
||||
RTC_DCHECK(prepared_);
|
||||
deferred_.emplace_back(
|
||||
[](DcSctpSocketCallbacks& cb) { cb.OnConnectionRestarted(); });
|
||||
}
|
||||
@ -105,6 +118,7 @@ void CallbackDeferrer::OnConnectionRestarted() {
|
||||
void CallbackDeferrer::OnStreamsResetFailed(
|
||||
rtc::ArrayView<const StreamID> outgoing_streams,
|
||||
absl::string_view reason) {
|
||||
RTC_DCHECK(prepared_);
|
||||
deferred_.emplace_back(
|
||||
[streams = std::vector<StreamID>(outgoing_streams.begin(),
|
||||
outgoing_streams.end()),
|
||||
@ -115,6 +129,7 @@ void CallbackDeferrer::OnStreamsResetFailed(
|
||||
|
||||
void CallbackDeferrer::OnStreamsResetPerformed(
|
||||
rtc::ArrayView<const StreamID> outgoing_streams) {
|
||||
RTC_DCHECK(prepared_);
|
||||
deferred_.emplace_back(
|
||||
[streams = std::vector<StreamID>(outgoing_streams.begin(),
|
||||
outgoing_streams.end())](
|
||||
@ -123,6 +138,7 @@ void CallbackDeferrer::OnStreamsResetPerformed(
|
||||
|
||||
void CallbackDeferrer::OnIncomingStreamsReset(
|
||||
rtc::ArrayView<const StreamID> incoming_streams) {
|
||||
RTC_DCHECK(prepared_);
|
||||
deferred_.emplace_back(
|
||||
[streams = std::vector<StreamID>(incoming_streams.begin(),
|
||||
incoming_streams.end())](
|
||||
@ -130,12 +146,14 @@ void CallbackDeferrer::OnIncomingStreamsReset(
|
||||
}
|
||||
|
||||
void CallbackDeferrer::OnBufferedAmountLow(StreamID stream_id) {
|
||||
RTC_DCHECK(prepared_);
|
||||
deferred_.emplace_back([stream_id](DcSctpSocketCallbacks& cb) {
|
||||
cb.OnBufferedAmountLow(stream_id);
|
||||
});
|
||||
}
|
||||
|
||||
void CallbackDeferrer::OnTotalBufferedAmountLow() {
|
||||
RTC_DCHECK(prepared_);
|
||||
deferred_.emplace_back(
|
||||
[](DcSctpSocketCallbacks& cb) { cb.OnTotalBufferedAmountLow(); });
|
||||
}
|
||||
|
@ -26,7 +26,6 @@
|
||||
#include "rtc_base/ref_counted_object.h"
|
||||
|
||||
namespace dcsctp {
|
||||
|
||||
// Defers callbacks until they can be safely triggered.
|
||||
//
|
||||
// There are a lot of callbacks from the dcSCTP library to the client,
|
||||
@ -44,11 +43,22 @@ namespace dcsctp {
|
||||
// There are a number of exceptions, which is clearly annotated in the API.
|
||||
class CallbackDeferrer : public DcSctpSocketCallbacks {
|
||||
public:
|
||||
class ScopedDeferrer {
|
||||
public:
|
||||
explicit ScopedDeferrer(CallbackDeferrer& callback_deferrer)
|
||||
: callback_deferrer_(callback_deferrer) {
|
||||
callback_deferrer_.Prepare();
|
||||
}
|
||||
|
||||
~ScopedDeferrer() { callback_deferrer_.TriggerDeferred(); }
|
||||
|
||||
private:
|
||||
CallbackDeferrer& callback_deferrer_;
|
||||
};
|
||||
|
||||
explicit CallbackDeferrer(DcSctpSocketCallbacks& underlying)
|
||||
: underlying_(underlying) {}
|
||||
|
||||
void TriggerDeferred();
|
||||
|
||||
// Implementation of DcSctpSocketCallbacks
|
||||
SendPacketStatus SendPacketWithStatus(
|
||||
rtc::ArrayView<const uint8_t> data) override;
|
||||
@ -71,7 +81,11 @@ class CallbackDeferrer : public DcSctpSocketCallbacks {
|
||||
void OnTotalBufferedAmountLow() override;
|
||||
|
||||
private:
|
||||
void Prepare();
|
||||
void TriggerDeferred();
|
||||
|
||||
DcSctpSocketCallbacks& underlying_;
|
||||
bool prepared_ = false;
|
||||
std::vector<std::function<void(DcSctpSocketCallbacks& cb)>> deferred_;
|
||||
};
|
||||
} // namespace dcsctp
|
||||
|
@ -281,6 +281,8 @@ void DcSctpSocket::MakeConnectionParameters() {
|
||||
}
|
||||
|
||||
void DcSctpSocket::Connect() {
|
||||
CallbackDeferrer::ScopedDeferrer deferrer(callbacks_);
|
||||
|
||||
if (state_ == State::kClosed) {
|
||||
MakeConnectionParameters();
|
||||
RTC_DLOG(LS_INFO)
|
||||
@ -296,10 +298,11 @@ void DcSctpSocket::Connect() {
|
||||
<< "Called Connect on a socket that is not closed";
|
||||
}
|
||||
RTC_DCHECK(IsConsistent());
|
||||
callbacks_.TriggerDeferred();
|
||||
}
|
||||
|
||||
void DcSctpSocket::RestoreFromState(const DcSctpSocketHandoverState& state) {
|
||||
CallbackDeferrer::ScopedDeferrer deferrer(callbacks_);
|
||||
|
||||
if (state_ != State::kClosed) {
|
||||
callbacks_.OnError(ErrorKind::kUnsupportedOperation,
|
||||
"Only closed socket can be restored from state");
|
||||
@ -334,10 +337,11 @@ void DcSctpSocket::RestoreFromState(const DcSctpSocketHandoverState& state) {
|
||||
}
|
||||
|
||||
RTC_DCHECK(IsConsistent());
|
||||
callbacks_.TriggerDeferred();
|
||||
}
|
||||
|
||||
void DcSctpSocket::Shutdown() {
|
||||
CallbackDeferrer::ScopedDeferrer deferrer(callbacks_);
|
||||
|
||||
if (tcb_ != nullptr) {
|
||||
// https://tools.ietf.org/html/rfc4960#section-9.2
|
||||
// "Upon receipt of the SHUTDOWN primitive from its upper layer, the
|
||||
@ -361,10 +365,11 @@ void DcSctpSocket::Shutdown() {
|
||||
InternalClose(ErrorKind::kNoError, "");
|
||||
}
|
||||
RTC_DCHECK(IsConsistent());
|
||||
callbacks_.TriggerDeferred();
|
||||
}
|
||||
|
||||
void DcSctpSocket::Close() {
|
||||
CallbackDeferrer::ScopedDeferrer deferrer(callbacks_);
|
||||
|
||||
if (state_ != State::kClosed) {
|
||||
if (tcb_ != nullptr) {
|
||||
SctpPacket::Builder b = tcb_->PacketBuilder();
|
||||
@ -379,7 +384,6 @@ void DcSctpSocket::Close() {
|
||||
RTC_DLOG(LS_INFO) << log_prefix() << "Called Close on a closed socket";
|
||||
}
|
||||
RTC_DCHECK(IsConsistent());
|
||||
callbacks_.TriggerDeferred();
|
||||
}
|
||||
|
||||
void DcSctpSocket::CloseConnectionBecauseOfTooManyTransmissionErrors() {
|
||||
@ -411,6 +415,8 @@ void DcSctpSocket::InternalClose(ErrorKind error, absl::string_view message) {
|
||||
|
||||
SendStatus DcSctpSocket::Send(DcSctpMessage message,
|
||||
const SendOptions& send_options) {
|
||||
CallbackDeferrer::ScopedDeferrer deferrer(callbacks_);
|
||||
|
||||
if (message.payload().empty()) {
|
||||
callbacks_.OnError(ErrorKind::kProtocolViolation,
|
||||
"Unable to send empty message");
|
||||
@ -445,12 +451,13 @@ SendStatus DcSctpSocket::Send(DcSctpMessage message,
|
||||
}
|
||||
|
||||
RTC_DCHECK(IsConsistent());
|
||||
callbacks_.TriggerDeferred();
|
||||
return SendStatus::kSuccess;
|
||||
}
|
||||
|
||||
ResetStreamsStatus DcSctpSocket::ResetStreams(
|
||||
rtc::ArrayView<const StreamID> outgoing_streams) {
|
||||
CallbackDeferrer::ScopedDeferrer deferrer(callbacks_);
|
||||
|
||||
if (tcb_ == nullptr) {
|
||||
callbacks_.OnError(ErrorKind::kWrongSequence,
|
||||
"Can't reset streams as the socket is not connected");
|
||||
@ -472,7 +479,6 @@ ResetStreamsStatus DcSctpSocket::ResetStreams(
|
||||
}
|
||||
|
||||
RTC_DCHECK(IsConsistent());
|
||||
callbacks_.TriggerDeferred();
|
||||
return ResetStreamsStatus::kPerformed;
|
||||
}
|
||||
|
||||
@ -654,6 +660,8 @@ bool DcSctpSocket::ValidatePacket(const SctpPacket& packet) {
|
||||
}
|
||||
|
||||
void DcSctpSocket::HandleTimeout(TimeoutID timeout_id) {
|
||||
CallbackDeferrer::ScopedDeferrer deferrer(callbacks_);
|
||||
|
||||
timer_manager_.HandleTimeout(timeout_id);
|
||||
|
||||
if (tcb_ != nullptr && tcb_->HasTooManyTxErrors()) {
|
||||
@ -662,10 +670,11 @@ void DcSctpSocket::HandleTimeout(TimeoutID timeout_id) {
|
||||
}
|
||||
|
||||
RTC_DCHECK(IsConsistent());
|
||||
callbacks_.TriggerDeferred();
|
||||
}
|
||||
|
||||
void DcSctpSocket::ReceivePacket(rtc::ArrayView<const uint8_t> data) {
|
||||
CallbackDeferrer::ScopedDeferrer deferrer(callbacks_);
|
||||
|
||||
++metrics_.rx_packets_count;
|
||||
|
||||
if (packet_observer_ != nullptr) {
|
||||
@ -681,7 +690,6 @@ void DcSctpSocket::ReceivePacket(rtc::ArrayView<const uint8_t> data) {
|
||||
callbacks_.OnError(ErrorKind::kParseFailed,
|
||||
"Failed to parse received SCTP packet");
|
||||
RTC_DCHECK(IsConsistent());
|
||||
callbacks_.TriggerDeferred();
|
||||
return;
|
||||
}
|
||||
|
||||
@ -696,7 +704,6 @@ void DcSctpSocket::ReceivePacket(rtc::ArrayView<const uint8_t> data) {
|
||||
RTC_DLOG(LS_VERBOSE) << log_prefix()
|
||||
<< "Packet failed verification tag check - dropping";
|
||||
RTC_DCHECK(IsConsistent());
|
||||
callbacks_.TriggerDeferred();
|
||||
return;
|
||||
}
|
||||
|
||||
@ -714,7 +721,6 @@ void DcSctpSocket::ReceivePacket(rtc::ArrayView<const uint8_t> data) {
|
||||
}
|
||||
|
||||
RTC_DCHECK(IsConsistent());
|
||||
callbacks_.TriggerDeferred();
|
||||
}
|
||||
|
||||
void DcSctpSocket::DebugPrintOutgoing(rtc::ArrayView<const uint8_t> payload) {
|
||||
@ -1646,6 +1652,8 @@ HandoverReadinessStatus DcSctpSocket::GetHandoverReadiness() const {
|
||||
|
||||
absl::optional<DcSctpSocketHandoverState>
|
||||
DcSctpSocket::GetHandoverStateAndClose() {
|
||||
CallbackDeferrer::ScopedDeferrer deferrer(callbacks_);
|
||||
|
||||
if (!GetHandoverReadiness().IsReady()) {
|
||||
return absl::nullopt;
|
||||
}
|
||||
@ -1659,7 +1667,6 @@ DcSctpSocket::GetHandoverStateAndClose() {
|
||||
tcb_->AddHandoverState(state);
|
||||
send_queue_.AddHandoverState(state);
|
||||
InternalClose(ErrorKind::kNoError, "handover");
|
||||
callbacks_.TriggerDeferred();
|
||||
}
|
||||
|
||||
return std::move(state);
|
||||
|
Reference in New Issue
Block a user