Replacing SSLIdentity* with scoped_refptr<RTCCertificate> in TransportChannel layer.

BUG=webrtc:4927
R=tommi@webrtc.org, torbjorng@webrtc.org

Review URL: https://codereview.webrtc.org/1304043008 .

Cr-Commit-Position: refs/heads/master@{#9885}
This commit is contained in:
Henrik Boström
2015-09-08 12:11:54 +02:00
parent 8006f07592
commit f3ecdb981c
14 changed files with 199 additions and 216 deletions

View File

@ -714,7 +714,7 @@ void StatsCollector::ExtractSessionInfo() {
transport = session_->GetTransport(transport_iter.second.content_name); transport = session_->GetTransport(transport_iter.second.content_name);
rtc::scoped_ptr<rtc::SSLCertificate> cert; rtc::scoped_ptr<rtc::SSLCertificate> cert;
if (transport && transport->GetRemoteCertificate(cert.accept())) { if (transport && transport->GetRemoteSSLCertificate(cert.accept())) {
StatsReport* r = AddCertificateReports(cert.get()); StatsReport* r = AddCertificateReports(cert.get());
if (r) if (r)
remote_cert_report_id = r->id(); remote_cert_report_id = r->id();

View File

@ -673,7 +673,7 @@ class StatsCollectorTest : public testing::Test {
static_cast<cricket::FakeTransportChannel*>( static_cast<cricket::FakeTransportChannel*>(
transport->CreateChannel(channel_stats.component)); transport->CreateChannel(channel_stats.component));
EXPECT_FALSE(channel == NULL); EXPECT_FALSE(channel == NULL);
channel->SetRemoteCertificate(remote_cert_copy.get()); channel->SetRemoteSSLCertificate(remote_cert_copy.get());
// Configure MockWebRtcSession // Configure MockWebRtcSession
EXPECT_CALL(session_, GetTransport(transport_stats.content_name)) EXPECT_CALL(session_, GetTransport(transport_stats.content_name))

View File

@ -92,9 +92,7 @@ class DtlsTransport : public Base {
certificate_ = nullptr; certificate_ = nullptr;
} }
// TODO(hbos): SetLocalCertificate if (!channel->SetLocalCertificate(certificate_)) {
if (!channel->SetLocalIdentity(
certificate_ ? certificate_->identity() : nullptr)) {
return BadTransportDescription("Failed to set local identity.", return BadTransportDescription("Failed to set local identity.",
error_desc); error_desc);
} }

View File

@ -95,7 +95,6 @@ DtlsTransportChannelWrapper::DtlsTransportChannelWrapper(
channel_(channel), channel_(channel),
downward_(NULL), downward_(NULL),
dtls_state_(STATE_NONE), dtls_state_(STATE_NONE),
local_identity_(NULL),
ssl_role_(rtc::SSL_CLIENT), ssl_role_(rtc::SSL_CLIENT),
ssl_max_version_(rtc::SSL_PROTOCOL_DTLS_10) { ssl_max_version_(rtc::SSL_PROTOCOL_DTLS_10) {
channel_->SignalReadableState.connect(this, channel_->SignalReadableState.connect(this,
@ -133,10 +132,10 @@ void DtlsTransportChannelWrapper::Connect() {
channel_->Connect(); channel_->Connect();
} }
bool DtlsTransportChannelWrapper::SetLocalIdentity( bool DtlsTransportChannelWrapper::SetLocalCertificate(
rtc::SSLIdentity* identity) { const rtc::scoped_refptr<rtc::RTCCertificate>& certificate) {
if (dtls_state_ != STATE_NONE) { if (dtls_state_ != STATE_NONE) {
if (identity == local_identity_) { if (certificate == local_certificate_) {
// This may happen during renegotiation. // This may happen during renegotiation.
LOG_J(LS_INFO, this) << "Ignoring identical DTLS identity"; LOG_J(LS_INFO, this) << "Ignoring identical DTLS identity";
return true; return true;
@ -146,8 +145,8 @@ bool DtlsTransportChannelWrapper::SetLocalIdentity(
} }
} }
if (identity) { if (certificate) {
local_identity_ = identity; local_certificate_ = certificate;
dtls_state_ = STATE_OFFERED; dtls_state_ = STATE_OFFERED;
} else { } else {
LOG_J(LS_INFO, this) << "NULL DTLS identity supplied. Not doing DTLS"; LOG_J(LS_INFO, this) << "NULL DTLS identity supplied. Not doing DTLS";
@ -156,13 +155,9 @@ bool DtlsTransportChannelWrapper::SetLocalIdentity(
return true; return true;
} }
bool DtlsTransportChannelWrapper::GetLocalIdentity( rtc::scoped_refptr<rtc::RTCCertificate>
rtc::SSLIdentity** identity) const { DtlsTransportChannelWrapper::GetLocalCertificate() const {
if (!local_identity_) return local_certificate_;
return false;
*identity = local_identity_->GetReference();
return true;
} }
bool DtlsTransportChannelWrapper::SetSslMaxProtocolVersion( bool DtlsTransportChannelWrapper::SetSslMaxProtocolVersion(
@ -245,7 +240,7 @@ bool DtlsTransportChannelWrapper::SetRemoteFingerprint(
return true; return true;
} }
bool DtlsTransportChannelWrapper::GetRemoteCertificate( bool DtlsTransportChannelWrapper::GetRemoteSSLCertificate(
rtc::SSLCertificate** cert) const { rtc::SSLCertificate** cert) const {
if (!dtls_) if (!dtls_)
return false; return false;
@ -265,7 +260,7 @@ bool DtlsTransportChannelWrapper::SetupDtls() {
downward_ = downward; downward_ = downward;
dtls_->SetIdentity(local_identity_->GetReference()); dtls_->SetIdentity(local_certificate_->identity()->GetReference());
dtls_->SetMode(rtc::SSL_MODE_DTLS); dtls_->SetMode(rtc::SSL_MODE_DTLS);
dtls_->SetMaxProtocolVersion(ssl_max_version_); dtls_->SetMaxProtocolVersion(ssl_max_version_);
dtls_->SetServerRole(ssl_role_); dtls_->SetServerRole(ssl_role_);

View File

@ -33,12 +33,12 @@ class StreamInterfaceChannel : public rtc::StreamInterface {
bool OnPacketReceived(const char* data, size_t size); bool OnPacketReceived(const char* data, size_t size);
// Implementations of StreamInterface // Implementations of StreamInterface
virtual rtc::StreamState GetState() const { return state_; } rtc::StreamState GetState() const override { return state_; }
virtual void Close() { state_ = rtc::SS_CLOSED; } void Close() override { state_ = rtc::SS_CLOSED; }
virtual rtc::StreamResult Read(void* buffer, size_t buffer_len, rtc::StreamResult Read(void* buffer, size_t buffer_len,
size_t* read, int* error); size_t* read, int* error) override;
virtual rtc::StreamResult Write(const void* data, size_t data_len, rtc::StreamResult Write(const void* data, size_t data_len,
size_t* written, int* error); size_t* written, int* error) override;
private: private:
TransportChannel* channel_; // owned by DtlsTransportChannelWrapper TransportChannel* channel_; // owned by DtlsTransportChannelWrapper
@ -91,41 +91,42 @@ class DtlsTransportChannelWrapper : public TransportChannelImpl {
// channel -- the TransportChannel we are wrapping // channel -- the TransportChannel we are wrapping
DtlsTransportChannelWrapper(Transport* transport, DtlsTransportChannelWrapper(Transport* transport,
TransportChannelImpl* channel); TransportChannelImpl* channel);
virtual ~DtlsTransportChannelWrapper(); ~DtlsTransportChannelWrapper() override;
virtual void SetIceRole(IceRole role) { void SetIceRole(IceRole role) override {
channel_->SetIceRole(role); channel_->SetIceRole(role);
} }
virtual IceRole GetIceRole() const { IceRole GetIceRole() const override {
return channel_->GetIceRole(); return channel_->GetIceRole();
} }
virtual bool SetLocalIdentity(rtc::SSLIdentity *identity); bool SetLocalCertificate(
virtual bool GetLocalIdentity(rtc::SSLIdentity** identity) const; const rtc::scoped_refptr<rtc::RTCCertificate>& certificate) override;
rtc::scoped_refptr<rtc::RTCCertificate> GetLocalCertificate() const override;
virtual bool SetRemoteFingerprint(const std::string& digest_alg, bool SetRemoteFingerprint(const std::string& digest_alg,
const uint8* digest, const uint8* digest,
size_t digest_len); size_t digest_len) override;
virtual bool IsDtlsActive() const { return dtls_state_ != STATE_NONE; } bool IsDtlsActive() const override { return dtls_state_ != STATE_NONE; }
// Called to send a packet (via DTLS, if turned on). // Called to send a packet (via DTLS, if turned on).
virtual int SendPacket(const char* data, size_t size, int SendPacket(const char* data, size_t size,
const rtc::PacketOptions& options, const rtc::PacketOptions& options,
int flags); int flags) override;
// TransportChannel calls that we forward to the wrapped transport. // TransportChannel calls that we forward to the wrapped transport.
virtual int SetOption(rtc::Socket::Option opt, int value) { int SetOption(rtc::Socket::Option opt, int value) override {
return channel_->SetOption(opt, value); return channel_->SetOption(opt, value);
} }
virtual bool GetOption(rtc::Socket::Option opt, int* value) { bool GetOption(rtc::Socket::Option opt, int* value) override {
return channel_->GetOption(opt, value); return channel_->GetOption(opt, value);
} }
virtual int GetError() { int GetError() override {
return channel_->GetError(); return channel_->GetError();
} }
virtual bool GetStats(ConnectionInfos* infos) { bool GetStats(ConnectionInfos* infos) override {
return channel_->GetStats(infos); return channel_->GetStats(infos);
} }
virtual const std::string SessionId() const { const std::string SessionId() const override {
return channel_->SessionId(); return channel_->SessionId();
} }
@ -134,31 +135,31 @@ class DtlsTransportChannelWrapper : public TransportChannelImpl {
// Set up the ciphers to use for DTLS-SRTP. If this method is not called // Set up the ciphers to use for DTLS-SRTP. If this method is not called
// before DTLS starts, or |ciphers| is empty, SRTP keys won't be negotiated. // before DTLS starts, or |ciphers| is empty, SRTP keys won't be negotiated.
// This method should be called before SetupDtls. // This method should be called before SetupDtls.
virtual bool SetSrtpCiphers(const std::vector<std::string>& ciphers); bool SetSrtpCiphers(const std::vector<std::string>& ciphers) override;
// Find out which DTLS-SRTP cipher was negotiated // Find out which DTLS-SRTP cipher was negotiated
virtual bool GetSrtpCipher(std::string* cipher); bool GetSrtpCipher(std::string* cipher) override;
virtual bool GetSslRole(rtc::SSLRole* role) const; bool GetSslRole(rtc::SSLRole* role) const override;
virtual bool SetSslRole(rtc::SSLRole role); bool SetSslRole(rtc::SSLRole role) override;
// Find out which DTLS cipher was negotiated // Find out which DTLS cipher was negotiated
virtual bool GetSslCipher(std::string* cipher); bool GetSslCipher(std::string* cipher) override;
// Once DTLS has been established, this method retrieves the certificate in // Once DTLS has been established, this method retrieves the certificate in
// use by the remote peer, for use in external identity verification. // use by the remote peer, for use in external identity verification.
virtual bool GetRemoteCertificate(rtc::SSLCertificate** cert) const; bool GetRemoteSSLCertificate(rtc::SSLCertificate** cert) const override;
// Once DTLS has established (i.e., this channel is writable), this method // Once DTLS has established (i.e., this channel is writable), this method
// extracts the keys negotiated during the DTLS handshake, for use in external // extracts the keys negotiated during the DTLS handshake, for use in external
// encryption. DTLS-SRTP uses this to extract the needed SRTP keys. // encryption. DTLS-SRTP uses this to extract the needed SRTP keys.
// See the SSLStreamAdapter documentation for info on the specific parameters. // See the SSLStreamAdapter documentation for info on the specific parameters.
virtual bool ExportKeyingMaterial(const std::string& label, bool ExportKeyingMaterial(const std::string& label,
const uint8* context, const uint8* context,
size_t context_len, size_t context_len,
bool use_context, bool use_context,
uint8* result, uint8* result,
size_t result_len) { size_t result_len) override {
return (dtls_.get()) ? dtls_->ExportKeyingMaterial(label, context, return (dtls_.get()) ? dtls_->ExportKeyingMaterial(label, context,
context_len, context_len,
use_context, use_context,
@ -167,34 +168,34 @@ class DtlsTransportChannelWrapper : public TransportChannelImpl {
} }
// TransportChannelImpl calls. // TransportChannelImpl calls.
virtual Transport* GetTransport() { Transport* GetTransport() override {
return transport_; return transport_;
} }
virtual TransportChannelState GetState() const { TransportChannelState GetState() const override {
return channel_->GetState(); return channel_->GetState();
} }
virtual void SetIceTiebreaker(uint64 tiebreaker) { void SetIceTiebreaker(uint64 tiebreaker) override {
channel_->SetIceTiebreaker(tiebreaker); channel_->SetIceTiebreaker(tiebreaker);
} }
virtual void SetIceCredentials(const std::string& ice_ufrag, void SetIceCredentials(const std::string& ice_ufrag,
const std::string& ice_pwd) { const std::string& ice_pwd) override {
channel_->SetIceCredentials(ice_ufrag, ice_pwd); channel_->SetIceCredentials(ice_ufrag, ice_pwd);
} }
virtual void SetRemoteIceCredentials(const std::string& ice_ufrag, void SetRemoteIceCredentials(const std::string& ice_ufrag,
const std::string& ice_pwd) { const std::string& ice_pwd) override {
channel_->SetRemoteIceCredentials(ice_ufrag, ice_pwd); channel_->SetRemoteIceCredentials(ice_ufrag, ice_pwd);
} }
virtual void SetRemoteIceMode(IceMode mode) { void SetRemoteIceMode(IceMode mode) override {
channel_->SetRemoteIceMode(mode); channel_->SetRemoteIceMode(mode);
} }
virtual void Connect(); void Connect() override;
virtual void OnSignalingReady() { void OnSignalingReady() override {
channel_->OnSignalingReady(); channel_->OnSignalingReady();
} }
virtual void OnCandidate(const Candidate& candidate) { void OnCandidate(const Candidate& candidate) override {
channel_->OnCandidate(candidate); channel_->OnCandidate(candidate);
} }
@ -230,7 +231,7 @@ class DtlsTransportChannelWrapper : public TransportChannelImpl {
StreamInterfaceChannel* downward_; // Wrapper for channel_, owned by dtls_. StreamInterfaceChannel* downward_; // Wrapper for channel_, owned by dtls_.
std::vector<std::string> srtp_ciphers_; // SRTP ciphers to use with DTLS. std::vector<std::string> srtp_ciphers_; // SRTP ciphers to use with DTLS.
State dtls_state_; State dtls_state_;
rtc::SSLIdentity* local_identity_; rtc::scoped_refptr<rtc::RTCCertificate> local_certificate_;
rtc::SSLRole ssl_role_; rtc::SSLRole ssl_role_;
rtc::SSLProtocolVersion ssl_max_version_; rtc::SSLProtocolVersion ssl_max_version_;
rtc::Buffer remote_fingerprint_value_; rtc::Buffer remote_fingerprint_value_;

View File

@ -842,10 +842,10 @@ TEST_F(DtlsTransportChannelTest, TestCertificatesBeforeConnect) {
ASSERT_NE(certificate1->ssl_certificate().ToPEMString(), ASSERT_NE(certificate1->ssl_certificate().ToPEMString(),
certificate2->ssl_certificate().ToPEMString()); certificate2->ssl_certificate().ToPEMString());
ASSERT_FALSE( ASSERT_FALSE(
client1_.transport()->GetRemoteCertificate(remote_cert1.accept())); client1_.transport()->GetRemoteSSLCertificate(remote_cert1.accept()));
ASSERT_FALSE(remote_cert1 != NULL); ASSERT_FALSE(remote_cert1 != NULL);
ASSERT_FALSE( ASSERT_FALSE(
client2_.transport()->GetRemoteCertificate(remote_cert2.accept())); client2_.transport()->GetRemoteSSLCertificate(remote_cert2.accept()));
ASSERT_FALSE(remote_cert2 != NULL); ASSERT_FALSE(remote_cert2 != NULL);
} }
@ -868,11 +868,11 @@ TEST_F(DtlsTransportChannelTest, TestCertificatesAfterConnect) {
// Each side's remote certificate is the other side's local certificate. // Each side's remote certificate is the other side's local certificate.
ASSERT_TRUE( ASSERT_TRUE(
client1_.transport()->GetRemoteCertificate(remote_cert1.accept())); client1_.transport()->GetRemoteSSLCertificate(remote_cert1.accept()));
ASSERT_EQ(remote_cert1->ToPEMString(), ASSERT_EQ(remote_cert1->ToPEMString(),
certificate2->ssl_certificate().ToPEMString()); certificate2->ssl_certificate().ToPEMString());
ASSERT_TRUE( ASSERT_TRUE(
client2_.transport()->GetRemoteCertificate(remote_cert2.accept())); client2_.transport()->GetRemoteSSLCertificate(remote_cert2.accept()));
ASSERT_EQ(remote_cert2->ToPEMString(), ASSERT_EQ(remote_cert2->ToPEMString(),
certificate1->ssl_certificate().ToPEMString()); certificate1->ssl_certificate().ToPEMString());
} }

View File

@ -47,15 +47,14 @@ class FakeTransportChannel : public TransportChannelImpl,
int component) int component)
: TransportChannelImpl(content_name, component), : TransportChannelImpl(content_name, component),
transport_(transport), transport_(transport),
dest_(NULL), dest_(nullptr),
state_(STATE_INIT), state_(STATE_INIT),
async_(false), async_(false),
identity_(NULL),
do_dtls_(false), do_dtls_(false),
role_(ICEROLE_UNKNOWN), role_(ICEROLE_UNKNOWN),
tiebreaker_(0), tiebreaker_(0),
remote_ice_mode_(ICEMODE_FULL), remote_ice_mode_(ICEMODE_FULL),
dtls_fingerprint_("", NULL, 0), dtls_fingerprint_("", nullptr, 0),
ssl_role_(rtc::SSL_CLIENT), ssl_role_(rtc::SSL_CLIENT),
connection_count_(0) { connection_count_(0) {
} }
@ -77,11 +76,11 @@ class FakeTransportChannel : public TransportChannelImpl,
async_ = async; async_ = async;
} }
virtual Transport* GetTransport() { Transport* GetTransport() override {
return transport_; return transport_;
} }
virtual TransportChannelState GetState() const { TransportChannelState GetState() const override {
if (connection_count_ == 0) { if (connection_count_ == 0) {
return TransportChannelState::STATE_FAILED; return TransportChannelState::STATE_FAILED;
} }
@ -93,36 +92,38 @@ class FakeTransportChannel : public TransportChannelImpl,
return TransportChannelState::STATE_FAILED; return TransportChannelState::STATE_FAILED;
} }
virtual void SetIceRole(IceRole role) { role_ = role; } void SetIceRole(IceRole role) override { role_ = role; }
virtual IceRole GetIceRole() const { return role_; } IceRole GetIceRole() const override { return role_; }
virtual void SetIceTiebreaker(uint64 tiebreaker) { tiebreaker_ = tiebreaker; } void SetIceTiebreaker(uint64 tiebreaker) override {
virtual void SetIceCredentials(const std::string& ice_ufrag, tiebreaker_ = tiebreaker;
const std::string& ice_pwd) { }
void SetIceCredentials(const std::string& ice_ufrag,
const std::string& ice_pwd) override {
ice_ufrag_ = ice_ufrag; ice_ufrag_ = ice_ufrag;
ice_pwd_ = ice_pwd; ice_pwd_ = ice_pwd;
} }
virtual void SetRemoteIceCredentials(const std::string& ice_ufrag, void SetRemoteIceCredentials(const std::string& ice_ufrag,
const std::string& ice_pwd) { const std::string& ice_pwd) override {
remote_ice_ufrag_ = ice_ufrag; remote_ice_ufrag_ = ice_ufrag;
remote_ice_pwd_ = ice_pwd; remote_ice_pwd_ = ice_pwd;
} }
virtual void SetRemoteIceMode(IceMode mode) { remote_ice_mode_ = mode; } void SetRemoteIceMode(IceMode mode) override { remote_ice_mode_ = mode; }
virtual bool SetRemoteFingerprint(const std::string& alg, const uint8* digest, bool SetRemoteFingerprint(const std::string& alg, const uint8* digest,
size_t digest_len) { size_t digest_len) override {
dtls_fingerprint_ = rtc::SSLFingerprint(alg, digest, digest_len); dtls_fingerprint_ = rtc::SSLFingerprint(alg, digest, digest_len);
return true; return true;
} }
virtual bool SetSslRole(rtc::SSLRole role) { bool SetSslRole(rtc::SSLRole role) override {
ssl_role_ = role; ssl_role_ = role;
return true; return true;
} }
virtual bool GetSslRole(rtc::SSLRole* role) const { bool GetSslRole(rtc::SSLRole* role) const override {
*role = ssl_role_; *role = ssl_role_;
return true; return true;
} }
virtual void Connect() { void Connect() override {
if (state_ == STATE_INIT) { if (state_ == STATE_INIT) {
state_ = STATE_CONNECTING; state_ = STATE_CONNECTING;
} }
@ -147,7 +148,7 @@ class FakeTransportChannel : public TransportChannelImpl,
// This simulates the delivery of candidates. // This simulates the delivery of candidates.
dest_ = dest; dest_ = dest;
dest_->dest_ = this; dest_->dest_ = this;
if (identity_ && dest_->identity_) { if (certificate_ && dest_->certificate_) {
do_dtls_ = true; do_dtls_ = true;
dest_->do_dtls_ = true; dest_->do_dtls_ = true;
NegotiateSrtpCiphers(); NegotiateSrtpCiphers();
@ -177,8 +178,8 @@ class FakeTransportChannel : public TransportChannelImpl,
void SetReceivingTimeout(int timeout) override {} void SetReceivingTimeout(int timeout) override {}
virtual int SendPacket(const char* data, size_t len, int SendPacket(const char* data, size_t len,
const rtc::PacketOptions& options, int flags) { const rtc::PacketOptions& options, int flags) override {
if (state_ != STATE_CONNECTED) { if (state_ != STATE_CONNECTED) {
return -1; return -1;
} }
@ -195,22 +196,22 @@ class FakeTransportChannel : public TransportChannelImpl,
} }
return static_cast<int>(len); return static_cast<int>(len);
} }
virtual int SetOption(rtc::Socket::Option opt, int value) { int SetOption(rtc::Socket::Option opt, int value) override {
return true; return true;
} }
virtual bool GetOption(rtc::Socket::Option opt, int* value) { bool GetOption(rtc::Socket::Option opt, int* value) override {
return true; return true;
} }
virtual int GetError() { int GetError() override {
return 0; return 0;
} }
virtual void OnSignalingReady() { void OnSignalingReady() override {
} }
virtual void OnCandidate(const Candidate& candidate) { void OnCandidate(const Candidate& candidate) override {
} }
virtual void OnMessage(rtc::Message* msg) { void OnMessage(rtc::Message* msg) override {
PacketMessageData* data = static_cast<PacketMessageData*>( PacketMessageData* data = static_cast<PacketMessageData*>(
msg->pdata); msg->pdata);
dest_->SignalReadPacket(dest_, data->packet.data<char>(), dest_->SignalReadPacket(dest_, data->packet.data<char>(),
@ -218,26 +219,26 @@ class FakeTransportChannel : public TransportChannelImpl,
delete data; delete data;
} }
bool SetLocalIdentity(rtc::SSLIdentity* identity) { bool SetLocalCertificate(
identity_ = identity; const rtc::scoped_refptr<rtc::RTCCertificate>& certificate) override {
certificate_ = certificate;
return true; return true;
} }
void SetRemoteSSLCertificate(rtc::FakeSSLCertificate* cert) {
void SetRemoteCertificate(rtc::FakeSSLCertificate* cert) {
remote_cert_ = cert; remote_cert_ = cert;
} }
virtual bool IsDtlsActive() const { bool IsDtlsActive() const override {
return do_dtls_; return do_dtls_;
} }
virtual bool SetSrtpCiphers(const std::vector<std::string>& ciphers) { bool SetSrtpCiphers(const std::vector<std::string>& ciphers) override {
srtp_ciphers_ = ciphers; srtp_ciphers_ = ciphers;
return true; return true;
} }
virtual bool GetSrtpCipher(std::string* cipher) { bool GetSrtpCipher(std::string* cipher) override {
if (!chosen_srtp_cipher_.empty()) { if (!chosen_srtp_cipher_.empty()) {
*cipher = chosen_srtp_cipher_; *cipher = chosen_srtp_cipher_;
return true; return true;
@ -245,19 +246,16 @@ class FakeTransportChannel : public TransportChannelImpl,
return false; return false;
} }
virtual bool GetSslCipher(std::string* cipher) { bool GetSslCipher(std::string* cipher) override {
return false; return false;
} }
virtual bool GetLocalIdentity(rtc::SSLIdentity** identity) const { rtc::scoped_refptr<rtc::RTCCertificate>
if (!identity_) GetLocalCertificate() const override {
return false; return certificate_;
*identity = identity_->GetReference();
return true;
} }
virtual bool GetRemoteCertificate(rtc::SSLCertificate** cert) const { bool GetRemoteSSLCertificate(rtc::SSLCertificate** cert) const override {
if (!remote_cert_) if (!remote_cert_)
return false; return false;
@ -265,12 +263,12 @@ class FakeTransportChannel : public TransportChannelImpl,
return true; return true;
} }
virtual bool ExportKeyingMaterial(const std::string& label, bool ExportKeyingMaterial(const std::string& label,
const uint8* context, const uint8* context,
size_t context_len, size_t context_len,
bool use_context, bool use_context,
uint8* result, uint8* result,
size_t result_len) { size_t result_len) override {
if (!chosen_srtp_cipher_.empty()) { if (!chosen_srtp_cipher_.empty()) {
memset(result, 0xff, result_len); memset(result, 0xff, result_len);
return true; return true;
@ -307,7 +305,7 @@ class FakeTransportChannel : public TransportChannelImpl,
FakeTransportChannel* dest_; FakeTransportChannel* dest_;
State state_; State state_;
bool async_; bool async_;
rtc::SSLIdentity* identity_; rtc::scoped_refptr<rtc::RTCCertificate> certificate_;
rtc::FakeSSLCertificate* remote_cert_; rtc::FakeSSLCertificate* remote_cert_;
bool do_dtls_; bool do_dtls_;
std::vector<std::string> srtp_ciphers_; std::vector<std::string> srtp_ciphers_;
@ -350,9 +348,7 @@ class FakeTransport : public Transport {
dest_ = dest; dest_ = dest;
for (ChannelMap::iterator it = channels_.begin(); it != channels_.end(); for (ChannelMap::iterator it = channels_.begin(); it != channels_.end();
++it) { ++it) {
// TODO(hbos): SetLocalCertificate it->second->SetLocalCertificate(certificate_);
it->second->SetLocalIdentity(
certificate_ ? certificate_->identity() : nullptr);
SetChannelDestination(it->first, it->second); SetChannelDestination(it->first, it->second);
} }
} }
@ -373,7 +369,7 @@ class FakeTransport : public Transport {
using Transport::remote_description; using Transport::remote_description;
protected: protected:
virtual TransportChannelImpl* CreateTransportChannel(int component) { TransportChannelImpl* CreateTransportChannel(int component) override {
if (channels_.find(component) != channels_.end()) { if (channels_.find(component) != channels_.end()) {
return NULL; return NULL;
} }
@ -384,7 +380,7 @@ class FakeTransport : public Transport {
channels_[component] = channel; channels_[component] = channel;
return channel; return channel;
} }
virtual void DestroyTransportChannel(TransportChannelImpl* channel) { void DestroyTransportChannel(TransportChannelImpl* channel) override {
channels_.erase(channel->component()); channels_.erase(channel->component());
delete channel; delete channel;
} }
@ -411,11 +407,8 @@ class FakeTransport : public Transport {
FakeTransportChannel* dest_channel = NULL; FakeTransportChannel* dest_channel = NULL;
if (dest_) { if (dest_) {
dest_channel = dest_->GetFakeChannel(component); dest_channel = dest_->GetFakeChannel(component);
if (dest_channel) { if (dest_channel)
// TODO(hbos): SetLocalCertificate dest_channel->SetLocalCertificate(dest_->certificate_);
dest_channel->SetLocalIdentity(
dest_->certificate_ ? dest_->certificate_->identity() : nullptr);
}
} }
channel->SetDestination(dest_channel); channel->SetDestination(dest_channel);
} }
@ -467,9 +460,8 @@ class FakeSession : public BaseSession {
} }
} }
virtual TransportChannel* CreateChannel( TransportChannel* CreateChannel(const std::string& content_name,
const std::string& content_name, int component) override {
int component) {
if (fail_create_channel_) { if (fail_create_channel_) {
return NULL; return NULL;
} }
@ -493,7 +485,7 @@ class FakeSession : public BaseSession {
} }
protected: protected:
virtual Transport* CreateTransport(const std::string& content_name) { Transport* CreateTransport(const std::string& content_name) override {
return new FakeTransport(signaling_thread(), worker_thread(), content_name); return new FakeTransport(signaling_thread(), worker_thread(), content_name);
} }

View File

@ -55,33 +55,33 @@ class P2PTransportChannel : public TransportChannelImpl,
int component, int component,
P2PTransport* transport, P2PTransport* transport,
PortAllocator *allocator); PortAllocator *allocator);
virtual ~P2PTransportChannel(); ~P2PTransportChannel() override;
// From TransportChannelImpl: // From TransportChannelImpl:
virtual Transport* GetTransport() { return transport_; } Transport* GetTransport() override { return transport_; }
virtual TransportChannelState GetState() const; TransportChannelState GetState() const override;
virtual void SetIceRole(IceRole role); void SetIceRole(IceRole role) override;
virtual IceRole GetIceRole() const { return ice_role_; } IceRole GetIceRole() const override { return ice_role_; }
virtual void SetIceTiebreaker(uint64 tiebreaker); void SetIceTiebreaker(uint64 tiebreaker) override;
virtual void SetIceCredentials(const std::string& ice_ufrag, void SetIceCredentials(const std::string& ice_ufrag,
const std::string& ice_pwd); const std::string& ice_pwd) override;
virtual void SetRemoteIceCredentials(const std::string& ice_ufrag, void SetRemoteIceCredentials(const std::string& ice_ufrag,
const std::string& ice_pwd); const std::string& ice_pwd) override;
virtual void SetRemoteIceMode(IceMode mode); void SetRemoteIceMode(IceMode mode) override;
virtual void Connect(); void Connect() override;
virtual void OnSignalingReady(); void OnSignalingReady() override;
virtual void OnCandidate(const Candidate& candidate); void OnCandidate(const Candidate& candidate) override;
// Sets the receiving timeout in milliseconds. // Sets the receiving timeout in milliseconds.
// This also sets the check_receiving_delay proportionally. // This also sets the check_receiving_delay proportionally.
virtual void SetReceivingTimeout(int receiving_timeout_ms); void SetReceivingTimeout(int receiving_timeout_ms) override;
// From TransportChannel: // From TransportChannel:
virtual int SendPacket(const char *data, size_t len, int SendPacket(const char *data, size_t len,
const rtc::PacketOptions& options, int flags); const rtc::PacketOptions& options, int flags) override;
virtual int SetOption(rtc::Socket::Option opt, int value); int SetOption(rtc::Socket::Option opt, int value) override;
virtual bool GetOption(rtc::Socket::Option opt, int* value); bool GetOption(rtc::Socket::Option opt, int* value) override;
virtual int GetError() { return error_; } int GetError() override { return error_; }
virtual bool GetStats(std::vector<ConnectionInfo>* stats); bool GetStats(std::vector<ConnectionInfo>* stats) override;
const Connection* best_connection() const { return best_connection_; } const Connection* best_connection() const { return best_connection_; }
void set_incoming_only(bool value) { incoming_only_ = value; } void set_incoming_only(bool value) { incoming_only_ = value; }
@ -93,61 +93,60 @@ class P2PTransportChannel : public TransportChannelImpl,
IceMode remote_ice_mode() const { return remote_ice_mode_; } IceMode remote_ice_mode() const { return remote_ice_mode_; }
// DTLS methods. // DTLS methods.
virtual bool IsDtlsActive() const { return false; } bool IsDtlsActive() const override { return false; }
// Default implementation. // Default implementation.
virtual bool GetSslRole(rtc::SSLRole* role) const { bool GetSslRole(rtc::SSLRole* role) const override {
return false; return false;
} }
virtual bool SetSslRole(rtc::SSLRole role) { bool SetSslRole(rtc::SSLRole role) override {
return false; return false;
} }
// Set up the ciphers to use for DTLS-SRTP. // Set up the ciphers to use for DTLS-SRTP.
virtual bool SetSrtpCiphers(const std::vector<std::string>& ciphers) { bool SetSrtpCiphers(const std::vector<std::string>& ciphers) override {
return false; return false;
} }
// Find out which DTLS-SRTP cipher was negotiated. // Find out which DTLS-SRTP cipher was negotiated.
virtual bool GetSrtpCipher(std::string* cipher) { bool GetSrtpCipher(std::string* cipher) override {
return false; return false;
} }
// Find out which DTLS cipher was negotiated. // Find out which DTLS cipher was negotiated.
virtual bool GetSslCipher(std::string* cipher) { bool GetSslCipher(std::string* cipher) override {
return false; return false;
} }
// Returns false because the channel is not encrypted by default. // Returns null because the channel is not encrypted by default.
virtual bool GetLocalIdentity(rtc::SSLIdentity** identity) const { rtc::scoped_refptr<rtc::RTCCertificate> GetLocalCertificate() const override {
return false; return nullptr;
} }
virtual bool GetRemoteCertificate(rtc::SSLCertificate** cert) const { bool GetRemoteSSLCertificate(rtc::SSLCertificate** cert) const override {
return false; return false;
} }
// Allows key material to be extracted for external encryption. // Allows key material to be extracted for external encryption.
virtual bool ExportKeyingMaterial( bool ExportKeyingMaterial(const std::string& label,
const std::string& label, const uint8* context,
const uint8* context, size_t context_len,
size_t context_len, bool use_context,
bool use_context, uint8* result,
uint8* result, size_t result_len) override {
size_t result_len) {
return false; return false;
} }
virtual bool SetLocalIdentity(rtc::SSLIdentity* identity) { bool SetLocalCertificate(
const rtc::scoped_refptr<rtc::RTCCertificate>& certificate) override {
return false; return false;
} }
// Set DTLS Remote fingerprint. Must be after local identity set. // Set DTLS Remote fingerprint. Must be after local identity set.
virtual bool SetRemoteFingerprint( bool SetRemoteFingerprint(const std::string& digest_alg,
const std::string& digest_alg, const uint8* digest,
const uint8* digest, size_t digest_len) override {
size_t digest_len) {
return false; return false;
} }
@ -213,7 +212,7 @@ class P2PTransportChannel : public TransportChannelImpl,
void OnNominated(Connection* conn); void OnNominated(Connection* conn);
virtual void OnMessage(rtc::Message *pmsg); void OnMessage(rtc::Message *pmsg) override;
void OnSort(); void OnSort();
void OnPing(); void OnPing();

View File

@ -140,20 +140,20 @@ bool Transport::GetCertificate(
Bind(&Transport::GetCertificate_w, this, certificate)); Bind(&Transport::GetCertificate_w, this, certificate));
} }
bool Transport::GetRemoteCertificate(rtc::SSLCertificate** cert) { bool Transport::GetRemoteSSLCertificate(rtc::SSLCertificate** cert) {
// Channels can be deleted on the worker thread, so for safety the remote // Channels can be deleted on the worker thread, so for safety the remote
// certificate is acquired on the worker thread. // certificate is acquired on the worker thread.
return worker_thread_->Invoke<bool>( return worker_thread_->Invoke<bool>(
Bind(&Transport::GetRemoteCertificate_w, this, cert)); Bind(&Transport::GetRemoteSSLCertificate_w, this, cert));
} }
bool Transport::GetRemoteCertificate_w(rtc::SSLCertificate** cert) { bool Transport::GetRemoteSSLCertificate_w(rtc::SSLCertificate** cert) {
ASSERT(worker_thread()->IsCurrent()); ASSERT(worker_thread()->IsCurrent());
if (channels_.empty()) if (channels_.empty())
return false; return false;
ChannelMap::iterator iter = channels_.begin(); ChannelMap::iterator iter = channels_.begin();
return iter->second->GetRemoteCertificate(cert); return iter->second->GetRemoteSSLCertificate(cert);
} }
void Transport::SetChannelReceivingTimeout(int timeout_ms) { void Transport::SetChannelReceivingTimeout(int timeout_ms) {

View File

@ -208,7 +208,7 @@ class Transport : public rtc::MessageHandler,
bool GetCertificate(rtc::scoped_refptr<rtc::RTCCertificate>* certificate); bool GetCertificate(rtc::scoped_refptr<rtc::RTCCertificate>* certificate);
// Get a copy of the remote certificate in use by the specified channel. // Get a copy of the remote certificate in use by the specified channel.
bool GetRemoteCertificate(rtc::SSLCertificate** cert); bool GetRemoteSSLCertificate(rtc::SSLCertificate** cert);
// Create, destroy, and lookup the channels of this type by their components. // Create, destroy, and lookup the channels of this type by their components.
TransportChannelImpl* CreateChannel(int component); TransportChannelImpl* CreateChannel(int component);
@ -437,7 +437,7 @@ class Transport : public rtc::MessageHandler,
ContentAction action, ContentAction action,
std::string* error_desc); std::string* error_desc);
bool GetStats_w(TransportStats* infos); bool GetStats_w(TransportStats* infos);
bool GetRemoteCertificate_w(rtc::SSLCertificate** cert); bool GetRemoteSSLCertificate_w(rtc::SSLCertificate** cert);
void SetChannelReceivingTimeout_w(int timeout_ms); void SetChannelReceivingTimeout_w(int timeout_ms);

View File

@ -108,11 +108,12 @@ class TransportChannel : public sigslot::has_slots<> {
// Finds out which DTLS cipher was negotiated. // Finds out which DTLS cipher was negotiated.
virtual bool GetSslCipher(std::string* cipher) = 0; virtual bool GetSslCipher(std::string* cipher) = 0;
// Gets a copy of the local SSL identity, owned by the caller. // Gets the local RTCCertificate used for DTLS.
virtual bool GetLocalIdentity(rtc::SSLIdentity** identity) const = 0; virtual rtc::scoped_refptr<rtc::RTCCertificate>
GetLocalCertificate() const = 0;
// Gets a copy of the remote side's SSL certificate, owned by the caller. // Gets a copy of the remote side's SSL certificate, owned by the caller.
virtual bool GetRemoteCertificate(rtc::SSLCertificate** cert) const = 0; virtual bool GetRemoteSSLCertificate(rtc::SSLCertificate** cert) const = 0;
// Allows key material to be extracted for external encryption. // Allows key material to be extracted for external encryption.
virtual bool ExportKeyingMaterial(const std::string& label, virtual bool ExportKeyingMaterial(const std::string& label,

View File

@ -82,11 +82,8 @@ class TransportChannelImpl : public TransportChannel {
virtual void OnCandidate(const Candidate& candidate) = 0; virtual void OnCandidate(const Candidate& candidate) = 0;
// DTLS methods // DTLS methods
// Set DTLS local identity. The identity object is not copied, but the caller virtual bool SetLocalCertificate(
// retains ownership and must delete it after this TransportChannelImpl is const rtc::scoped_refptr<rtc::RTCCertificate>& certificate) = 0;
// destroyed.
// TODO(bemasc): Fix the ownership semantics of this method.
virtual bool SetLocalIdentity(rtc::SSLIdentity* identity) = 0;
// Set DTLS Remote fingerprint. Must be after local identity set. // Set DTLS Remote fingerprint. Must be after local identity set.
virtual bool SetRemoteFingerprint(const std::string& digest_alg, virtual bool SetRemoteFingerprint(const std::string& digest_alg,

View File

@ -189,22 +189,22 @@ bool TransportChannelProxy::GetSslCipher(std::string* cipher) {
return impl_->GetSslCipher(cipher); return impl_->GetSslCipher(cipher);
} }
bool TransportChannelProxy::GetLocalIdentity( rtc::scoped_refptr<rtc::RTCCertificate>
rtc::SSLIdentity** identity) const { TransportChannelProxy::GetLocalCertificate() const {
ASSERT(rtc::Thread::Current() == worker_thread_); ASSERT(rtc::Thread::Current() == worker_thread_);
if (!impl_) { if (!impl_) {
return false; return nullptr;
} }
return impl_->GetLocalIdentity(identity); return impl_->GetLocalCertificate();
} }
bool TransportChannelProxy::GetRemoteCertificate( bool TransportChannelProxy::GetRemoteSSLCertificate(
rtc::SSLCertificate** cert) const { rtc::SSLCertificate** cert) const {
ASSERT(rtc::Thread::Current() == worker_thread_); ASSERT(rtc::Thread::Current() == worker_thread_);
if (!impl_) { if (!impl_) {
return false; return false;
} }
return impl_->GetRemoteCertificate(cert); return impl_->GetRemoteSSLCertificate(cert);
} }
bool TransportChannelProxy::ExportKeyingMaterial(const std::string& label, bool TransportChannelProxy::ExportKeyingMaterial(const std::string& label,

View File

@ -35,39 +35,39 @@ class TransportChannelProxy : public TransportChannel,
public: public:
TransportChannelProxy(const std::string& content_name, TransportChannelProxy(const std::string& content_name,
int component); int component);
virtual ~TransportChannelProxy(); ~TransportChannelProxy() override;
TransportChannelImpl* impl() { return impl_; } TransportChannelImpl* impl() { return impl_; }
virtual TransportChannelState GetState() const; TransportChannelState GetState() const override;
// Sets the implementation to which we will proxy. // Sets the implementation to which we will proxy.
void SetImplementation(TransportChannelImpl* impl); void SetImplementation(TransportChannelImpl* impl);
// Implementation of the TransportChannel interface. These simply forward to // Implementation of the TransportChannel interface. These simply forward to
// the implementation. // the implementation.
virtual int SendPacket(const char* data, size_t len, int SendPacket(const char* data, size_t len,
const rtc::PacketOptions& options, const rtc::PacketOptions& options,
int flags); int flags) override;
virtual int SetOption(rtc::Socket::Option opt, int value); int SetOption(rtc::Socket::Option opt, int value) override;
virtual bool GetOption(rtc::Socket::Option opt, int* value); bool GetOption(rtc::Socket::Option opt, int* value) override;
virtual int GetError(); int GetError() override;
virtual IceRole GetIceRole() const; virtual IceRole GetIceRole() const;
virtual bool GetStats(ConnectionInfos* infos); bool GetStats(ConnectionInfos* infos) override;
virtual bool IsDtlsActive() const; bool IsDtlsActive() const override;
virtual bool GetSslRole(rtc::SSLRole* role) const; bool GetSslRole(rtc::SSLRole* role) const override;
virtual bool SetSslRole(rtc::SSLRole role); virtual bool SetSslRole(rtc::SSLRole role);
virtual bool SetSrtpCiphers(const std::vector<std::string>& ciphers); bool SetSrtpCiphers(const std::vector<std::string>& ciphers) override;
virtual bool GetSrtpCipher(std::string* cipher); bool GetSrtpCipher(std::string* cipher) override;
virtual bool GetSslCipher(std::string* cipher); bool GetSslCipher(std::string* cipher) override;
virtual bool GetLocalIdentity(rtc::SSLIdentity** identity) const; rtc::scoped_refptr<rtc::RTCCertificate> GetLocalCertificate() const override;
virtual bool GetRemoteCertificate(rtc::SSLCertificate** cert) const; bool GetRemoteSSLCertificate(rtc::SSLCertificate** cert) const override;
virtual bool ExportKeyingMaterial(const std::string& label, bool ExportKeyingMaterial(const std::string& label,
const uint8* context, const uint8* context,
size_t context_len, size_t context_len,
bool use_context, bool use_context,
uint8* result, uint8* result,
size_t result_len); size_t result_len) override;
private: private:
// Catch signals from the implementation channel. These just forward to the // Catch signals from the implementation channel. These just forward to the