Support DNS resolution matching a specified IP family.

The input SocketAddress for STUN host lookup is constructed with just
the hostname, so the family is AF_UNSPEC. So added an overload with a
target family to distinguish this from the family of the input addr.

Bug: webrtc:14319, webrtc:14131
Change-Id: Ia5ac5aa2e894e0c4dfb4417e3e8a76a6cec3ea71
Reviewed-on: https://webrtc-review.googlesource.com/c/src/+/270624
Reviewed-by: Tomas Gunnarsson <tommi@webrtc.org>
Commit-Queue: Sameer Vijaykar <samvi@google.com>
Reviewed-by: Jonas Oreland <jonaso@webrtc.org>
Reviewed-by: Jonas Oreland <jonaso@google.com>
Cr-Commit-Position: refs/heads/main@{#37750}
This commit is contained in:
Sameer Vijaykar
2022-08-11 11:52:57 +02:00
committed by WebRTC LUCI CQ
parent 12053ec64a
commit b787e26369
11 changed files with 73 additions and 8 deletions

View File

@ -345,6 +345,7 @@ rtc_source_set("async_dns_resolver") {
visibility = [ "*" ] visibility = [ "*" ]
sources = [ "async_dns_resolver.h" ] sources = [ "async_dns_resolver.h" ]
deps = [ deps = [
"../rtc_base:checks",
"../rtc_base:socket_address", "../rtc_base:socket_address",
"../rtc_base/system:rtc_export", "../rtc_base/system:rtc_export",
] ]

View File

@ -14,6 +14,7 @@
#include <functional> #include <functional>
#include <memory> #include <memory>
#include "rtc_base/checks.h"
#include "rtc_base/socket_address.h" #include "rtc_base/socket_address.h"
#include "rtc_base/system/rtc_export.h" #include "rtc_base/system/rtc_export.h"
@ -63,6 +64,10 @@ class RTC_EXPORT AsyncDnsResolverInterface {
// Start address resolution of the hostname in `addr`. // Start address resolution of the hostname in `addr`.
virtual void Start(const rtc::SocketAddress& addr, virtual void Start(const rtc::SocketAddress& addr,
std::function<void()> callback) = 0; std::function<void()> callback) = 0;
// Start address resolution of the hostname in `addr` matching `family`.
virtual void Start(const rtc::SocketAddress& addr,
int family,
std::function<void()> callback) = 0;
virtual const AsyncDnsResolverResult& result() const = 0; virtual const AsyncDnsResolverResult& result() const = 0;
}; };
@ -79,6 +84,14 @@ class AsyncDnsResolverFactoryInterface {
virtual std::unique_ptr<webrtc::AsyncDnsResolverInterface> CreateAndResolve( virtual std::unique_ptr<webrtc::AsyncDnsResolverInterface> CreateAndResolve(
const rtc::SocketAddress& addr, const rtc::SocketAddress& addr,
std::function<void()> callback) = 0; std::function<void()> callback) = 0;
// Creates an AsyncDnsResolver and starts resolving the name to an address
// matching the specified family. The callback will be called when resolution
// is finished. The callback will be called on the sequence that the caller
// runs on.
virtual std::unique_ptr<webrtc::AsyncDnsResolverInterface> CreateAndResolve(
const rtc::SocketAddress& addr,
int family,
std::function<void()> callback) = 0;
// Creates an AsyncDnsResolver and does not start it. // Creates an AsyncDnsResolver and does not start it.
// For backwards compatibility, will be deprecated and removed. // For backwards compatibility, will be deprecated and removed.
// One has to do a separate Start() call on the // One has to do a separate Start() call on the

View File

@ -34,6 +34,10 @@ class MockAsyncDnsResolver : public AsyncDnsResolverInterface {
Start, Start,
(const rtc::SocketAddress&, std::function<void()>), (const rtc::SocketAddress&, std::function<void()>),
(override)); (override));
MOCK_METHOD(void,
Start,
(const rtc::SocketAddress&, int family, std::function<void()>),
(override));
MOCK_METHOD(AsyncDnsResolverResult&, result, (), (const, override)); MOCK_METHOD(AsyncDnsResolverResult&, result, (), (const, override));
}; };
@ -43,6 +47,10 @@ class MockAsyncDnsResolverFactory : public AsyncDnsResolverFactoryInterface {
CreateAndResolve, CreateAndResolve,
(const rtc::SocketAddress&, std::function<void()>), (const rtc::SocketAddress&, std::function<void()>),
(override)); (override));
MOCK_METHOD(std::unique_ptr<webrtc::AsyncDnsResolverInterface>,
CreateAndResolve,
(const rtc::SocketAddress&, int, std::function<void()>),
(override));
MOCK_METHOD(std::unique_ptr<webrtc::AsyncDnsResolverInterface>, MOCK_METHOD(std::unique_ptr<webrtc::AsyncDnsResolverInterface>,
Create, Create,
(), (),

View File

@ -13,6 +13,7 @@
#include <functional> #include <functional>
#include <memory> #include <memory>
#include <utility>
#include "absl/memory/memory.h" #include "absl/memory/memory.h"
#include "api/async_dns_resolver.h" #include "api/async_dns_resolver.h"
@ -68,14 +69,18 @@ class RTC_EXPORT WrappingAsyncDnsResolver : public AsyncDnsResolverInterface,
void Start(const rtc::SocketAddress& addr, void Start(const rtc::SocketAddress& addr,
std::function<void()> callback) override { std::function<void()> callback) override {
RTC_DCHECK_RUN_ON(&sequence_checker_); RTC_DCHECK_RUN_ON(&sequence_checker_);
RTC_DCHECK_EQ(State::kNotStarted, state_); PrepareToResolve(std::move(callback));
state_ = State::kStarted;
callback_ = callback;
wrapped_->SignalDone.connect(this,
&WrappingAsyncDnsResolver::OnResolveResult);
wrapped_->Start(addr); wrapped_->Start(addr);
} }
void Start(const rtc::SocketAddress& addr,
int family,
std::function<void()> callback) override {
RTC_DCHECK_RUN_ON(&sequence_checker_);
PrepareToResolve(std::move(callback));
wrapped_->Start(addr, family);
}
const AsyncDnsResolverResult& result() const override { const AsyncDnsResolverResult& result() const override {
RTC_DCHECK_RUN_ON(&sequence_checker_); RTC_DCHECK_RUN_ON(&sequence_checker_);
RTC_DCHECK_EQ(State::kResolved, state_); RTC_DCHECK_EQ(State::kResolved, state_);
@ -92,6 +97,15 @@ class RTC_EXPORT WrappingAsyncDnsResolver : public AsyncDnsResolverInterface,
return wrapped_.get(); return wrapped_.get();
} }
void PrepareToResolve(std::function<void()> callback) {
RTC_DCHECK_RUN_ON(&sequence_checker_);
RTC_DCHECK_EQ(State::kNotStarted, state_);
state_ = State::kStarted;
callback_ = std::move(callback);
wrapped_->SignalDone.connect(this,
&WrappingAsyncDnsResolver::OnResolveResult);
}
void OnResolveResult(rtc::AsyncResolverInterface* ref) { void OnResolveResult(rtc::AsyncResolverInterface* ref) {
RTC_DCHECK_RUN_ON(&sequence_checker_); RTC_DCHECK_RUN_ON(&sequence_checker_);
RTC_DCHECK(state_ == State::kStarted); RTC_DCHECK(state_ == State::kStarted);

View File

@ -36,7 +36,17 @@ WrappingAsyncDnsResolverFactory::CreateAndResolve(
const rtc::SocketAddress& addr, const rtc::SocketAddress& addr,
std::function<void()> callback) { std::function<void()> callback) {
std::unique_ptr<webrtc::AsyncDnsResolverInterface> resolver = Create(); std::unique_ptr<webrtc::AsyncDnsResolverInterface> resolver = Create();
resolver->Start(addr, callback); resolver->Start(addr, std::move(callback));
return resolver;
}
std::unique_ptr<webrtc::AsyncDnsResolverInterface>
WrappingAsyncDnsResolverFactory::CreateAndResolve(
const rtc::SocketAddress& addr,
int family,
std::function<void()> callback) {
std::unique_ptr<webrtc::AsyncDnsResolverInterface> resolver = Create();
resolver->Start(addr, family, std::move(callback));
return resolver; return resolver;
} }

View File

@ -45,6 +45,11 @@ class WrappingAsyncDnsResolverFactory final
const rtc::SocketAddress& addr, const rtc::SocketAddress& addr,
std::function<void()> callback) override; std::function<void()> callback) override;
std::unique_ptr<webrtc::AsyncDnsResolverInterface> CreateAndResolve(
const rtc::SocketAddress& addr,
int family,
std::function<void()> callback) override;
std::unique_ptr<webrtc::AsyncDnsResolverInterface> Create() override; std::unique_ptr<webrtc::AsyncDnsResolverInterface> Create() override;
private: private:

View File

@ -30,6 +30,7 @@ class MockAsyncResolver : public AsyncResolverInterface {
~MockAsyncResolver() = default; ~MockAsyncResolver() = default;
MOCK_METHOD(void, Start, (const rtc::SocketAddress&), (override)); MOCK_METHOD(void, Start, (const rtc::SocketAddress&), (override));
MOCK_METHOD(void, Start, (const rtc::SocketAddress&, int family), (override));
MOCK_METHOD(bool, MOCK_METHOD(bool,
GetResolvedAddress, GetResolvedAddress,
(int family, SocketAddress* addr), (int family, SocketAddress* addr),

View File

@ -845,6 +845,7 @@ rtc_library("async_resolver_interface") {
"async_resolver_interface.h", "async_resolver_interface.h",
] ]
deps = [ deps = [
":checks",
":socket_address", ":socket_address",
"system:rtc_export", "system:rtc_export",
"third_party/sigslot", "third_party/sigslot",

View File

@ -145,14 +145,18 @@ void RunResolution(void* obj) {
} }
void AsyncResolver::Start(const SocketAddress& addr) { void AsyncResolver::Start(const SocketAddress& addr) {
Start(addr, addr.family());
}
void AsyncResolver::Start(const SocketAddress& addr, int family) {
RTC_DCHECK_RUN_ON(&sequence_checker_); RTC_DCHECK_RUN_ON(&sequence_checker_);
RTC_DCHECK(!destroy_called_); RTC_DCHECK(!destroy_called_);
addr_ = addr; addr_ = addr;
auto thread_function = auto thread_function =
[this, addr, caller_task_queue = webrtc::TaskQueueBase::Current(), [this, addr, family, caller_task_queue = webrtc::TaskQueueBase::Current(),
state = state_] { state = state_] {
std::vector<IPAddress> addresses; std::vector<IPAddress> addresses;
int error = ResolveHostname(addr.hostname(), addr.family(), &addresses); int error = ResolveHostname(addr.hostname(), family, &addresses);
webrtc::MutexLock lock(&state->mutex); webrtc::MutexLock lock(&state->mutex);
if (state->status == State::Status::kLive) { if (state->status == State::Status::kLive) {
caller_task_queue->PostTask( caller_task_queue->PostTask(

View File

@ -45,6 +45,7 @@ class RTC_EXPORT AsyncResolver : public AsyncResolverInterface {
~AsyncResolver() override; ~AsyncResolver() override;
void Start(const SocketAddress& addr) override; void Start(const SocketAddress& addr) override;
void Start(const SocketAddress& addr, int family) override;
bool GetResolvedAddress(int family, SocketAddress* addr) const override; bool GetResolvedAddress(int family, SocketAddress* addr) const override;
int GetError() const override; int GetError() const override;
void Destroy(bool wait) override; void Destroy(bool wait) override;

View File

@ -11,6 +11,7 @@
#ifndef RTC_BASE_ASYNC_RESOLVER_INTERFACE_H_ #ifndef RTC_BASE_ASYNC_RESOLVER_INTERFACE_H_
#define RTC_BASE_ASYNC_RESOLVER_INTERFACE_H_ #define RTC_BASE_ASYNC_RESOLVER_INTERFACE_H_
#include "rtc_base/checks.h"
#include "rtc_base/socket_address.h" #include "rtc_base/socket_address.h"
#include "rtc_base/system/rtc_export.h" #include "rtc_base/system/rtc_export.h"
#include "rtc_base/third_party/sigslot/sigslot.h" #include "rtc_base/third_party/sigslot/sigslot.h"
@ -25,6 +26,12 @@ class RTC_EXPORT AsyncResolverInterface {
// Start address resolution of the hostname in `addr`. // Start address resolution of the hostname in `addr`.
virtual void Start(const SocketAddress& addr) = 0; virtual void Start(const SocketAddress& addr) = 0;
// Start address resolution of the hostname in `addr` matching `family`.
virtual void Start(const SocketAddress& addr, int family) {
// TODO(webrtc:14319) make pure virtual when all subclasses have been
// updated.
RTC_DCHECK_NOTREACHED();
}
// Returns true iff the address from `Start` was successfully resolved. // Returns true iff the address from `Start` was successfully resolved.
// If the address was successfully resolved, sets `addr` to a copy of the // If the address was successfully resolved, sets `addr` to a copy of the
// address from `Start` with the IP address set to the top most resolved // address from `Start` with the IP address set to the top most resolved