diff --git a/media/BUILD.gn b/media/BUILD.gn index 332420aac0..0350c0ad8e 100644 --- a/media/BUILD.gn +++ b/media/BUILD.gn @@ -425,7 +425,7 @@ if (rtc_build_dcsctp) { "../rtc_base:socket", "../rtc_base:stringutils", "../rtc_base:threading", - "../rtc_base/containers:flat_set", + "../rtc_base/containers:flat_map", "../rtc_base/task_utils:pending_task_safety_flag", "../rtc_base/task_utils:to_queued_task", "../rtc_base/third_party/sigslot:sigslot", diff --git a/media/sctp/dcsctp_transport.cc b/media/sctp/dcsctp_transport.cc index 6527f6fcaa..3da6709702 100644 --- a/media/sctp/dcsctp_transport.cc +++ b/media/sctp/dcsctp_transport.cc @@ -212,19 +212,27 @@ bool DcSctpTransport::OpenStream(int sid) { << "): Transport is not started."; return false; } - local_close_.erase(dcsctp::StreamID(static_cast(sid))); return true; } bool DcSctpTransport::ResetStream(int sid) { RTC_LOG(LS_INFO) << debug_name_ << "->ResetStream(" << sid << ")."; if (!socket_) { - RTC_LOG(LS_ERROR) << debug_name_ << "->OpenStream(sid=" << sid + RTC_LOG(LS_ERROR) << debug_name_ << "->ResetStream(sid=" << sid << "): Transport is not started."; return false; } + dcsctp::StreamID streams[1] = {dcsctp::StreamID(static_cast(sid))}; - local_close_.insert(streams[0]); + + StreamClosingState& closing_state = closing_states_[streams[0]]; + if (closing_state.closure_initiated || closing_state.incoming_reset_done || + closing_state.outgoing_reset_done) { + // The closing procedure was already initiated by the remote, don't do + // anything. + return false; + } + closing_state.closure_initiated = true; socket_->ResetStreams(streams); return true; } @@ -484,10 +492,14 @@ void DcSctpTransport::OnStreamsResetPerformed( RTC_LOG(LS_INFO) << debug_name_ << "->OnStreamsResetPerformed(...): Outgoing stream reset" << ", sid=" << stream_id.value(); - if (!local_close_.contains(stream_id)) { - // When the close was not initiated locally, we can signal the end of the - // data channel close procedure when the remote ACKs the reset. + StreamClosingState& closing_state = closing_states_[stream_id]; + closing_state.outgoing_reset_done = true; + + if (closing_state.incoming_reset_done) { + // When the close was not initiated locally, we can signal the end of the + // data channel close procedure when the remote ACKs the reset. SignalClosingProcedureComplete(stream_id.value()); + closing_states_.erase(stream_id); } } } @@ -498,17 +510,23 @@ void DcSctpTransport::OnIncomingStreamsReset( RTC_LOG(LS_INFO) << debug_name_ << "->OnIncomingStreamsReset(...): Incoming stream reset" << ", sid=" << stream_id.value(); - if (!local_close_.contains(stream_id)) { + StreamClosingState& closing_state = closing_states_[stream_id]; + closing_state.incoming_reset_done = true; + + if (!closing_state.closure_initiated) { // When receiving an incoming stream reset event for a non local close // procedure, the transport needs to reset the stream in the other // direction too. dcsctp::StreamID streams[1] = {stream_id}; socket_->ResetStreams(streams); SignalClosingProcedureStartedRemotely(stream_id.value()); - } else { + } + + if (closing_state.outgoing_reset_done) { // The close procedure that was initiated locally is complete when we // receive and incoming reset event. SignalClosingProcedureComplete(stream_id.value()); + closing_states_.erase(stream_id); } } } diff --git a/media/sctp/dcsctp_transport.h b/media/sctp/dcsctp_transport.h index 5e3401d471..c62a28f3c5 100644 --- a/media/sctp/dcsctp_transport.h +++ b/media/sctp/dcsctp_transport.h @@ -25,7 +25,7 @@ #include "net/dcsctp/public/types.h" #include "net/dcsctp/timer/task_queue_timeout.h" #include "p2p/base/packet_transport_internal.h" -#include "rtc_base/containers/flat_set.h" +#include "rtc_base/containers/flat_map.h" #include "rtc_base/copy_on_write_buffer.h" #include "rtc_base/random.h" #include "rtc_base/third_party/sigslot/sigslot.h" @@ -111,7 +111,22 @@ class DcSctpTransport : public cricket::SctpTransportInternal, std::string debug_name_ = "DcSctpTransport"; rtc::CopyOnWriteBuffer receive_buffer_; - flat_set local_close_; + // Used to keep track of the closing state of the data channel. + // Reset needs to happen both ways before signaling the transport + // is closed. + struct StreamClosingState { + // True when the local connection has initiated the reset. + // If a connection receives a reset for a stream that isn't + // already being reset locally, it needs to fire the signal + // SignalClosingProcedureStartedRemotely. + bool closure_initiated = false; + // True when the local connection received OnIncomingStreamsReset + bool incoming_reset_done = false; + // True when the local connection received OnStreamsResetPerformed + bool outgoing_reset_done = false; + }; + + flat_map closing_states_; bool ready_to_send_data_ = false; }; diff --git a/media/sctp/dcsctp_transport_unittest.cc b/media/sctp/dcsctp_transport_unittest.cc index b382dc9548..119922c4df 100644 --- a/media/sctp/dcsctp_transport_unittest.cc +++ b/media/sctp/dcsctp_transport_unittest.cc @@ -91,6 +91,8 @@ TEST(DcSctpTransportTest, OpenSequence) { peer_a.sctp_transport_->Start(5000, 5000, 256 * 1024); } +// Tests that the close sequence invoked from one end results in the stream to +// be reset from both ends and all the proper signals are sent. TEST(DcSctpTransportTest, CloseSequence) { Peer peer_a; Peer peer_b; @@ -100,30 +102,76 @@ TEST(DcSctpTransportTest, CloseSequence) { InSequence sequence; EXPECT_CALL(*peer_a.socket_, ResetStreams(ElementsAre(dcsctp::StreamID(1)))) - .WillOnce(DoAll( - Invoke(peer_b.sctp_transport_.get(), - &dcsctp::DcSctpSocketCallbacks::OnIncomingStreamsReset), - Invoke(peer_a.sctp_transport_.get(), - &dcsctp::DcSctpSocketCallbacks::OnStreamsResetPerformed), - Return(dcsctp::ResetStreamsStatus::kPerformed))); + .WillOnce(Return(dcsctp::ResetStreamsStatus::kPerformed)); EXPECT_CALL(*peer_b.socket_, ResetStreams(ElementsAre(dcsctp::StreamID(1)))) - .WillOnce(DoAll( - Invoke(peer_a.sctp_transport_.get(), - &dcsctp::DcSctpSocketCallbacks::OnIncomingStreamsReset), - Invoke(peer_b.sctp_transport_.get(), - &dcsctp::DcSctpSocketCallbacks::OnStreamsResetPerformed), - Return(dcsctp::ResetStreamsStatus::kPerformed))); + .WillOnce(Return(dcsctp::ResetStreamsStatus::kPerformed)); + EXPECT_CALL(peer_a.observer_, OnSignalClosingProcedureStartedRemotely(1)) + .Times(0); + EXPECT_CALL(peer_b.observer_, OnSignalClosingProcedureStartedRemotely(1)); EXPECT_CALL(peer_a.observer_, OnSignalClosingProcedureComplete(1)); EXPECT_CALL(peer_b.observer_, OnSignalClosingProcedureComplete(1)); - EXPECT_CALL(peer_b.observer_, OnSignalClosingProcedureStartedRemotely(1)); } peer_a.sctp_transport_->Start(5000, 5000, 256 * 1024); peer_b.sctp_transport_->Start(5000, 5000, 256 * 1024); peer_a.sctp_transport_->OpenStream(1); peer_a.sctp_transport_->ResetStream(1); + + // Simulate the callbacks from the stream resets + dcsctp::StreamID streams[1] = {dcsctp::StreamID(1)}; + static_cast(peer_a.sctp_transport_.get()) + ->OnStreamsResetPerformed(streams); + static_cast(peer_b.sctp_transport_.get()) + ->OnIncomingStreamsReset(streams); + static_cast(peer_a.sctp_transport_.get()) + ->OnIncomingStreamsReset(streams); + static_cast(peer_b.sctp_transport_.get()) + ->OnStreamsResetPerformed(streams); +} + +// Tests that the close sequence initiated from both peers at the same time +// terminates properly. Both peers will think they initiated it, so no +// OnSignalClosingProcedureStartedRemotely should be called. +TEST(DcSctpTransportTest, CloseSequenceSimultaneous) { + Peer peer_a; + Peer peer_b; + peer_a.fake_packet_transport_.SetDestination(&peer_b.fake_packet_transport_, + false); + { + InSequence sequence; + + EXPECT_CALL(*peer_a.socket_, ResetStreams(ElementsAre(dcsctp::StreamID(1)))) + .WillOnce(Return(dcsctp::ResetStreamsStatus::kPerformed)); + + EXPECT_CALL(*peer_b.socket_, ResetStreams(ElementsAre(dcsctp::StreamID(1)))) + .WillOnce(Return(dcsctp::ResetStreamsStatus::kPerformed)); + + EXPECT_CALL(peer_a.observer_, OnSignalClosingProcedureStartedRemotely(1)) + .Times(0); + EXPECT_CALL(peer_b.observer_, OnSignalClosingProcedureStartedRemotely(1)) + .Times(0); + EXPECT_CALL(peer_a.observer_, OnSignalClosingProcedureComplete(1)); + EXPECT_CALL(peer_b.observer_, OnSignalClosingProcedureComplete(1)); + } + + peer_a.sctp_transport_->Start(5000, 5000, 256 * 1024); + peer_b.sctp_transport_->Start(5000, 5000, 256 * 1024); + peer_a.sctp_transport_->OpenStream(1); + peer_a.sctp_transport_->ResetStream(1); + peer_b.sctp_transport_->ResetStream(1); + + // Simulate the callbacks from the stream resets + dcsctp::StreamID streams[1] = {dcsctp::StreamID(1)}; + static_cast(peer_a.sctp_transport_.get()) + ->OnStreamsResetPerformed(streams); + static_cast(peer_b.sctp_transport_.get()) + ->OnStreamsResetPerformed(streams); + static_cast(peer_a.sctp_transport_.get()) + ->OnIncomingStreamsReset(streams); + static_cast(peer_b.sctp_transport_.get()) + ->OnIncomingStreamsReset(streams); } } // namespace webrtc diff --git a/pc/data_channel_integrationtest.cc b/pc/data_channel_integrationtest.cc index ea66adaf08..d184a81732 100644 --- a/pc/data_channel_integrationtest.cc +++ b/pc/data_channel_integrationtest.cc @@ -484,7 +484,7 @@ TEST_P(DataChannelIntegrationTest, StressTestUnorderedSctpDataChannel) { // Repeatedly open and close data channels on a peer connection to check that // the channels are properly negotiated and SCTP stream IDs properly recycled. -TEST_P(DataChannelIntegrationTest, StressTestOpenCloseChannel) { +TEST_P(DataChannelIntegrationTest, StressTestOpenCloseChannelNoDelay) { ASSERT_TRUE(CreatePeerConnectionWrappers()); ConnectFakeSignaling(); @@ -511,7 +511,7 @@ TEST_P(DataChannelIntegrationTest, StressTestOpenCloseChannel) { } for (size_t i = 0; i < kChannelCount; ++i) { - EXPECT_EQ_WAIT(caller()->data_channels()[i]->state(), + ASSERT_EQ_WAIT(caller()->data_channels()[i]->state(), DataChannelInterface::DataState::kOpen, kDefaultTimeout); RTC_LOG(LS_INFO) << "Caller Channel " << caller()->data_channels()[i]->label() << " with id " @@ -520,21 +520,106 @@ TEST_P(DataChannelIntegrationTest, StressTestOpenCloseChannel) { ASSERT_EQ_WAIT(callee()->data_channels().size(), kChannelCount, kDefaultTimeout); for (size_t i = 0; i < kChannelCount; ++i) { - EXPECT_EQ_WAIT(callee()->data_channels()[i]->state(), + ASSERT_EQ_WAIT(callee()->data_channels()[i]->state(), DataChannelInterface::DataState::kOpen, kDefaultTimeout); RTC_LOG(LS_INFO) << "Callee Channel " << callee()->data_channels()[i]->label() << " with id " << callee()->data_channels()[i]->id() << " is open."; } + // Closing from both sides to attempt creating races. + // A real application would likely only close from one side. for (size_t i = 0; i < kChannelCount; ++i) { - caller()->data_channels()[i]->Close(); + if (i % 3 == 0) { + callee()->data_channels()[i]->Close(); + caller()->data_channels()[i]->Close(); + } else { + caller()->data_channels()[i]->Close(); + callee()->data_channels()[i]->Close(); + } } for (size_t i = 0; i < kChannelCount; ++i) { - EXPECT_EQ_WAIT(caller()->data_channels()[i]->state(), + ASSERT_EQ_WAIT(caller()->data_channels()[i]->state(), DataChannelInterface::DataState::kClosed, kDefaultTimeout); - EXPECT_EQ_WAIT(callee()->data_channels()[i]->state(), + ASSERT_EQ_WAIT(callee()->data_channels()[i]->state(), + DataChannelInterface::DataState::kClosed, kDefaultTimeout); + } + + caller()->data_channels().clear(); + caller()->data_observers().clear(); + callee()->data_channels().clear(); + callee()->data_observers().clear(); + } +} + +// Repeatedly open and close data channels on a peer connection to check that +// the channels are properly negotiated and SCTP stream IDs properly recycled. +// Some delay is added for better coverage. +TEST_P(DataChannelIntegrationTest, StressTestOpenCloseChannelWithDelay) { + // Simulate some network delay + virtual_socket_server()->set_delay_mean(20); + virtual_socket_server()->set_delay_stddev(5); + virtual_socket_server()->UpdateDelayDistribution(); + + ASSERT_TRUE(CreatePeerConnectionWrappers()); + ConnectFakeSignaling(); + + int channel_id = 0; + const size_t kChannelCount = 8; + const size_t kIterations = 10; + bool has_negotiated = false; + + webrtc::DataChannelInit init; + for (size_t repeats = 0; repeats < kIterations; ++repeats) { + RTC_LOG(LS_INFO) << "Iteration " << (repeats + 1) << "/" << kIterations; + + for (size_t i = 0; i < kChannelCount; ++i) { + rtc::StringBuilder sb; + sb << "channel-" << channel_id++; + caller()->CreateDataChannel(sb.Release(), &init); + } + ASSERT_EQ(caller()->data_channels().size(), kChannelCount); + + if (!has_negotiated) { + caller()->CreateAndSetAndSignalOffer(); + ASSERT_TRUE_WAIT(SignalingStateStable(), kDefaultTimeout); + has_negotiated = true; + } + + for (size_t i = 0; i < kChannelCount; ++i) { + ASSERT_EQ_WAIT(caller()->data_channels()[i]->state(), + DataChannelInterface::DataState::kOpen, kDefaultTimeout); + RTC_LOG(LS_INFO) << "Caller Channel " + << caller()->data_channels()[i]->label() << " with id " + << caller()->data_channels()[i]->id() << " is open."; + } + ASSERT_EQ_WAIT(callee()->data_channels().size(), kChannelCount, + kDefaultTimeout); + for (size_t i = 0; i < kChannelCount; ++i) { + ASSERT_EQ_WAIT(callee()->data_channels()[i]->state(), + DataChannelInterface::DataState::kOpen, kDefaultTimeout); + RTC_LOG(LS_INFO) << "Callee Channel " + << callee()->data_channels()[i]->label() << " with id " + << callee()->data_channels()[i]->id() << " is open."; + } + + // Closing from both sides to attempt creating races. + // A real application would likely only close from one side. + for (size_t i = 0; i < kChannelCount; ++i) { + if (i % 3 == 0) { + callee()->data_channels()[i]->Close(); + caller()->data_channels()[i]->Close(); + } else { + caller()->data_channels()[i]->Close(); + callee()->data_channels()[i]->Close(); + } + } + + for (size_t i = 0; i < kChannelCount; ++i) { + ASSERT_EQ_WAIT(caller()->data_channels()[i]->state(), + DataChannelInterface::DataState::kClosed, kDefaultTimeout); + ASSERT_EQ_WAIT(callee()->data_channels()[i]->state(), DataChannelInterface::DataState::kClosed, kDefaultTimeout); } diff --git a/pc/sctp_data_channel.cc b/pc/sctp_data_channel.cc index 356493658a..9333be96ad 100644 --- a/pc/sctp_data_channel.cc +++ b/pc/sctp_data_channel.cc @@ -255,7 +255,7 @@ uint64_t SctpDataChannel::buffered_amount() const { void SctpDataChannel::Close() { RTC_DCHECK_RUN_ON(signaling_thread_); - if (state_ == kClosed) + if (state_ == kClosing || state_ == kClosed) return; SetState(kClosing); // Will send queued data before beginning the underlying closing procedure.