diff --git a/net/dcsctp/public/types.h b/net/dcsctp/public/types.h index b87fd4e79a..d516daffe3 100644 --- a/net/dcsctp/public/types.h +++ b/net/dcsctp/public/types.h @@ -12,6 +12,7 @@ #define NET_DCSCTP_PUBLIC_TYPES_H_ #include +#include #include "net/dcsctp/public/strong_alias.h" @@ -85,6 +86,10 @@ class TimeMs : public StrongAlias { value_ -= *d; return *this; } + + static constexpr TimeMs InfiniteFuture() { + return TimeMs(std::numeric_limits::max()); + } }; constexpr inline TimeMs operator+(TimeMs lhs, DurationMs rhs) { diff --git a/net/dcsctp/timer/BUILD.gn b/net/dcsctp/timer/BUILD.gn index e95ab676cd..a0ba5b030e 100644 --- a/net/dcsctp/timer/BUILD.gn +++ b/net/dcsctp/timer/BUILD.gn @@ -29,21 +29,45 @@ rtc_library("timer") { ] } +rtc_library("task_queue_timeout") { + deps = [ + "../../../api:array_view", + "../../../api/task_queue:task_queue", + "../../../rtc_base", + "../../../rtc_base:checks", + "../../../rtc_base:rtc_base_approved", + "../../../rtc_base/task_utils:pending_task_safety_flag", + "../../../rtc_base/task_utils:to_queued_task", + "../public:socket", + "../public:strong_alias", + "../public:types", + ] + sources = [ + "task_queue_timeout.cc", + "task_queue_timeout.h", + ] +} + if (rtc_include_tests) { rtc_library("dcsctp_timer_unittests") { testonly = true defines = [] deps = [ + ":task_queue_timeout", ":timer", "../../../api:array_view", "../../../rtc_base:checks", "../../../rtc_base:gunit_helpers", "../../../rtc_base:rtc_base_approved", "../../../test:test_support", + "../../../test/time_controller:time_controller", "../public:socket", ] - sources = [ "timer_test.cc" ] + sources = [ + "task_queue_timeout_test.cc", + "timer_test.cc", + ] absl_deps = [ "//third_party/abseil-cpp/absl/types:optional" ] } } diff --git a/net/dcsctp/timer/task_queue_timeout.cc b/net/dcsctp/timer/task_queue_timeout.cc new file mode 100644 index 0000000000..6d3054eeb8 --- /dev/null +++ b/net/dcsctp/timer/task_queue_timeout.cc @@ -0,0 +1,96 @@ +/* + * Copyright (c) 2021 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ +#include "net/dcsctp/timer/task_queue_timeout.h" + +#include "rtc_base/logging.h" +#include "rtc_base/task_utils/pending_task_safety_flag.h" +#include "rtc_base/task_utils/to_queued_task.h" + +namespace dcsctp { + +TaskQueueTimeoutFactory::TaskQueueTimeout::TaskQueueTimeout( + TaskQueueTimeoutFactory& parent) + : parent_(parent), + pending_task_safety_flag_(webrtc::PendingTaskSafetyFlag::Create()) {} + +TaskQueueTimeoutFactory::TaskQueueTimeout::~TaskQueueTimeout() { + RTC_DCHECK_RUN_ON(&parent_.thread_checker_); + pending_task_safety_flag_->SetNotAlive(); +} + +void TaskQueueTimeoutFactory::TaskQueueTimeout::Start(DurationMs duration_ms, + TimeoutID timeout_id) { + RTC_DCHECK_RUN_ON(&parent_.thread_checker_); + RTC_DCHECK(timeout_expiration_ == TimeMs::InfiniteFuture()); + timeout_expiration_ = parent_.get_time_() + duration_ms; + timeout_id_ = timeout_id; + + if (timeout_expiration_ >= posted_task_expiration_) { + // There is already a running task, and it's scheduled to expire sooner than + // the new expiration time. Don't do anything; The `timeout_expiration_` has + // already been updated and if the delayed task _does_ expire and the timer + // hasn't been stopped, that will be noticed in the timeout handler, and the + // task will be re-scheduled. Most timers are stopped before they expire. + return; + } + + if (posted_task_expiration_ != TimeMs::InfiniteFuture()) { + RTC_DLOG(LS_VERBOSE) << "New timeout duration is less than scheduled - " + "ghosting old delayed task."; + // There is already a scheduled delayed task, but its expiration time is + // further away than the new expiration, so it can't be used. It will be + // "killed" by replacing the safety flag. This is not expected to happen + // especially often; Mainly when a timer did exponential backoff and + // later recovered. + pending_task_safety_flag_->SetNotAlive(); + pending_task_safety_flag_ = webrtc::PendingTaskSafetyFlag::Create(); + } + + posted_task_expiration_ = timeout_expiration_; + parent_.task_queue_.PostDelayedTask( + webrtc::ToQueuedTask( + pending_task_safety_flag_, + [timeout_id, this]() { + RTC_DLOG(LS_VERBOSE) << "Timout expired: " << timeout_id.value(); + RTC_DCHECK_RUN_ON(&parent_.thread_checker_); + RTC_DCHECK(posted_task_expiration_ != TimeMs::InfiniteFuture()); + posted_task_expiration_ = TimeMs::InfiniteFuture(); + + if (timeout_expiration_ == TimeMs::InfiniteFuture()) { + // The timeout was stopped before it expired. Very common. + } else { + // Note that the timeout might have been restarted, which updated + // `timeout_expiration_` but left the scheduled task running. So + // if it's not quite time to trigger the timeout yet, schedule a + // new delayed task with what's remaining and retry at that point + // in time. + DurationMs remaining = timeout_expiration_ - parent_.get_time_(); + timeout_expiration_ = TimeMs::InfiniteFuture(); + if (*remaining > 0) { + Start(remaining, timeout_id_); + } else { + // It has actually triggered. + RTC_DLOG(LS_VERBOSE) + << "Timout triggered: " << timeout_id.value(); + parent_.on_expired_(timeout_id_); + } + } + }), + duration_ms.value()); +} + +void TaskQueueTimeoutFactory::TaskQueueTimeout::Stop() { + // As the TaskQueue doesn't support deleting a posted task, just mark the + // timeout as not running. + RTC_DCHECK_RUN_ON(&parent_.thread_checker_); + timeout_expiration_ = TimeMs::InfiniteFuture(); +} + +} // namespace dcsctp diff --git a/net/dcsctp/timer/task_queue_timeout.h b/net/dcsctp/timer/task_queue_timeout.h new file mode 100644 index 0000000000..e8d12df592 --- /dev/null +++ b/net/dcsctp/timer/task_queue_timeout.h @@ -0,0 +1,88 @@ +/* + * Copyright (c) 2021 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ +#ifndef NET_DCSCTP_TIMER_TASK_QUEUE_TIMEOUT_H_ +#define NET_DCSCTP_TIMER_TASK_QUEUE_TIMEOUT_H_ + +#include +#include + +#include "api/task_queue/task_queue_base.h" +#include "net/dcsctp/public/timeout.h" +#include "rtc_base/task_utils/pending_task_safety_flag.h" + +namespace dcsctp { + +// The TaskQueueTimeoutFactory creates `Timeout` instances, which schedules +// itself to be triggered on the provided `task_queue`, which may be a thread, +// an actual TaskQueue or something else which supports posting a delayed task. +// +// Note that each `DcSctpSocket` must have its own `TaskQueueTimeoutFactory`, +// as the `TimeoutID` are not unique among sockets. +// +// This class must outlive any created Timeout that it has created. Note that +// the `DcSctpSocket` will ensure that all Timeouts are deleted when the socket +// is destructed, so this means that this class must outlive the `DcSctpSocket`. +// +// This class, and the timeouts created it, are not thread safe. +class TaskQueueTimeoutFactory { + public: + // The `get_time` function must return the current time, relative to any + // epoch. Whenever a timeout expires, the `on_expired` callback will be + // triggered, and then the client should provided `timeout_id` to + // `DcSctpSocketInterface::HandleTimeout`. + TaskQueueTimeoutFactory(webrtc::TaskQueueBase& task_queue, + std::function get_time, + std::function on_expired) + : task_queue_(task_queue), + get_time_(std::move(get_time)), + on_expired_(std::move(on_expired)) {} + + // Creates an implementation of `Timeout`. + std::unique_ptr CreateTimeout() { + return std::make_unique(*this); + } + + private: + class TaskQueueTimeout : public Timeout { + public: + explicit TaskQueueTimeout(TaskQueueTimeoutFactory& parent); + ~TaskQueueTimeout(); + + void Start(DurationMs duration_ms, TimeoutID timeout_id) override; + void Stop() override; + + private: + TaskQueueTimeoutFactory& parent_; + // A safety flag to ensure that posted tasks to the task queue don't + // reference these object when they go out of scope. Note that this safety + // flag will be re-created if the scheduled-but-not-yet-expired task is not + // to be run. This happens when there is a posted delayed task with an + // expiration time _further away_ than what is now the expected expiration + // time. In this scenario, a new delayed task has to be posted with a + // shorter duration and the old task has to be forgotten. + rtc::scoped_refptr pending_task_safety_flag_; + // The time when the posted delayed task is set to expire. Will be set to + // the infinite future if there is no such task running. + TimeMs posted_task_expiration_ = TimeMs::InfiniteFuture(); + // The time when the timeout expires. It will be set to the infinite future + // if the timeout is not running/not started. + TimeMs timeout_expiration_ = TimeMs::InfiniteFuture(); + // The current timeout ID that will be reported when expired. + TimeoutID timeout_id_ = TimeoutID(0); + }; + + RTC_NO_UNIQUE_ADDRESS webrtc::SequenceChecker thread_checker_; + webrtc::TaskQueueBase& task_queue_; + const std::function get_time_; + const std::function on_expired_; +}; +} // namespace dcsctp + +#endif // NET_DCSCTP_TIMER_TASK_QUEUE_TIMEOUT_H_ diff --git a/net/dcsctp/timer/task_queue_timeout_test.cc b/net/dcsctp/timer/task_queue_timeout_test.cc new file mode 100644 index 0000000000..9d3846953b --- /dev/null +++ b/net/dcsctp/timer/task_queue_timeout_test.cc @@ -0,0 +1,115 @@ +/* + * Copyright (c) 2021 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ +#include "net/dcsctp/timer/task_queue_timeout.h" + +#include + +#include "rtc_base/gunit.h" +#include "test/gmock.h" +#include "test/time_controller/simulated_time_controller.h" + +namespace dcsctp { +namespace { +using ::testing::MockFunction; + +class TaskQueueTimeoutTest : public testing::Test { + protected: + TaskQueueTimeoutTest() + : time_controller_(webrtc::Timestamp::Millis(1234)), + task_queue_(time_controller_.GetMainThread()), + factory_( + *task_queue_, + [this]() { + return TimeMs(time_controller_.GetClock()->CurrentTime().ms()); + }, + on_expired_.AsStdFunction()) {} + + void AdvanceTime(DurationMs duration) { + time_controller_.AdvanceTime(webrtc::TimeDelta::Millis(*duration)); + } + + MockFunction on_expired_; + webrtc::GlobalSimulatedTimeController time_controller_; + + rtc::Thread* task_queue_; + TaskQueueTimeoutFactory factory_; +}; + +TEST_F(TaskQueueTimeoutTest, StartPostsDelayedTask) { + std::unique_ptr timeout = factory_.CreateTimeout(); + timeout->Start(DurationMs(1000), TimeoutID(1)); + + EXPECT_CALL(on_expired_, Call).Times(0); + AdvanceTime(DurationMs(999)); + + EXPECT_CALL(on_expired_, Call(TimeoutID(1))); + AdvanceTime(DurationMs(1)); +} + +TEST_F(TaskQueueTimeoutTest, StopBeforeExpiringDoesntTrigger) { + std::unique_ptr timeout = factory_.CreateTimeout(); + timeout->Start(DurationMs(1000), TimeoutID(1)); + + EXPECT_CALL(on_expired_, Call).Times(0); + AdvanceTime(DurationMs(999)); + + timeout->Stop(); + + AdvanceTime(DurationMs(1)); + AdvanceTime(DurationMs(1000)); +} + +TEST_F(TaskQueueTimeoutTest, RestartPrologingTimeoutDuration) { + std::unique_ptr timeout = factory_.CreateTimeout(); + timeout->Start(DurationMs(1000), TimeoutID(1)); + + EXPECT_CALL(on_expired_, Call).Times(0); + AdvanceTime(DurationMs(500)); + + timeout->Restart(DurationMs(1000), TimeoutID(2)); + + AdvanceTime(DurationMs(999)); + + EXPECT_CALL(on_expired_, Call(TimeoutID(2))); + AdvanceTime(DurationMs(1)); +} + +TEST_F(TaskQueueTimeoutTest, RestartWithShorterDurationExpiresWhenExpected) { + std::unique_ptr timeout = factory_.CreateTimeout(); + timeout->Start(DurationMs(1000), TimeoutID(1)); + + EXPECT_CALL(on_expired_, Call).Times(0); + AdvanceTime(DurationMs(500)); + + timeout->Restart(DurationMs(200), TimeoutID(2)); + + AdvanceTime(DurationMs(199)); + + EXPECT_CALL(on_expired_, Call(TimeoutID(2))); + AdvanceTime(DurationMs(1)); + + EXPECT_CALL(on_expired_, Call).Times(0); + AdvanceTime(DurationMs(1000)); +} + +TEST_F(TaskQueueTimeoutTest, KilledBeforeExpired) { + std::unique_ptr timeout = factory_.CreateTimeout(); + timeout->Start(DurationMs(1000), TimeoutID(1)); + + EXPECT_CALL(on_expired_, Call).Times(0); + AdvanceTime(DurationMs(500)); + + timeout = nullptr; + + EXPECT_CALL(on_expired_, Call).Times(0); + AdvanceTime(DurationMs(1000)); +} +} // namespace +} // namespace dcsctp