Refactor StunRequest and StunRequestManager classes.

* Make StunRequest::manager_ a reference, inject ref at ctor time.
* Make other member variables private.
* Mark methods that are only used for testing with "ForTest"
* Add RTC_GUARDED_BY for member variables and thread checks.
* Remove/reduce 'friend'-ness between classes.
* Use std::unique_ptr for owned and passed message pointers.
* Rename `requests_` to `request_manager_` (type: StunRequestManager)

Bug: webrtc:13892
Change-Id: I3a5d511b3c2645bb6813352d39e9fefe422dd1de
Reviewed-on: https://webrtc-review.googlesource.com/c/src/+/258620
Reviewed-by: Harald Alvestrand <hta@webrtc.org>
Commit-Queue: Tomas Gunnarsson <tommi@webrtc.org>
Cr-Commit-Position: refs/heads/main@{#36529}
This commit is contained in:
Tommi
2022-04-12 09:17:57 +02:00
committed by WebRTC LUCI CQ
parent 4baaf99c5e
commit 86aa03e238
11 changed files with 187 additions and 144 deletions

View File

@ -163,8 +163,10 @@ constexpr int kSupportGoogPingVersionResponseIndex =
namespace cricket {
// A ConnectionRequest is a STUN binding used to determine writability.
ConnectionRequest::ConnectionRequest(Connection* connection)
: StunRequest(new IceMessage()), connection_(connection) {}
ConnectionRequest::ConnectionRequest(StunRequestManager& manager,
Connection* connection)
: StunRequest(manager, std::make_unique<IceMessage>()),
connection_(connection) {}
void ConnectionRequest::Prepare(StunMessage* request) {
RTC_DCHECK_RUN_ON(connection_->network_thread_);
@ -276,7 +278,7 @@ void ConnectionRequest::OnSent() {
connection_->OnConnectionRequestSent(this);
// Each request is sent only once. After a single delay , the request will
// time out.
timeout_ = true;
set_timed_out();
}
int ConnectionRequest::resend_delay() {
@ -986,7 +988,7 @@ int64_t Connection::last_ping_sent() const {
void Connection::Ping(int64_t now) {
RTC_DCHECK_RUN_ON(network_thread_);
last_ping_sent_ = now;
ConnectionRequest* req = new ConnectionRequest(this);
ConnectionRequest* req = new ConnectionRequest(requests_, this);
// If not using renomination, we use "1" to mean "nominated" and "0" to mean
// "not nominated". If using renomination, values greater than 1 are used for
// re-nominated pairs.

View File

@ -57,7 +57,7 @@ struct CandidatePair final : public CandidatePairInterface {
// A ConnectionRequest is a simple STUN ping used to determine writability.
class ConnectionRequest : public StunRequest {
public:
explicit ConnectionRequest(Connection* connection);
ConnectionRequest(StunRequestManager& manager, Connection* connection);
void Prepare(StunMessage* request) override;
void OnResponse(StunMessage* response) override;
void OnErrorResponse(StunMessage* response) override;

View File

@ -40,7 +40,10 @@ class StunBindingRequest : public StunRequest {
StunBindingRequest(UDPPort* port,
const rtc::SocketAddress& addr,
int64_t start_time)
: port_(port), server_addr_(addr), start_time_(start_time) {}
: StunRequest(port->request_manager()),
port_(port),
server_addr_(addr),
start_time_(start_time) {}
const rtc::SocketAddress& server_addr() const { return server_addr_; }
@ -63,7 +66,7 @@ class StunBindingRequest : public StunRequest {
// The keep-alive requests will be stopped after its lifetime has passed.
if (WithinLifetime(rtc::TimeMillis())) {
port_->requests_.SendDelayed(
port_->request_manager_.SendDelayed(
new StunBindingRequest(port_, server_addr_, start_time_),
port_->stun_keepalive_delay());
}
@ -88,7 +91,7 @@ class StunBindingRequest : public StunRequest {
int64_t now = rtc::TimeMillis();
if (WithinLifetime(now) &&
rtc::TimeDiff(now, start_time_) < RETRY_TIMEOUT) {
port_->requests_.SendDelayed(
port_->request_manager_.SendDelayed(
new StunBindingRequest(port_, server_addr_, start_time_),
port_->stun_keepalive_delay());
}
@ -166,7 +169,7 @@ UDPPort::UDPPort(rtc::Thread* thread,
username,
password,
field_trials),
requests_(thread),
request_manager_(thread),
socket_(socket),
error_(0),
ready_(false),
@ -192,7 +195,7 @@ UDPPort::UDPPort(rtc::Thread* thread,
username,
password,
field_trials),
requests_(thread),
request_manager_(thread),
socket_(nullptr),
error_(0),
ready_(false),
@ -215,7 +218,7 @@ bool UDPPort::Init() {
socket_->SignalSentPacket.connect(this, &UDPPort::OnSentPacket);
socket_->SignalReadyToSend.connect(this, &UDPPort::OnReadyToSend);
socket_->SignalAddressReady.connect(this, &UDPPort::OnLocalAddressReady);
requests_.SignalSendPacket.connect(this, &UDPPort::OnSendPacket);
request_manager_.SignalSendPacket.connect(this, &UDPPort::OnSendPacket);
return true;
}
@ -225,7 +228,7 @@ UDPPort::~UDPPort() {
}
void UDPPort::PrepareAddress() {
RTC_DCHECK(requests_.empty());
RTC_DCHECK(request_manager_.empty());
if (socket_->GetState() == rtc::AsyncPacketSocket::STATE_BOUND) {
OnLocalAddressReady(socket_, socket_->GetLocalAddress());
}
@ -390,7 +393,7 @@ void UDPPort::OnReadPacket(rtc::AsyncPacketSocket* socket,
// will eat it because it might be a response to a retransmitted packet, and
// we already cleared the request when we got the first response.
if (server_addresses_.find(remote_addr) != server_addresses_.end()) {
requests_.CheckResponse(data, size);
request_manager_.CheckResponse(data, size);
return;
}
@ -413,7 +416,7 @@ void UDPPort::OnReadyToSend(rtc::AsyncPacketSocket* socket) {
void UDPPort::SendStunBindingRequests() {
// We will keep pinging the stun server to make sure our NAT pin-hole stays
// open until the deadline (specified in SendStunBindingRequest).
RTC_DCHECK(requests_.empty());
RTC_DCHECK(request_manager_.empty());
for (ServerAddresses::const_iterator it = server_addresses_.begin();
it != server_addresses_.end(); ++it) {
@ -463,7 +466,7 @@ void UDPPort::SendStunBindingRequest(const rtc::SocketAddress& stun_addr) {
} else if (socket_->GetState() == rtc::AsyncPacketSocket::STATE_BOUND) {
// Check if `server_addr_` is compatible with the port's ip.
if (IsCompatibleAddress(stun_addr)) {
requests_.Send(
request_manager_.Send(
new StunBindingRequest(this, stun_addr, rtc::TimeMillis()));
} else {
// Since we can't send stun messages to the server, we should mark this

View File

@ -114,10 +114,12 @@ class UDPPort : public Port {
stun_keepalive_lifetime_ = lifetime;
}
// Returns true if there is a pending request with type `msg_type`.
bool HasPendingRequest(int msg_type) {
return requests_.HasRequest(msg_type);
bool HasPendingRequestForTest(int msg_type) {
return request_manager_.HasRequestForTest(msg_type);
}
StunRequestManager& request_manager() { return request_manager_; }
protected:
UDPPort(rtc::Thread* thread,
rtc::PacketSocketFactory* factory,
@ -244,7 +246,7 @@ class UDPPort : public Port {
ServerAddresses server_addresses_;
ServerAddresses bind_request_succeeded_servers_;
ServerAddresses bind_request_failed_servers_;
StunRequestManager requests_;
StunRequestManager request_manager_;
rtc::AsyncPacketSocket* socket_;
int error_;
int send_error_count_ = 0;

View File

@ -393,7 +393,7 @@ TEST_F(StunPortTest, TestStunBindingRequestShortLifetime) {
PrepareAddress();
EXPECT_TRUE_SIMULATED_WAIT(done(), kTimeoutMs, fake_clock);
EXPECT_TRUE_SIMULATED_WAIT(
!port()->HasPendingRequest(cricket::STUN_BINDING_REQUEST), 2000,
!port()->HasPendingRequestForTest(cricket::STUN_BINDING_REQUEST), 2000,
fake_clock);
}
@ -404,7 +404,7 @@ TEST_F(StunPortTest, TestStunBindingRequestLongLifetime) {
PrepareAddress();
EXPECT_TRUE_SIMULATED_WAIT(done(), kTimeoutMs, fake_clock);
EXPECT_TRUE_SIMULATED_WAIT(
port()->HasPendingRequest(cricket::STUN_BINDING_REQUEST), 1000,
port()->HasPendingRequestForTest(cricket::STUN_BINDING_REQUEST), 1000,
fake_clock);
}

View File

@ -12,6 +12,7 @@
#include <algorithm>
#include <memory>
#include <utility>
#include <vector>
#include "rtc_base/checks.h"
@ -56,7 +57,8 @@ void StunRequestManager::Send(StunRequest* request) {
}
void StunRequestManager::SendDelayed(StunRequest* request, int delay) {
request->set_manager(this);
RTC_DCHECK_RUN_ON(thread_);
RTC_DCHECK_EQ(this, request->manager());
RTC_DCHECK(requests_.find(request->id()) == requests_.end());
request->Construct();
requests_[request->id()] = request;
@ -67,7 +69,8 @@ void StunRequestManager::SendDelayed(StunRequest* request, int delay) {
}
}
void StunRequestManager::Flush(int msg_type) {
void StunRequestManager::FlushForTest(int msg_type) {
RTC_DCHECK_RUN_ON(thread_);
for (const auto& kv : requests_) {
StunRequest* request = kv.second;
if (msg_type == kAllRequests || msg_type == request->type()) {
@ -77,7 +80,8 @@ void StunRequestManager::Flush(int msg_type) {
}
}
bool StunRequestManager::HasRequest(int msg_type) {
bool StunRequestManager::HasRequestForTest(int msg_type) {
RTC_DCHECK_RUN_ON(thread_);
for (const auto& kv : requests_) {
StunRequest* request = kv.second;
if (msg_type == kAllRequests || msg_type == request->type()) {
@ -88,6 +92,7 @@ bool StunRequestManager::HasRequest(int msg_type) {
}
void StunRequestManager::Remove(StunRequest* request) {
RTC_DCHECK_RUN_ON(thread_);
RTC_DCHECK(request->manager() == this);
RequestMap::iterator iter = requests_.find(request->id());
if (iter != requests_.end()) {
@ -98,6 +103,7 @@ void StunRequestManager::Remove(StunRequest* request) {
}
void StunRequestManager::Clear() {
RTC_DCHECK_RUN_ON(thread_);
std::vector<StunRequest*> requests;
for (RequestMap::iterator i = requests_.begin(); i != requests_.end(); ++i)
requests.push_back(i->second);
@ -110,6 +116,7 @@ void StunRequestManager::Clear() {
}
bool StunRequestManager::CheckResponse(StunMessage* msg) {
RTC_DCHECK_RUN_ON(thread_);
RequestMap::iterator iter = requests_.find(msg->transaction_id());
if (iter == requests_.end()) {
// TODO(pthatcher): Log unknown responses without being too spammy
@ -156,7 +163,13 @@ bool StunRequestManager::CheckResponse(StunMessage* msg) {
return true;
}
bool StunRequestManager::empty() const {
RTC_DCHECK_RUN_ON(thread_);
return requests_.empty();
}
bool StunRequestManager::CheckResponse(const char* data, size_t size) {
RTC_DCHECK_RUN_ON(thread_);
// Check the appropriate bytes of the stream to see if they match the
// transaction ID of a response we are expecting.
@ -186,32 +199,33 @@ bool StunRequestManager::CheckResponse(const char* data, size_t size) {
return CheckResponse(response.get());
}
StunRequest::StunRequest()
: count_(0),
timeout_(false),
manager_(0),
StunRequest::StunRequest(StunRequestManager& manager)
: manager_(manager),
msg_(new StunMessage()),
tstamp_(0) {
tstamp_(0),
count_(0),
timeout_(false) {
msg_->SetTransactionID(rtc::CreateRandomString(kStunTransactionIdLength));
}
StunRequest::StunRequest(StunMessage* request)
: count_(0), timeout_(false), manager_(0), msg_(request), tstamp_(0) {
StunRequest::StunRequest(StunRequestManager& manager,
std::unique_ptr<StunMessage> request)
: manager_(manager),
msg_(std::move(request)),
tstamp_(0),
count_(0),
timeout_(false) {
msg_->SetTransactionID(rtc::CreateRandomString(kStunTransactionIdLength));
}
StunRequest::~StunRequest() {
RTC_DCHECK(manager_ != NULL);
if (manager_) {
manager_->Remove(this);
manager_->thread_->Clear(this);
}
delete msg_;
manager_.Remove(this);
manager_.network_thread()->Clear(this);
}
void StunRequest::Construct() {
if (msg_->type() == 0) {
Prepare(msg_);
Prepare(msg_.get());
RTC_DCHECK(msg_->type() != 0);
}
}
@ -222,24 +236,16 @@ int StunRequest::type() {
}
const StunMessage* StunRequest::msg() const {
return msg_;
}
StunMessage* StunRequest::mutable_msg() {
return msg_;
return msg_.get();
}
int StunRequest::Elapsed() const {
RTC_DCHECK_RUN_ON(network_thread());
return static_cast<int>(rtc::TimeMillis() - tstamp_);
}
void StunRequest::set_manager(StunRequestManager* manager) {
RTC_DCHECK(!manager_);
manager_ = manager;
}
void StunRequest::OnMessage(rtc::Message* pmsg) {
RTC_DCHECK(manager_ != NULL);
RTC_DCHECK_RUN_ON(network_thread());
RTC_DCHECK(pmsg->message_id == MSG_STUN_SEND);
if (timeout_) {
@ -252,24 +258,26 @@ void StunRequest::OnMessage(rtc::Message* pmsg) {
rtc::ByteBufferWriter buf;
msg_->Write(&buf);
manager_->SignalSendPacket(buf.Data(), buf.Length(), this);
manager_.SignalSendPacket(buf.Data(), buf.Length(), this);
OnSent();
manager_->thread_->PostDelayed(RTC_FROM_HERE, resend_delay(), this,
MSG_STUN_SEND, NULL);
manager_.network_thread()->PostDelayed(RTC_FROM_HERE, resend_delay(), this,
MSG_STUN_SEND, NULL);
}
void StunRequest::OnSent() {
RTC_DCHECK_RUN_ON(network_thread());
count_ += 1;
int retransmissions = (count_ - 1);
if (retransmissions >= STUN_MAX_RETRANSMISSIONS) {
timeout_ = true;
}
RTC_LOG(LS_VERBOSE) << "Sent STUN request " << count_
<< "; resend delay = " << resend_delay();
RTC_DLOG(LS_VERBOSE) << "Sent STUN request " << count_
<< "; resend delay = " << resend_delay();
}
int StunRequest::resend_delay() {
RTC_DCHECK_RUN_ON(network_thread());
if (count_ == 0) {
return 0;
}
@ -278,4 +286,9 @@ int StunRequest::resend_delay() {
return std::min(rto, STUN_MAX_RTO);
}
void StunRequest::set_timed_out() {
RTC_DCHECK_RUN_ON(network_thread());
timeout_ = true;
}
} // namespace cricket

View File

@ -15,6 +15,7 @@
#include <stdint.h>
#include <map>
#include <memory>
#include <string>
#include "api/transport/stun.h"
@ -47,11 +48,11 @@ class StunRequestManager {
// If `msg_type` is kAllRequests, sends all pending requests right away.
// Otherwise, sends those that have a matching type right away.
// Only for testing.
void Flush(int msg_type);
void FlushForTest(int msg_type);
// Returns true if at least one request with `msg_type` is scheduled for
// transmission. For testing only.
bool HasRequest(int msg_type);
bool HasRequestForTest(int msg_type);
// Removes a stun request that was added previously. This will happen
// automatically when a request succeeds, fails, or times out.
@ -65,7 +66,10 @@ class StunRequestManager {
bool CheckResponse(StunMessage* msg);
bool CheckResponse(const char* data, size_t size);
bool empty() { return requests_.empty(); }
bool empty() const;
// TODO(tommi): Use TaskQueueBase* instead of rtc::Thread.
rtc::Thread* network_thread() const { return thread_; }
// Raised when there are bytes to be sent.
sigslot::signal3<const void*, size_t, StunRequest*> SignalSendPacket;
@ -74,27 +78,26 @@ class StunRequestManager {
typedef std::map<std::string, StunRequest*> RequestMap;
rtc::Thread* const thread_;
RequestMap requests_;
friend class StunRequest;
RequestMap requests_ RTC_GUARDED_BY(thread_);
};
// Represents an individual request to be sent. The STUN message can either be
// constructed beforehand or built on demand.
class StunRequest : public rtc::MessageHandler {
public:
StunRequest();
explicit StunRequest(StunMessage* request);
explicit StunRequest(StunRequestManager& manager);
StunRequest(StunRequestManager& manager,
std::unique_ptr<StunMessage> request);
~StunRequest() override;
// Causes our wrapped StunMessage to be Prepared
void Construct();
// The manager handling this request (if it has been scheduled for sending).
StunRequestManager* manager() { return manager_; }
StunRequestManager* manager() { return &manager_; }
// Returns the transaction ID of this request.
const std::string& id() { return msg_->transaction_id(); }
const std::string& id() const { return msg_->transaction_id(); }
// Returns the reduced transaction ID of this request.
uint32_t reduced_transaction_id() const {
@ -107,15 +110,11 @@ class StunRequest : public rtc::MessageHandler {
// Returns a const pointer to `msg_`.
const StunMessage* msg() const;
// Returns a mutable pointer to `msg_`.
StunMessage* mutable_msg();
// Time elapsed since last send (in ms)
int Elapsed() const;
protected:
int count_;
bool timeout_;
friend class StunRequestManager;
// Fills in a request object to be sent. Note that request's transaction ID
// will already be set and cannot be changed.
@ -130,17 +129,21 @@ class StunRequest : public rtc::MessageHandler {
// Returns the next delay for resends.
virtual int resend_delay();
private:
void set_manager(StunRequestManager* manager);
webrtc::TaskQueueBase* network_thread() const {
return manager_.network_thread();
}
void set_timed_out();
private:
// Handles messages for sending and timeout.
void OnMessage(rtc::Message* pmsg) override;
StunRequestManager* manager_;
StunMessage* msg_;
int64_t tstamp_;
friend class StunRequestManager;
StunRequestManager& manager_;
const std::unique_ptr<StunMessage> msg_;
int64_t tstamp_ RTC_GUARDED_BY(network_thread());
int count_ RTC_GUARDED_BY(network_thread());
bool timeout_ RTC_GUARDED_BY(network_thread());
};
} // namespace cricket

View File

@ -10,6 +10,7 @@
#include "p2p/base/stun_request.h"
#include <utility>
#include <vector>
#include "rtc_base/fake_clock.h"
@ -19,6 +20,24 @@
#include "test/gtest.h"
namespace cricket {
namespace {
std::unique_ptr<StunMessage> CreateStunMessage(
StunMessageType type,
const StunMessage* req = nullptr) {
std::unique_ptr<StunMessage> msg = std::make_unique<StunMessage>();
msg->SetType(type);
if (req) {
msg->SetTransactionID(req->transaction_id());
}
return msg;
}
int TotalDelay(int sends) {
std::vector<int> delays = {0, 250, 750, 1750, 3750,
7750, 15750, 23750, 31750, 39750};
return delays[sends];
}
} // namespace
class StunRequestTest : public ::testing::Test, public sigslot::has_slots<> {
public:
@ -47,21 +66,6 @@ class StunRequestTest : public ::testing::Test, public sigslot::has_slots<> {
void OnTimeout() { timeout_ = true; }
protected:
static StunMessage* CreateStunMessage(StunMessageType type,
StunMessage* req) {
StunMessage* msg = new StunMessage();
msg->SetType(type);
if (req) {
msg->SetTransactionID(req->transaction_id());
}
return msg;
}
static int TotalDelay(int sends) {
std::vector<int> delays = {0, 250, 750, 1750, 3750,
7750, 15750, 23750, 31750, 39750};
return delays[sends];
}
StunRequestManager manager_;
int request_count_;
StunMessage* response_;
@ -73,9 +77,20 @@ class StunRequestTest : public ::testing::Test, public sigslot::has_slots<> {
// Forwards results to the test class.
class StunRequestThunker : public StunRequest {
public:
StunRequestThunker(StunMessage* msg, StunRequestTest* test)
: StunRequest(msg), test_(test) {}
explicit StunRequestThunker(StunRequestTest* test) : test_(test) {}
StunRequestThunker(StunRequestManager& manager,
StunMessageType message_type,
StunRequestTest* test)
: StunRequest(manager, CreateStunMessage(message_type)), test_(test) {
Construct(); // Triggers a call to `Prepare()` which sets the type.
}
StunRequestThunker(StunRequestManager& manager, StunRequestTest* test)
: StunRequest(manager), test_(test) {
Construct(); // Triggers a call to `Prepare()` which sets the type.
}
std::unique_ptr<StunMessage> CreateResponseMessage(StunMessageType type) {
return CreateStunMessage(type, msg());
}
private:
virtual void OnResponse(StunMessage* res) { test_->OnResponse(res); }
@ -93,127 +108,124 @@ class StunRequestThunker : public StunRequest {
// Test handling of a normal binding response.
TEST_F(StunRequestTest, TestSuccess) {
StunMessage* req = CreateStunMessage(STUN_BINDING_REQUEST, NULL);
auto* request = new StunRequestThunker(manager_, STUN_BINDING_REQUEST, this);
std::unique_ptr<StunMessage> res =
request->CreateResponseMessage(STUN_BINDING_RESPONSE);
manager_.Send(request);
EXPECT_TRUE(manager_.CheckResponse(res.get()));
manager_.Send(new StunRequestThunker(req, this));
StunMessage* res = CreateStunMessage(STUN_BINDING_RESPONSE, req);
EXPECT_TRUE(manager_.CheckResponse(res));
EXPECT_TRUE(response_ == res);
EXPECT_TRUE(response_ == res.get());
EXPECT_TRUE(success_);
EXPECT_FALSE(failure_);
EXPECT_FALSE(timeout_);
delete res;
}
// Test handling of an error binding response.
TEST_F(StunRequestTest, TestError) {
StunMessage* req = CreateStunMessage(STUN_BINDING_REQUEST, NULL);
auto* request = new StunRequestThunker(manager_, STUN_BINDING_REQUEST, this);
std::unique_ptr<StunMessage> res =
request->CreateResponseMessage(STUN_BINDING_ERROR_RESPONSE);
manager_.Send(request);
EXPECT_TRUE(manager_.CheckResponse(res.get()));
manager_.Send(new StunRequestThunker(req, this));
StunMessage* res = CreateStunMessage(STUN_BINDING_ERROR_RESPONSE, req);
EXPECT_TRUE(manager_.CheckResponse(res));
EXPECT_TRUE(response_ == res);
EXPECT_TRUE(response_ == res.get());
EXPECT_FALSE(success_);
EXPECT_TRUE(failure_);
EXPECT_FALSE(timeout_);
delete res;
}
// Test handling of a binding response with the wrong transaction id.
TEST_F(StunRequestTest, TestUnexpected) {
StunMessage* req = CreateStunMessage(STUN_BINDING_REQUEST, NULL);
auto* request = new StunRequestThunker(manager_, STUN_BINDING_REQUEST, this);
std::unique_ptr<StunMessage> res = CreateStunMessage(STUN_BINDING_RESPONSE);
manager_.Send(new StunRequestThunker(req, this));
StunMessage* res = CreateStunMessage(STUN_BINDING_RESPONSE, NULL);
EXPECT_FALSE(manager_.CheckResponse(res));
manager_.Send(request);
EXPECT_FALSE(manager_.CheckResponse(res.get()));
EXPECT_TRUE(response_ == NULL);
EXPECT_FALSE(success_);
EXPECT_FALSE(failure_);
EXPECT_FALSE(timeout_);
delete res;
}
// Test that requests are sent at the right times.
TEST_F(StunRequestTest, TestBackoff) {
rtc::ScopedFakeClock fake_clock;
StunMessage* req = CreateStunMessage(STUN_BINDING_REQUEST, NULL);
auto* request = new StunRequestThunker(manager_, STUN_BINDING_REQUEST, this);
std::unique_ptr<StunMessage> res =
request->CreateResponseMessage(STUN_BINDING_RESPONSE);
int64_t start = rtc::TimeMillis();
manager_.Send(new StunRequestThunker(req, this));
StunMessage* res = CreateStunMessage(STUN_BINDING_RESPONSE, req);
manager_.Send(request);
for (int i = 0; i < 9; ++i) {
EXPECT_TRUE_SIMULATED_WAIT(request_count_ != i, STUN_TOTAL_TIMEOUT,
fake_clock);
int64_t elapsed = rtc::TimeMillis() - start;
RTC_LOG(LS_INFO) << "STUN request #" << (i + 1) << " sent at " << elapsed
<< " ms";
RTC_DLOG(LS_INFO) << "STUN request #" << (i + 1) << " sent at " << elapsed
<< " ms";
EXPECT_EQ(TotalDelay(i), elapsed);
}
EXPECT_TRUE(manager_.CheckResponse(res));
EXPECT_TRUE(manager_.CheckResponse(res.get()));
EXPECT_TRUE(response_ == res);
EXPECT_TRUE(response_ == res.get());
EXPECT_TRUE(success_);
EXPECT_FALSE(failure_);
EXPECT_FALSE(timeout_);
delete res;
}
// Test that we timeout properly if no response is received.
TEST_F(StunRequestTest, TestTimeout) {
rtc::ScopedFakeClock fake_clock;
StunMessage* req = CreateStunMessage(STUN_BINDING_REQUEST, NULL);
StunMessage* res = CreateStunMessage(STUN_BINDING_RESPONSE, req);
auto* request = new StunRequestThunker(manager_, STUN_BINDING_REQUEST, this);
std::unique_ptr<StunMessage> res =
request->CreateResponseMessage(STUN_BINDING_RESPONSE);
manager_.Send(new StunRequestThunker(req, this));
manager_.Send(request);
SIMULATED_WAIT(false, cricket::STUN_TOTAL_TIMEOUT, fake_clock);
EXPECT_FALSE(manager_.CheckResponse(res));
EXPECT_FALSE(manager_.CheckResponse(res.get()));
EXPECT_TRUE(response_ == NULL);
EXPECT_FALSE(success_);
EXPECT_FALSE(failure_);
EXPECT_TRUE(timeout_);
delete res;
}
// Regression test for specific crash where we receive a response with the
// same id as a request that doesn't have an underlying StunMessage yet.
TEST_F(StunRequestTest, TestNoEmptyRequest) {
StunRequestThunker* request = new StunRequestThunker(this);
StunRequestThunker* request = new StunRequestThunker(manager_, this);
manager_.SendDelayed(request, 100);
StunMessage dummy_req;
dummy_req.SetTransactionID(request->id());
StunMessage* res = CreateStunMessage(STUN_BINDING_RESPONSE, &dummy_req);
std::unique_ptr<StunMessage> res =
CreateStunMessage(STUN_BINDING_RESPONSE, &dummy_req);
EXPECT_TRUE(manager_.CheckResponse(res));
EXPECT_TRUE(manager_.CheckResponse(res.get()));
EXPECT_TRUE(response_ == res);
EXPECT_TRUE(response_ == res.get());
EXPECT_TRUE(success_);
EXPECT_FALSE(failure_);
EXPECT_FALSE(timeout_);
delete res;
}
// If the response contains an attribute in the "comprehension required" range
// which is not recognized, the transaction should be considered a failure and
// the response should be ignored.
TEST_F(StunRequestTest, TestUnrecognizedComprehensionRequiredAttribute) {
StunMessage* req = CreateStunMessage(STUN_BINDING_REQUEST, NULL);
auto* request = new StunRequestThunker(manager_, STUN_BINDING_REQUEST, this);
std::unique_ptr<StunMessage> res =
request->CreateResponseMessage(STUN_BINDING_ERROR_RESPONSE);
manager_.Send(new StunRequestThunker(req, this));
StunMessage* res = CreateStunMessage(STUN_BINDING_ERROR_RESPONSE, req);
manager_.Send(request);
res->AddAttribute(StunAttribute::CreateUInt32(0x7777));
EXPECT_FALSE(manager_.CheckResponse(res));
EXPECT_FALSE(manager_.CheckResponse(res.get()));
EXPECT_EQ(nullptr, response_);
EXPECT_FALSE(success_);
EXPECT_FALSE(failure_);
EXPECT_FALSE(timeout_);
delete res;
}
} // namespace cricket

View File

@ -1361,7 +1361,8 @@ void TurnPort::MaybeAddTurnLoggingId(StunMessage* msg) {
}
TurnAllocateRequest::TurnAllocateRequest(TurnPort* port)
: StunRequest(new TurnMessage()), port_(port) {}
: StunRequest(port->request_manager(), std::make_unique<TurnMessage>()),
port_(port) {}
void TurnAllocateRequest::Prepare(StunMessage* request) {
// Create the request as indicated in RFC 5766, Section 6.1.
@ -1549,7 +1550,9 @@ void TurnAllocateRequest::OnTryAlternate(StunMessage* response, int code) {
}
TurnRefreshRequest::TurnRefreshRequest(TurnPort* port)
: StunRequest(new TurnMessage()), port_(port), lifetime_(-1) {}
: StunRequest(port->request_manager(), std::make_unique<TurnMessage>()),
port_(port),
lifetime_(-1) {}
void TurnRefreshRequest::Prepare(StunMessage* request) {
// Create the request as indicated in RFC 5766, Section 7.1.
@ -1630,7 +1633,7 @@ TurnCreatePermissionRequest::TurnCreatePermissionRequest(
TurnEntry* entry,
const rtc::SocketAddress& ext_addr,
const std::string& remote_ufrag)
: StunRequest(new TurnMessage()),
: StunRequest(port->request_manager(), std::make_unique<TurnMessage>()),
port_(port),
entry_(entry),
ext_addr_(ext_addr),
@ -1703,7 +1706,7 @@ TurnChannelBindRequest::TurnChannelBindRequest(
TurnEntry* entry,
int channel_id,
const rtc::SocketAddress& ext_addr)
: StunRequest(new TurnMessage()),
: StunRequest(port->request_manager(), std::make_unique<TurnMessage>()),
port_(port),
entry_(entry),
channel_id_(channel_id),

View File

@ -171,6 +171,7 @@ class TurnPort : public Port {
void OnAllocateMismatch();
rtc::AsyncPacketSocket* socket() const { return socket_; }
StunRequestManager& request_manager() { return request_manager_; }
// Signal with resolved server address.
// Parameters are port, server address and resolved server address.
@ -188,7 +189,11 @@ class TurnPort : public Port {
sigslot::signal2<TurnPort*, int> SignalTurnRefreshResult;
sigslot::signal3<TurnPort*, const rtc::SocketAddress&, int>
SignalCreatePermissionResult;
void FlushRequests(int msg_type) { request_manager_.Flush(msg_type); }
void FlushRequestsForTest(int msg_type) {
request_manager_.FlushForTest(msg_type);
}
bool HasRequests() { return !request_manager_.empty(); }
void set_credentials(const RelayCredentials& credentials) {
credentials_ = credentials;

View File

@ -1241,10 +1241,10 @@ TEST_F(TurnPortTest, TestRefreshRequestGetsErrorResponse) {
// This sends out the first RefreshRequest with correct credentials.
// When this succeeds, it will schedule a new RefreshRequest with the bad
// credential.
turn_port_->FlushRequests(TURN_REFRESH_REQUEST);
turn_port_->FlushRequestsForTest(TURN_REFRESH_REQUEST);
EXPECT_TRUE_SIMULATED_WAIT(turn_refresh_success_, kSimulatedRtt, fake_clock_);
// Flush it again, it will receive a bad response.
turn_port_->FlushRequests(TURN_REFRESH_REQUEST);
turn_port_->FlushRequestsForTest(TURN_REFRESH_REQUEST);
EXPECT_TRUE_SIMULATED_WAIT(!turn_refresh_success_, kSimulatedRtt,
fake_clock_);
EXPECT_FALSE(turn_port_->connected());
@ -1458,11 +1458,11 @@ TEST_F(TurnPortTest, TestRefreshCreatePermissionRequest) {
// another request with bad_ufrag and bad_pwd.
RelayCredentials bad_credentials("bad_user", "bad_pwd");
turn_port_->set_credentials(bad_credentials);
turn_port_->FlushRequests(kAllRequests);
turn_port_->FlushRequestsForTest(kAllRequests);
EXPECT_TRUE_SIMULATED_WAIT(turn_create_permission_success_, kSimulatedRtt,
fake_clock_);
// Flush the requests again; the create-permission-request will fail.
turn_port_->FlushRequests(kAllRequests);
turn_port_->FlushRequestsForTest(kAllRequests);
EXPECT_TRUE_SIMULATED_WAIT(!turn_create_permission_success_, kSimulatedRtt,
fake_clock_);
EXPECT_TRUE(CheckConnectionFailedAndPruned(conn));