diff --git a/net/dcsctp/packet/chunk/data_chunk.cc b/net/dcsctp/packet/chunk/data_chunk.cc index cf65f53d29..769be2db91 100644 --- a/net/dcsctp/packet/chunk/data_chunk.cc +++ b/net/dcsctp/packet/chunk/data_chunk.cc @@ -93,7 +93,7 @@ std::string DataChunk::ToString() const { ? "complete" : *options().is_beginning ? "first" : *options().is_end ? "last" : "middle") - << ", tsn=" << *tsn() << ", stream_id=" << *stream_id() + << ", tsn=" << *tsn() << ", sid=" << *stream_id() << ", ssn=" << *ssn() << ", ppid=" << *ppid() << ", length=" << payload().size(); return sb.Release(); } diff --git a/net/dcsctp/packet/chunk/data_chunk_test.cc b/net/dcsctp/packet/chunk/data_chunk_test.cc index 6a5ca82bae..def99ceb23 100644 --- a/net/dcsctp/packet/chunk/data_chunk_test.cc +++ b/net/dcsctp/packet/chunk/data_chunk_test.cc @@ -67,7 +67,7 @@ TEST(DataChunkTest, SerializeAndDeserialize) { EXPECT_THAT(chunk.payload(), ElementsAre(1, 2, 3, 4, 5)); EXPECT_EQ(deserialized.ToString(), - "DATA, type=ordered::middle, tsn=123, stream_id=456, ppid=9090, " + "DATA, type=ordered::middle, tsn=123, sid=456, ssn=789, ppid=9090, " "length=5"); } } // namespace diff --git a/net/dcsctp/public/dcsctp_handover_state.h b/net/dcsctp/public/dcsctp_handover_state.h index 3ad4ab7f3e..370501c0f8 100644 --- a/net/dcsctp/public/dcsctp_handover_state.h +++ b/net/dcsctp/public/dcsctp_handover_state.h @@ -43,6 +43,12 @@ struct DcSctpSocketHandoverState { }; Capabilities capabilities; + struct OutgoingStream { + uint32_t id = 0; + uint32_t next_ssn = 0; + uint32_t next_unordered_mid = 0; + uint32_t next_ordered_mid = 0; + }; struct Transmission { uint32_t next_tsn = 0; uint32_t next_reset_req_sn = 0; @@ -50,6 +56,7 @@ struct DcSctpSocketHandoverState { uint32_t rwnd = 0; uint32_t ssthresh = 0; uint32_t partial_bytes_acked = 0; + std::vector streams; }; Transmission tx; diff --git a/net/dcsctp/socket/dcsctp_socket.cc b/net/dcsctp/socket/dcsctp_socket.cc index bfe248c055..7910a3953f 100644 --- a/net/dcsctp/socket/dcsctp_socket.cc +++ b/net/dcsctp/socket/dcsctp_socket.cc @@ -302,6 +302,8 @@ void DcSctpSocket::RestoreFromState(const DcSctpSocketHandoverState& state) { state.capabilities.message_interleaving; capabilities.reconfig = state.capabilities.reconfig; + send_queue_.RestoreFromState(state); + tcb_ = std::make_unique( timer_manager_, log_prefix_, options_, capabilities, callbacks_, send_queue_, my_verification_tag, TSN(state.my_initial_tsn), @@ -1619,9 +1621,7 @@ HandoverReadinessStatus DcSctpSocket::GetHandoverReadiness() const { if (state_ != State::kClosed && state_ != State::kEstablished) { status.Add(HandoverUnreadinessReason::kWrongConnectionState); } - if (!send_queue_.IsEmpty()) { - status.Add(HandoverUnreadinessReason::kSendQueueNotEmpty); - } + status.Add(send_queue_.GetHandoverReadiness()); if (tcb_) { status.Add(tcb_->GetHandoverReadiness()); } @@ -1641,6 +1641,7 @@ DcSctpSocket::GetHandoverStateAndClose() { } else if (state_ == State::kEstablished) { state.socket_state = DcSctpSocketHandoverState::SocketState::kConnected; tcb_->AddHandoverState(state); + send_queue_.AddHandoverState(state); InternalClose(ErrorKind::kNoError, "handover"); callbacks_.TriggerDeferred(); } diff --git a/net/dcsctp/tx/rr_send_queue.cc b/net/dcsctp/tx/rr_send_queue.cc index 254214e554..eaaf34a0e0 100644 --- a/net/dcsctp/tx/rr_send_queue.cc +++ b/net/dcsctp/tx/rr_send_queue.cc @@ -27,6 +27,19 @@ namespace dcsctp { +RRSendQueue::RRSendQueue(absl::string_view log_prefix, + size_t buffer_size, + std::function on_buffered_amount_low, + size_t total_buffered_amount_low_threshold, + std::function on_total_buffered_amount_low, + const DcSctpSocketHandoverState* handover_state) + : log_prefix_(std::string(log_prefix) + "fcfs: "), + buffer_size_(buffer_size), + on_buffered_amount_low_(std::move(on_buffered_amount_low)), + total_buffered_amount_(std::move(on_total_buffered_amount_low)) { + total_buffered_amount_.SetLowThreshold(total_buffered_amount_low_threshold); +} + bool RRSendQueue::OutgoingStream::HasDataToSend(TimeMs now) { while (!items_.empty()) { RRSendQueue::OutgoingStream::Item& item = items_.front(); @@ -53,6 +66,13 @@ bool RRSendQueue::OutgoingStream::HasDataToSend(TimeMs now) { return false; } +void RRSendQueue::OutgoingStream::AddHandoverState( + DcSctpSocketHandoverState::OutgoingStream& state) const { + state.next_ssn = next_ssn_.value(); + state.next_ordered_mid = next_ordered_mid_.value(); + state.next_unordered_mid = next_unordered_mid_.value(); +} + bool RRSendQueue::IsConsistent() const { size_t total_buffered_amount = 0; for (const auto& stream_entry : streams_) { @@ -433,4 +453,33 @@ RRSendQueue::OutgoingStream& RRSendQueue::GetOrCreateStreamInfo( total_buffered_amount_)) .first->second; } + +HandoverReadinessStatus RRSendQueue::GetHandoverReadiness() const { + HandoverReadinessStatus status; + if (!IsEmpty()) { + status.Add(HandoverUnreadinessReason::kSendQueueNotEmpty); + } + return status; +} + +void RRSendQueue::AddHandoverState(DcSctpSocketHandoverState& state) { + for (const auto& entry : streams_) { + DcSctpSocketHandoverState::OutgoingStream state_stream; + state_stream.id = entry.first.value(); + entry.second.AddHandoverState(state_stream); + state.tx.streams.push_back(std::move(state_stream)); + } +} + +void RRSendQueue::RestoreFromState(const DcSctpSocketHandoverState& state) { + for (const DcSctpSocketHandoverState::OutgoingStream& state_stream : + state.tx.streams) { + StreamID stream_id(state_stream.id); + streams_.emplace(stream_id, OutgoingStream( + [this, stream_id]() { + on_buffered_amount_low_(stream_id); + }, + total_buffered_amount_, &state_stream)); + } +} } // namespace dcsctp diff --git a/net/dcsctp/tx/rr_send_queue.h b/net/dcsctp/tx/rr_send_queue.h index ed077cdef7..94b80d606e 100644 --- a/net/dcsctp/tx/rr_send_queue.h +++ b/net/dcsctp/tx/rr_send_queue.h @@ -47,13 +47,8 @@ class RRSendQueue : public SendQueue { size_t buffer_size, std::function on_buffered_amount_low, size_t total_buffered_amount_low_threshold, - std::function on_total_buffered_amount_low) - : log_prefix_(std::string(log_prefix) + "fcfs: "), - buffer_size_(buffer_size), - on_buffered_amount_low_(std::move(on_buffered_amount_low)), - total_buffered_amount_(std::move(on_total_buffered_amount_low)) { - total_buffered_amount_.SetLowThreshold(total_buffered_amount_low_threshold); - } + std::function on_total_buffered_amount_low, + const DcSctpSocketHandoverState* handover_state = nullptr); // Indicates if the buffer is full. Note that it's up to the caller to ensure // that the buffer is not full prior to adding new items to it. @@ -86,6 +81,10 @@ class RRSendQueue : public SendQueue { size_t buffered_amount_low_threshold(StreamID stream_id) const override; void SetBufferedAmountLowThreshold(StreamID stream_id, size_t bytes) override; + HandoverReadinessStatus GetHandoverReadiness() const; + void AddHandoverState(DcSctpSocketHandoverState& state); + void RestoreFromState(const DcSctpSocketHandoverState& state); + private: // Represents a value and a "low threshold" that when the value reaches or // goes under the "low threshold", will trigger `on_threshold_reached` @@ -113,9 +112,14 @@ class RRSendQueue : public SendQueue { // Per-stream information. class OutgoingStream { public: - explicit OutgoingStream(std::function on_buffered_amount_low, - ThresholdWatcher& total_buffered_amount) - : buffered_amount_(std::move(on_buffered_amount_low)), + explicit OutgoingStream( + std::function on_buffered_amount_low, + ThresholdWatcher& total_buffered_amount, + const DcSctpSocketHandoverState::OutgoingStream* state = nullptr) + : next_unordered_mid_(MID(state ? state->next_unordered_mid : 0)), + next_ordered_mid_(MID(state ? state->next_ordered_mid : 0)), + next_ssn_(SSN(state ? state->next_ssn : 0)), + buffered_amount_(std::move(on_buffered_amount_low)), total_buffered_amount_(total_buffered_amount) {} // Enqueues a message to this stream. @@ -150,6 +154,9 @@ class RRSendQueue : public SendQueue { // expired non-partially sent message. bool HasDataToSend(TimeMs now); + void AddHandoverState( + DcSctpSocketHandoverState::OutgoingStream& state) const; + private: // An enqueued message and metadata. struct Item { @@ -181,10 +188,10 @@ class RRSendQueue : public SendQueue { // Streams are pause when they are about to be reset. bool is_paused_ = false; // MIDs are different for unordered and ordered messages sent on a stream. - MID next_unordered_mid_ = MID(0); - MID next_ordered_mid_ = MID(0); + MID next_unordered_mid_; + MID next_ordered_mid_; - SSN next_ssn_ = SSN(0); + SSN next_ssn_; // Enqueued messages, and metadata. std::deque items_;