dcsctp: Add timer safeguards and sanity checks
Ensuring that timer durations never go beyond a safe maximum duration and that timer IDs are not re-used. Bug: webrtc:12614 Change-Id: I227a2e9933da16669dc6ea0a39c570892010ba2c Reviewed-on: https://webrtc-review.googlesource.com/c/src/+/215063 Commit-Queue: Victor Boivie <boivie@webrtc.org> Reviewed-by: Tommi <tommi@webrtc.org> Cr-Commit-Position: refs/heads/master@{#33860}
This commit is contained in:

committed by
WebRTC LUCI CQ

parent
769629e02f
commit
5d3bda58fd
@ -14,6 +14,7 @@ rtc_library("timer") {
|
|||||||
"../../../rtc_base",
|
"../../../rtc_base",
|
||||||
"../../../rtc_base:checks",
|
"../../../rtc_base:checks",
|
||||||
"../../../rtc_base:rtc_base_approved",
|
"../../../rtc_base:rtc_base_approved",
|
||||||
|
"../public:strong_alias",
|
||||||
"../public:types",
|
"../public:types",
|
||||||
]
|
]
|
||||||
sources = [
|
sources = [
|
||||||
|
@ -9,7 +9,9 @@
|
|||||||
*/
|
*/
|
||||||
#include "net/dcsctp/timer/timer.h"
|
#include "net/dcsctp/timer/timer.h"
|
||||||
|
|
||||||
|
#include <algorithm>
|
||||||
#include <cstdint>
|
#include <cstdint>
|
||||||
|
#include <limits>
|
||||||
#include <memory>
|
#include <memory>
|
||||||
#include <unordered_map>
|
#include <unordered_map>
|
||||||
#include <utility>
|
#include <utility>
|
||||||
@ -17,11 +19,12 @@
|
|||||||
#include "absl/memory/memory.h"
|
#include "absl/memory/memory.h"
|
||||||
#include "absl/strings/string_view.h"
|
#include "absl/strings/string_view.h"
|
||||||
#include "net/dcsctp/public/timeout.h"
|
#include "net/dcsctp/public/timeout.h"
|
||||||
|
#include "rtc_base/checks.h"
|
||||||
|
|
||||||
namespace dcsctp {
|
namespace dcsctp {
|
||||||
namespace {
|
namespace {
|
||||||
TimeoutID MakeTimeoutId(uint32_t timer_id, uint32_t generation) {
|
TimeoutID MakeTimeoutId(TimerID timer_id, TimerGeneration generation) {
|
||||||
return TimeoutID(static_cast<uint64_t>(timer_id) << 32 | generation);
|
return TimeoutID(static_cast<uint64_t>(*timer_id) << 32 | *generation);
|
||||||
}
|
}
|
||||||
|
|
||||||
DurationMs GetBackoffDuration(TimerBackoffAlgorithm algorithm,
|
DurationMs GetBackoffDuration(TimerBackoffAlgorithm algorithm,
|
||||||
@ -30,13 +33,23 @@ DurationMs GetBackoffDuration(TimerBackoffAlgorithm algorithm,
|
|||||||
switch (algorithm) {
|
switch (algorithm) {
|
||||||
case TimerBackoffAlgorithm::kFixed:
|
case TimerBackoffAlgorithm::kFixed:
|
||||||
return base_duration;
|
return base_duration;
|
||||||
case TimerBackoffAlgorithm::kExponential:
|
case TimerBackoffAlgorithm::kExponential: {
|
||||||
return DurationMs(*base_duration * (1 << expiration_count));
|
int32_t duration_ms = *base_duration;
|
||||||
|
|
||||||
|
while (expiration_count > 0 && duration_ms < *Timer::kMaxTimerDuration) {
|
||||||
|
duration_ms *= 2;
|
||||||
|
--expiration_count;
|
||||||
|
}
|
||||||
|
|
||||||
|
return DurationMs(std::min(duration_ms, *Timer::kMaxTimerDuration));
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
Timer::Timer(uint32_t id,
|
constexpr DurationMs Timer::kMaxTimerDuration;
|
||||||
|
|
||||||
|
Timer::Timer(TimerID id,
|
||||||
absl::string_view name,
|
absl::string_view name,
|
||||||
OnExpired on_expired,
|
OnExpired on_expired,
|
||||||
UnregisterHandler unregister_handler,
|
UnregisterHandler unregister_handler,
|
||||||
@ -59,11 +72,13 @@ void Timer::Start() {
|
|||||||
expiration_count_ = 0;
|
expiration_count_ = 0;
|
||||||
if (!is_running()) {
|
if (!is_running()) {
|
||||||
is_running_ = true;
|
is_running_ = true;
|
||||||
timeout_->Start(duration_, MakeTimeoutId(id_, ++generation_));
|
generation_ = TimerGeneration(*generation_ + 1);
|
||||||
|
timeout_->Start(duration_, MakeTimeoutId(id_, generation_));
|
||||||
} else {
|
} else {
|
||||||
// Timer was running - stop and restart it, to make it expire in `duration_`
|
// Timer was running - stop and restart it, to make it expire in `duration_`
|
||||||
// from now.
|
// from now.
|
||||||
timeout_->Restart(duration_, MakeTimeoutId(id_, ++generation_));
|
generation_ = TimerGeneration(*generation_ + 1);
|
||||||
|
timeout_->Restart(duration_, MakeTimeoutId(id_, generation_));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -75,7 +90,7 @@ void Timer::Stop() {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
void Timer::Trigger(uint32_t generation) {
|
void Timer::Trigger(TimerGeneration generation) {
|
||||||
if (is_running_ && generation == generation_) {
|
if (is_running_ && generation == generation_) {
|
||||||
++expiration_count_;
|
++expiration_count_;
|
||||||
if (options_.max_restarts >= 0 &&
|
if (options_.max_restarts >= 0 &&
|
||||||
@ -92,14 +107,15 @@ void Timer::Trigger(uint32_t generation) {
|
|||||||
// Restart it with new duration.
|
// Restart it with new duration.
|
||||||
DurationMs duration = GetBackoffDuration(options_.backoff_algorithm,
|
DurationMs duration = GetBackoffDuration(options_.backoff_algorithm,
|
||||||
duration_, expiration_count_);
|
duration_, expiration_count_);
|
||||||
timeout_->Start(duration, MakeTimeoutId(id_, ++generation_));
|
generation_ = TimerGeneration(*generation_ + 1);
|
||||||
|
timeout_->Start(duration, MakeTimeoutId(id_, generation_));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
void TimerManager::HandleTimeout(TimeoutID timeout_id) {
|
void TimerManager::HandleTimeout(TimeoutID timeout_id) {
|
||||||
uint32_t timer_id = *timeout_id >> 32;
|
TimerID timer_id(*timeout_id >> 32);
|
||||||
uint32_t generation = *timeout_id;
|
TimerGeneration generation(*timeout_id);
|
||||||
auto it = timers_.find(timer_id);
|
auto it = timers_.find(timer_id);
|
||||||
if (it != timers_.end()) {
|
if (it != timers_.end()) {
|
||||||
it->second->Trigger(generation);
|
it->second->Trigger(generation);
|
||||||
@ -109,7 +125,12 @@ void TimerManager::HandleTimeout(TimeoutID timeout_id) {
|
|||||||
std::unique_ptr<Timer> TimerManager::CreateTimer(absl::string_view name,
|
std::unique_ptr<Timer> TimerManager::CreateTimer(absl::string_view name,
|
||||||
Timer::OnExpired on_expired,
|
Timer::OnExpired on_expired,
|
||||||
const TimerOptions& options) {
|
const TimerOptions& options) {
|
||||||
uint32_t id = ++next_id_;
|
next_id_ = TimerID(*next_id_ + 1);
|
||||||
|
TimerID id = next_id_;
|
||||||
|
// This would overflow after 4 billion timers created, which in SCTP would be
|
||||||
|
// after 800 million reconnections on a single socket. Ensure this will never
|
||||||
|
// happen.
|
||||||
|
RTC_CHECK_NE(*id, std::numeric_limits<uint32_t>::max());
|
||||||
auto timer = absl::WrapUnique(new Timer(
|
auto timer = absl::WrapUnique(new Timer(
|
||||||
id, name, std::move(on_expired), [this, id]() { timers_.erase(id); },
|
id, name, std::move(on_expired), [this, id]() { timers_.erase(id); },
|
||||||
create_timeout_(), options));
|
create_timeout_(), options));
|
||||||
|
@ -12,6 +12,7 @@
|
|||||||
|
|
||||||
#include <stdint.h>
|
#include <stdint.h>
|
||||||
|
|
||||||
|
#include <algorithm>
|
||||||
#include <functional>
|
#include <functional>
|
||||||
#include <memory>
|
#include <memory>
|
||||||
#include <string>
|
#include <string>
|
||||||
@ -20,10 +21,14 @@
|
|||||||
|
|
||||||
#include "absl/strings/string_view.h"
|
#include "absl/strings/string_view.h"
|
||||||
#include "absl/types/optional.h"
|
#include "absl/types/optional.h"
|
||||||
|
#include "net/dcsctp/public/strong_alias.h"
|
||||||
#include "net/dcsctp/public/timeout.h"
|
#include "net/dcsctp/public/timeout.h"
|
||||||
|
|
||||||
namespace dcsctp {
|
namespace dcsctp {
|
||||||
|
|
||||||
|
using TimerID = StrongAlias<class TimerIDTag, uint32_t>;
|
||||||
|
using TimerGeneration = StrongAlias<class TimerGenerationTag, uint32_t>;
|
||||||
|
|
||||||
enum class TimerBackoffAlgorithm {
|
enum class TimerBackoffAlgorithm {
|
||||||
// The base duration will be used for any restart.
|
// The base duration will be used for any restart.
|
||||||
kFixed,
|
kFixed,
|
||||||
@ -68,6 +73,9 @@ struct TimerOptions {
|
|||||||
// backoff algorithm).
|
// backoff algorithm).
|
||||||
class Timer {
|
class Timer {
|
||||||
public:
|
public:
|
||||||
|
// The maximum timer duration - one day.
|
||||||
|
static constexpr DurationMs kMaxTimerDuration = DurationMs(24 * 3600 * 1000);
|
||||||
|
|
||||||
// When expired, the timer handler can optionally return a new duration which
|
// When expired, the timer handler can optionally return a new duration which
|
||||||
// will be set as `duration` and used as base duration when the timer is
|
// will be set as `duration` and used as base duration when the timer is
|
||||||
// restarted and as input to the backoff algorithm.
|
// restarted and as input to the backoff algorithm.
|
||||||
@ -89,7 +97,9 @@ class Timer {
|
|||||||
|
|
||||||
// Sets the base duration. The actual timer duration may be larger depending
|
// Sets the base duration. The actual timer duration may be larger depending
|
||||||
// on the backoff algorithm.
|
// on the backoff algorithm.
|
||||||
void set_duration(DurationMs duration) { duration_ = duration; }
|
void set_duration(DurationMs duration) {
|
||||||
|
duration_ = std::min(duration, kMaxTimerDuration);
|
||||||
|
}
|
||||||
|
|
||||||
// Retrieves the base duration. The actual timer duration may be larger
|
// Retrieves the base duration. The actual timer duration may be larger
|
||||||
// depending on the backoff algorithm.
|
// depending on the backoff algorithm.
|
||||||
@ -110,7 +120,7 @@ class Timer {
|
|||||||
private:
|
private:
|
||||||
friend class TimerManager;
|
friend class TimerManager;
|
||||||
using UnregisterHandler = std::function<void()>;
|
using UnregisterHandler = std::function<void()>;
|
||||||
Timer(uint32_t id,
|
Timer(TimerID id,
|
||||||
absl::string_view name,
|
absl::string_view name,
|
||||||
OnExpired on_expired,
|
OnExpired on_expired,
|
||||||
UnregisterHandler unregister,
|
UnregisterHandler unregister,
|
||||||
@ -122,9 +132,9 @@ class Timer {
|
|||||||
// duration as decided by the backoff algorithm, unless the
|
// duration as decided by the backoff algorithm, unless the
|
||||||
// `TimerOptions::max_restarts` has been reached and then it will be stopped
|
// `TimerOptions::max_restarts` has been reached and then it will be stopped
|
||||||
// and `is_running()` will return false.
|
// and `is_running()` will return false.
|
||||||
void Trigger(uint32_t generation);
|
void Trigger(TimerGeneration generation);
|
||||||
|
|
||||||
const uint32_t id_;
|
const TimerID id_;
|
||||||
const std::string name_;
|
const std::string name_;
|
||||||
const TimerOptions options_;
|
const TimerOptions options_;
|
||||||
const OnExpired on_expired_;
|
const OnExpired on_expired_;
|
||||||
@ -133,8 +143,16 @@ class Timer {
|
|||||||
|
|
||||||
DurationMs duration_;
|
DurationMs duration_;
|
||||||
|
|
||||||
// Increased on each start, and is matched on Trigger, to avoid races.
|
// Increased on each start, and is matched on Trigger, to avoid races. And by
|
||||||
uint32_t generation_ = 0;
|
// race, meaning that a timeout - which may be evaluated/expired on a
|
||||||
|
// different thread while this thread has stopped that timer already. Note
|
||||||
|
// that the entire socket is not thread-safe, so `TimerManager::HandleTimeout`
|
||||||
|
// is never executed concurrently with any timer starting/stopping.
|
||||||
|
//
|
||||||
|
// This will wrap around after 4 billion timer restarts, and if it wraps
|
||||||
|
// around, it would just trigger _this_ timer in advance (but it's hard to
|
||||||
|
// restart it 4 billion times within its duration).
|
||||||
|
TimerGeneration generation_ = TimerGeneration(0);
|
||||||
bool is_running_ = false;
|
bool is_running_ = false;
|
||||||
// Incremented each time time has expired and reset when stopped or restarted.
|
// Incremented each time time has expired and reset when stopped or restarted.
|
||||||
int expiration_count_ = 0;
|
int expiration_count_ = 0;
|
||||||
@ -158,8 +176,8 @@ class TimerManager {
|
|||||||
|
|
||||||
private:
|
private:
|
||||||
const std::function<std::unique_ptr<Timeout>()> create_timeout_;
|
const std::function<std::unique_ptr<Timeout>()> create_timeout_;
|
||||||
std::unordered_map<int, Timer*> timers_;
|
std::unordered_map<TimerID, Timer*, TimerID::Hasher> timers_;
|
||||||
uint32_t next_id_ = 0;
|
TimerID next_id_ = TimerID(0);
|
||||||
};
|
};
|
||||||
|
|
||||||
} // namespace dcsctp
|
} // namespace dcsctp
|
||||||
|
@ -310,5 +310,41 @@ TEST_F(TimerTest, ReturningNewDurationWhenExpired) {
|
|||||||
AdvanceTimeAndRunTimers(DurationMs(1000));
|
AdvanceTimeAndRunTimers(DurationMs(1000));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TEST_F(TimerTest, TimersHaveMaximumDuration) {
|
||||||
|
std::unique_ptr<Timer> t1 = manager_.CreateTimer(
|
||||||
|
"t1", on_expired_.AsStdFunction(),
|
||||||
|
TimerOptions(DurationMs(1000), TimerBackoffAlgorithm::kExponential));
|
||||||
|
|
||||||
|
t1->set_duration(DurationMs(2 * *Timer::kMaxTimerDuration));
|
||||||
|
EXPECT_EQ(t1->duration(), Timer::kMaxTimerDuration);
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(TimerTest, TimersHaveMaximumBackoffDuration) {
|
||||||
|
std::unique_ptr<Timer> t1 = manager_.CreateTimer(
|
||||||
|
"t1", on_expired_.AsStdFunction(),
|
||||||
|
TimerOptions(DurationMs(1000), TimerBackoffAlgorithm::kExponential));
|
||||||
|
|
||||||
|
t1->Start();
|
||||||
|
|
||||||
|
int max_exponent = static_cast<int>(log2(*Timer::kMaxTimerDuration / 1000));
|
||||||
|
for (int i = 0; i < max_exponent; ++i) {
|
||||||
|
EXPECT_CALL(on_expired_, Call).Times(1);
|
||||||
|
AdvanceTimeAndRunTimers(DurationMs(1000 * (1 << i)));
|
||||||
|
}
|
||||||
|
|
||||||
|
// Reached the maximum duration.
|
||||||
|
EXPECT_CALL(on_expired_, Call).Times(1);
|
||||||
|
AdvanceTimeAndRunTimers(Timer::kMaxTimerDuration);
|
||||||
|
|
||||||
|
EXPECT_CALL(on_expired_, Call).Times(1);
|
||||||
|
AdvanceTimeAndRunTimers(Timer::kMaxTimerDuration);
|
||||||
|
|
||||||
|
EXPECT_CALL(on_expired_, Call).Times(1);
|
||||||
|
AdvanceTimeAndRunTimers(Timer::kMaxTimerDuration);
|
||||||
|
|
||||||
|
EXPECT_CALL(on_expired_, Call).Times(1);
|
||||||
|
AdvanceTimeAndRunTimers(Timer::kMaxTimerDuration);
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
} // namespace dcsctp
|
} // namespace dcsctp
|
||||||
|
Reference in New Issue
Block a user