dcsctp: Expire timers just before triggering them

In real life, when a Timeout expires, the caller is supposed to call
DcSctpSocket::HandleTimeout directly, as the Timeout that just expired
is stopped (it just expired), but the Timer still believes it's running.
The system is not in a consistent state.

In tests, all timeouts were evaluated at the same time, which, if two
timeouts expired at the same time, would put them both as "not running",
and with their timers believing they were running. So if you would do
any operation on a timer whose timeout had just expired, the timeout
would assert saying that "you can't stop a stopped timeout" or similar.

This isn't relevant in non-test scenarios.

Solved by expiring timeouts one by one.

Bug: webrtc:12614
Change-Id: I79d006f4d3e96854d77cec3eb0080aa23b8569cb
Reviewed-on: https://webrtc-review.googlesource.com/c/src/+/217560
Reviewed-by: Florent Castelli <orphis@webrtc.org>
Commit-Queue: Victor Boivie <boivie@webrtc.org>
Cr-Commit-Position: refs/heads/master@{#33925}
This commit is contained in:
Victor Boivie
2021-05-05 14:00:50 +02:00
committed by WebRTC LUCI CQ
parent fc88df81f6
commit 1d2fa9a1c3
6 changed files with 53 additions and 33 deletions

View File

@ -208,13 +208,19 @@ class DcSctpSocketTest : public testing::Test {
} while (delivered_packet);
}
void RunTimers(MockDcSctpSocketCallbacks& cb, DcSctpSocket& socket) {
for (;;) {
absl::optional<TimeoutID> timeout_id = cb.GetNextExpiredTimeout();
if (!timeout_id.has_value()) {
break;
}
socket.HandleTimeout(*timeout_id);
}
}
void RunTimers() {
for (const auto timeout_id : cb_a_.RunTimers()) {
sock_a_.HandleTimeout(timeout_id);
}
for (const auto timeout_id : cb_z_.RunTimers()) {
sock_z_.HandleTimeout(timeout_id);
}
RunTimers(cb_a_, sock_a_);
RunTimers(cb_z_, sock_z_);
}
const DcSctpOptions options_;
@ -1025,9 +1031,7 @@ TEST_F(DcSctpSocketTest, PassingHighWatermarkWillOnlyAcceptCumAckTsn) {
// The receiver might have moved into delayed ack mode.
cb_z2.AdvanceTime(options.rto_initial);
for (const auto timeout_id : cb_z2.RunTimers()) {
sock_z2.HandleTimeout(timeout_id);
}
RunTimers(cb_z2, sock_z2);
EXPECT_THAT(
cb_z2.ConsumeSentPacket(),
@ -1066,9 +1070,7 @@ TEST_F(DcSctpSocketTest, PassingHighWatermarkWillOnlyAcceptCumAckTsn) {
// The receiver might have moved into delayed ack mode.
cb_z2.AdvanceTime(options.rto_initial);
for (const auto timeout_id : cb_z2.RunTimers()) {
sock_z2.HandleTimeout(timeout_id);
}
RunTimers(cb_z2, sock_z2);
EXPECT_THAT(
cb_z2.ConsumeSentPacket(),

View File

@ -45,6 +45,17 @@ class HeartbeatHandlerTest : public testing::Test {
timer_manager_([this]() { return callbacks_.CreateTimeout(); }),
handler_("log: ", options_, &context_, &timer_manager_) {}
void AdvanceTime(DurationMs duration) {
callbacks_.AdvanceTime(duration);
for (;;) {
absl::optional<TimeoutID> timeout_id = callbacks_.GetNextExpiredTimeout();
if (!timeout_id.has_value()) {
break;
}
timer_manager_.HandleTimeout(*timeout_id);
}
}
const DcSctpOptions options_;
NiceMock<MockDcSctpSocketCallbacks> callbacks_;
NiceMock<MockContext> context_;
@ -75,10 +86,7 @@ TEST_F(HeartbeatHandlerTest, RepliesToHeartbeatRequests) {
}
TEST_F(HeartbeatHandlerTest, SendsHeartbeatRequestsOnIdleChannel) {
callbacks_.AdvanceTime(options_.heartbeat_interval);
for (TimeoutID id : callbacks_.RunTimers()) {
timer_manager_.HandleTimeout(id);
}
AdvanceTime(options_.heartbeat_interval);
// Grab the request, and make a response.
std::vector<uint8_t> payload = callbacks_.ConsumeSentPacket();
@ -101,22 +109,15 @@ TEST_F(HeartbeatHandlerTest, SendsHeartbeatRequestsOnIdleChannel) {
}
TEST_F(HeartbeatHandlerTest, IncreasesErrorIfNotAckedInTime) {
callbacks_.AdvanceTime(options_.heartbeat_interval);
DurationMs rto(105);
EXPECT_CALL(context_, current_rto).WillOnce(Return(rto));
for (TimeoutID id : callbacks_.RunTimers()) {
timer_manager_.HandleTimeout(id);
}
AdvanceTime(options_.heartbeat_interval);
// Validate that a request was sent.
EXPECT_THAT(callbacks_.ConsumeSentPacket(), Not(IsEmpty()));
EXPECT_CALL(context_, IncrementTxErrorCounter).Times(1);
callbacks_.AdvanceTime(rto);
for (TimeoutID id : callbacks_.RunTimers()) {
timer_manager_.HandleTimeout(id);
}
AdvanceTime(rto);
}
} // namespace

View File

@ -134,7 +134,9 @@ class MockDcSctpSocketCallbacks : public DcSctpSocketCallbacks {
void AdvanceTime(DurationMs duration_ms) { now_ = now_ + duration_ms; }
void SetTime(TimeMs now) { now_ = now; }
std::vector<TimeoutID> RunTimers() { return timeout_manager_.RunTimers(); }
absl::optional<TimeoutID> GetNextExpiredTimeout() {
return timeout_manager_.GetNextExpiredTimeout();
}
private:
TimeMs now_ = TimeMs(0);

View File

@ -119,8 +119,12 @@ class StreamResetHandlerTest : public testing::Test {
void AdvanceTime(DurationMs duration) {
callbacks_.AdvanceTime(kRto);
for (TimeoutID timeout_id : callbacks_.RunTimers()) {
timer_manager_.HandleTimeout(timeout_id);
for (;;) {
absl::optional<TimeoutID> timeout_id = callbacks_.GetNextExpiredTimeout();
if (!timeout_id.has_value()) {
break;
}
timer_manager_.HandleTimeout(*timeout_id);
}
}

View File

@ -18,6 +18,7 @@
#include <utility>
#include <vector>
#include "absl/types/optional.h"
#include "net/dcsctp/public/timeout.h"
namespace dcsctp {
@ -73,15 +74,20 @@ class FakeTimeoutManager {
return timer;
}
std::vector<TimeoutID> RunTimers() {
// NOTE: This can't return a vector, as calling EvaluateHasExpired requires
// calling socket->HandleTimeout directly afterwards, as the owning Timer
// still believes it's running, and it needs to be updated to set
// Timer::is_running_ to false before you operate on the Timer or Timeout
// again.
absl::optional<TimeoutID> GetNextExpiredTimeout() {
TimeMs now = get_time_();
std::vector<TimeoutID> expired_timers;
for (auto& timer : timers_) {
if (timer->EvaluateHasExpired(now)) {
expired_timers.push_back(timer->timeout_id());
return timer->timeout_id();
}
}
return expired_timers;
return absl::nullopt;
}
private:

View File

@ -32,8 +32,13 @@ class TimerTest : public testing::Test {
void AdvanceTimeAndRunTimers(DurationMs duration) {
now_ = now_ + duration;
for (TimeoutID timeout_id : timeout_manager_.RunTimers()) {
manager_.HandleTimeout(timeout_id);
for (;;) {
absl::optional<TimeoutID> timeout_id =
timeout_manager_.GetNextExpiredTimeout();
if (!timeout_id.has_value()) {
break;
}
manager_.HandleTimeout(*timeout_id);
}
}