diff --git a/p2p/base/dtlstransport.cc b/p2p/base/dtlstransport.cc index a576a1bebc..d1c3fec77c 100644 --- a/p2p/base/dtlstransport.cc +++ b/p2p/base/dtlstransport.cc @@ -309,15 +309,6 @@ bool DtlsTransport::SetRemoteFingerprint(const std::string& digest_alg, return true; } -std::unique_ptr DtlsTransport::GetRemoteSSLCertificate() - const { - if (!dtls_) { - return nullptr; - } - - return dtls_->GetPeerCertificate(); -} - std::unique_ptr DtlsTransport::GetRemoteSSLCertChain() const { if (!dtls_) { diff --git a/p2p/base/dtlstransport.h b/p2p/base/dtlstransport.h index da0e7d48e2..a4b1c6cdec 100644 --- a/p2p/base/dtlstransport.h +++ b/p2p/base/dtlstransport.h @@ -143,14 +143,9 @@ class DtlsTransport : public DtlsTransportInternal { // Find out which DTLS cipher was negotiated bool GetSslCipherSuite(int* cipher) override; - // Once DTLS has been established, this method retrieves the certificate in - // use by the remote peer, for use in external identity verification. - // TODO(zhihuang): Remove all the SSLCertificate versions of these methods, - // and replace them with the SSLCertChain versions. Implement the - // PeerConnection::GetRemoteSSLCertificate using the SSLCertChain version. - std::unique_ptr GetRemoteSSLCertificate() const override; - - // Version of the above method that returns the full certificate chain. + // Once DTLS has been established, this method retrieves the certificate + // chain in use by the remote peer, for use in external identity + // verification. std::unique_ptr GetRemoteSSLCertChain() const override; // Once DTLS has established (i.e., this ice_transport is writable), this diff --git a/p2p/base/dtlstransport_unittest.cc b/p2p/base/dtlstransport_unittest.cc index d77f77c954..c66e02ba90 100644 --- a/p2p/base/dtlstransport_unittest.cc +++ b/p2p/base/dtlstransport_unittest.cc @@ -515,8 +515,8 @@ TEST_F(DtlsTransportTest, TestCertificatesBeforeConnect) { auto certificate2 = client2_.dtls_transport()->GetLocalCertificate(); ASSERT_NE(certificate1->ssl_certificate().ToPEMString(), certificate2->ssl_certificate().ToPEMString()); - ASSERT_FALSE(client1_.dtls_transport()->GetRemoteSSLCertificate()); - ASSERT_FALSE(client2_.dtls_transport()->GetRemoteSSLCertificate()); + ASSERT_FALSE(client1_.dtls_transport()->GetRemoteSSLCertChain()); + ASSERT_FALSE(client2_.dtls_transport()->GetRemoteSSLCertChain()); } // Test Certificates state after connection. @@ -531,15 +531,17 @@ TEST_F(DtlsTransportTest, TestCertificatesAfterConnect) { certificate2->ssl_certificate().ToPEMString()); // Each side's remote certificate is the other side's local certificate. - std::unique_ptr remote_cert1 = - client1_.dtls_transport()->GetRemoteSSLCertificate(); + std::unique_ptr remote_cert1 = + client1_.dtls_transport()->GetRemoteSSLCertChain(); ASSERT_TRUE(remote_cert1); - ASSERT_EQ(remote_cert1->ToPEMString(), + ASSERT_EQ(1u, remote_cert1->GetSize()); + ASSERT_EQ(remote_cert1->Get(0).ToPEMString(), certificate2->ssl_certificate().ToPEMString()); - std::unique_ptr remote_cert2 = - client2_.dtls_transport()->GetRemoteSSLCertificate(); + std::unique_ptr remote_cert2 = + client2_.dtls_transport()->GetRemoteSSLCertChain(); ASSERT_TRUE(remote_cert2); - ASSERT_EQ(remote_cert2->ToPEMString(), + ASSERT_EQ(1u, remote_cert2->GetSize()); + ASSERT_EQ(remote_cert2->Get(0).ToPEMString(), certificate1->ssl_certificate().ToPEMString()); } diff --git a/p2p/base/dtlstransportinternal.h b/p2p/base/dtlstransportinternal.h index e100305808..58d36fa08d 100644 --- a/p2p/base/dtlstransportinternal.h +++ b/p2p/base/dtlstransportinternal.h @@ -78,10 +78,6 @@ class DtlsTransportInternal : public rtc::PacketTransportInternal { virtual bool SetLocalCertificate( const rtc::scoped_refptr& certificate) = 0; - // Gets a copy of the remote side's SSL certificate. - virtual std::unique_ptr GetRemoteSSLCertificate() - const = 0; - // Gets a copy of the remote side's SSL certificate chain. virtual std::unique_ptr GetRemoteSSLCertChain() const = 0; diff --git a/p2p/base/fakedtlstransport.h b/p2p/base/fakedtlstransport.h index 770f7c90ef..a1bea1359f 100644 --- a/p2p/base/fakedtlstransport.h +++ b/p2p/base/fakedtlstransport.h @@ -178,14 +178,9 @@ class FakeDtlsTransport : public DtlsTransportInternal { rtc::scoped_refptr GetLocalCertificate() const override { return local_cert_; } - std::unique_ptr GetRemoteSSLCertificate() - const override { - return remote_cert_ ? std::unique_ptr( - remote_cert_->GetReference()) - : nullptr; - } std::unique_ptr GetRemoteSSLCertChain() const override { - return nullptr; + return remote_cert_ ? rtc::MakeUnique(remote_cert_) + : nullptr; } bool ExportKeyingMaterial(const std::string& label, const uint8_t* context, diff --git a/pc/jseptransportcontroller.cc b/pc/jseptransportcontroller.cc index aac55c4c25..5235791792 100644 --- a/pc/jseptransportcontroller.cc +++ b/pc/jseptransportcontroller.cc @@ -243,12 +243,12 @@ JsepTransportController::GetLocalCertificate( return t->GetLocalCertificate(); } -std::unique_ptr -JsepTransportController::GetRemoteSSLCertificate( +std::unique_ptr +JsepTransportController::GetRemoteSSLCertChain( const std::string& transport_name) const { if (!network_thread_->IsCurrent()) { - return network_thread_->Invoke>( - RTC_FROM_HERE, [&] { return GetRemoteSSLCertificate(transport_name); }); + return network_thread_->Invoke>( + RTC_FROM_HERE, [&] { return GetRemoteSSLCertChain(transport_name); }); } // Get the certificate from the RTP channel's DTLS handshake. Should be @@ -259,7 +259,7 @@ JsepTransportController::GetRemoteSSLCertificate( return nullptr; } - return dtls->GetRemoteSSLCertificate(); + return dtls->GetRemoteSSLCertChain(); } void JsepTransportController::MaybeStartGathering() { diff --git a/pc/jseptransportcontroller.h b/pc/jseptransportcontroller.h index 9755f4c9d7..fffdd33f33 100644 --- a/pc/jseptransportcontroller.h +++ b/pc/jseptransportcontroller.h @@ -122,9 +122,9 @@ class JsepTransportController : public sigslot::has_slots<>, const rtc::scoped_refptr& certificate); rtc::scoped_refptr GetLocalCertificate( const std::string& mid) const; - // Caller owns returned certificate. This method mainly exists for stats - // reporting. - std::unique_ptr GetRemoteSSLCertificate( + // Caller owns returned certificate chain. This method mainly exists for + // stats reporting. + std::unique_ptr GetRemoteSSLCertChain( const std::string& mid) const; // Get negotiated role, if one has been negotiated. rtc::Optional GetDtlsRole(const std::string& mid) const; diff --git a/pc/jseptransportcontroller_unittest.cc b/pc/jseptransportcontroller_unittest.cc index 887c4beacf..78f4986914 100644 --- a/pc/jseptransportcontroller_unittest.cc +++ b/pc/jseptransportcontroller_unittest.cc @@ -462,7 +462,7 @@ TEST_F(JsepTransportControllerTest, SetAndGetLocalCertificate) { EXPECT_FALSE(transport_controller_->SetLocalCertificate(certificate2)); } -TEST_F(JsepTransportControllerTest, GetRemoteSSLCertificate) { +TEST_F(JsepTransportControllerTest, GetRemoteSSLCertChain) { CreateJsepTransportController(JsepTransportController::Config()); auto description = CreateSessionDescriptionWithBundleGroup(); EXPECT_TRUE(transport_controller_ @@ -473,14 +473,15 @@ TEST_F(JsepTransportControllerTest, GetRemoteSSLCertificate) { auto fake_audio_dtls = static_cast( transport_controller_->GetDtlsTransport(kAudioMid1)); fake_audio_dtls->SetRemoteSSLCertificate(&fake_certificate); - std::unique_ptr returned_certificate = - transport_controller_->GetRemoteSSLCertificate(kAudioMid1); - EXPECT_TRUE(returned_certificate); + std::unique_ptr returned_cert_chain = + transport_controller_->GetRemoteSSLCertChain(kAudioMid1); + ASSERT_TRUE(returned_cert_chain); + ASSERT_EQ(1u, returned_cert_chain->GetSize()); EXPECT_EQ(fake_certificate.ToPEMString(), - returned_certificate->ToPEMString()); + returned_cert_chain->Get(0).ToPEMString()); // Should fail if called for a nonexistant transport. - EXPECT_FALSE(transport_controller_->GetRemoteSSLCertificate(kAudioMid2)); + EXPECT_FALSE(transport_controller_->GetRemoteSSLCertChain(kAudioMid2)); } TEST_F(JsepTransportControllerTest, GetDtlsRole) { diff --git a/pc/peerconnection.cc b/pc/peerconnection.cc index 211c3d6c1c..0681acbb4a 100644 --- a/pc/peerconnection.cc +++ b/pc/peerconnection.cc @@ -2945,12 +2945,11 @@ void PeerConnection::SetAudioRecording(bool recording) { std::unique_ptr PeerConnection::GetRemoteAudioSSLCertificate() { - auto audio_transceiver = GetFirstAudioTransceiver(); - if (!audio_transceiver || !audio_transceiver->internal()->channel()) { + std::unique_ptr chain = GetRemoteAudioSSLCertChain(); + if (!chain || !chain->GetSize()) { return nullptr; } - return GetRemoteSSLCertificate( - audio_transceiver->internal()->channel()->transport_name()); + return chain->Get(0).GetUniqueReference(); } std::unique_ptr @@ -5007,9 +5006,9 @@ bool PeerConnection::GetLocalCertificate( certificate); } -std::unique_ptr PeerConnection::GetRemoteSSLCertificate( +std::unique_ptr PeerConnection::GetRemoteSSLCertChain( const std::string& transport_name) { - return transport_controller_->GetRemoteSSLCertificate(transport_name); + return transport_controller_->GetRemoteSSLCertChain(transport_name); } cricket::DataChannelType PeerConnection::data_channel_type() const { diff --git a/pc/peerconnection.h b/pc/peerconnection.h index 7171deec9b..2f86939081 100644 --- a/pc/peerconnection.h +++ b/pc/peerconnection.h @@ -242,7 +242,7 @@ class PeerConnection : public PeerConnectionInternal, bool GetLocalCertificate( const std::string& transport_name, rtc::scoped_refptr* certificate) override; - std::unique_ptr GetRemoteSSLCertificate( + std::unique_ptr GetRemoteSSLCertChain( const std::string& transport_name) override; bool IceRestartPending(const std::string& content_name) const override; bool NeedsIceRestart(const std::string& content_name) const override; diff --git a/pc/peerconnectioninternal.h b/pc/peerconnectioninternal.h index 8cf4989976..f0267a700d 100644 --- a/pc/peerconnectioninternal.h +++ b/pc/peerconnectioninternal.h @@ -71,7 +71,7 @@ class PeerConnectionInternal : public PeerConnectionInterface { virtual bool GetLocalCertificate( const std::string& transport_name, rtc::scoped_refptr* certificate) = 0; - virtual std::unique_ptr GetRemoteSSLCertificate( + virtual std::unique_ptr GetRemoteSSLCertChain( const std::string& transport_name) = 0; // Returns true if there was an ICE restart initiated by the remote offer. diff --git a/pc/rtcstatscollector.cc b/pc/rtcstatscollector.cc index dd8f69e785..80e67c5eb4 100644 --- a/pc/rtcstatscollector.cc +++ b/pc/rtcstatscollector.cc @@ -1234,13 +1234,13 @@ RTCStatsCollector::PrepareTransportCertificateStats_n( rtc::scoped_refptr local_certificate; if (pc_->GetLocalCertificate(transport_name, &local_certificate)) { certificate_stats_pair.local = - local_certificate->ssl_certificate().GetStats(); + local_certificate->ssl_cert_chain().GetStats(); } - std::unique_ptr remote_certificate = - pc_->GetRemoteSSLCertificate(transport_name); - if (remote_certificate) { - certificate_stats_pair.remote = remote_certificate->GetStats(); + std::unique_ptr remote_cert_chain = + pc_->GetRemoteSSLCertChain(transport_name); + if (remote_cert_chain) { + certificate_stats_pair.remote = remote_cert_chain->GetStats(); } transport_cert_stats.insert( diff --git a/pc/rtcstatscollector_unittest.cc b/pc/rtcstatscollector_unittest.cc index f95914b71c..3456f2520f 100644 --- a/pc/rtcstatscollector_unittest.cc +++ b/pc/rtcstatscollector_unittest.cc @@ -133,7 +133,7 @@ std::unique_ptr CreateFakeCertificateAndInfoFromDers( } info->certificate = rtc::RTCCertificate::Create(std::unique_ptr( - new rtc::FakeSSLIdentity(rtc::FakeSSLCertificate(info->pems)))); + new rtc::FakeSSLIdentity(info->pems))); // Strip header/footer and newline characters of PEM strings. for (size_t i = 0; i < info->pems.size(); ++i) { rtc::replace_substrs("-----BEGIN CERTIFICATE-----", 27, @@ -143,21 +143,14 @@ std::unique_ptr CreateFakeCertificateAndInfoFromDers( rtc::replace_substrs("\n", 1, "", 0, &info->pems[i]); } - // Fingerprint of leaf certificate. - std::unique_ptr fp( - rtc::SSLFingerprint::Create("sha-1", - &info->certificate->ssl_certificate())); - EXPECT_TRUE(fp); - info->fingerprints.push_back(fp->GetRfc4572Fingerprint()); - // Fingerprints of the rest of the chain. - std::unique_ptr chain = - info->certificate->ssl_certificate().GetChain(); - if (chain) { - for (size_t i = 0; i < chain->GetSize(); i++) { - fp.reset(rtc::SSLFingerprint::Create("sha-1", &chain->Get(i))); - EXPECT_TRUE(fp); - info->fingerprints.push_back(fp->GetRfc4572Fingerprint()); - } + // Fingerprints for the whole certificate chain, starting with leaf + // certificate. + const rtc::SSLCertChain& chain = info->certificate->ssl_cert_chain(); + std::unique_ptr fp; + for (size_t i = 0; i < chain.GetSize(); i++) { + fp.reset(rtc::SSLFingerprint::Create("sha-1", &chain.Get(i))); + EXPECT_TRUE(fp); + info->fingerprints.push_back(fp->GetRfc4572Fingerprint()); } EXPECT_EQ(info->ders.size(), info->fingerprints.size()); return info; @@ -579,9 +572,9 @@ TEST_F(RTCStatsCollectorTest, CollectRTCCertificateStatsSingle) { std::unique_ptr remote_certinfo = CreateFakeCertificateAndInfoFromDers( std::vector({ "(remote) single certificate" })); - pc_->SetRemoteCertificate( + pc_->SetRemoteCertChain( kTransportName, - remote_certinfo->certificate->ssl_certificate().GetUniqueReference()); + remote_certinfo->certificate->ssl_cert_chain().UniqueCopy()); rtc::scoped_refptr report = stats_->GetStatsReport(); @@ -693,9 +686,9 @@ TEST_F(RTCStatsCollectorTest, CollectRTCCertificateStatsMultiple) { std::unique_ptr audio_remote_certinfo = CreateFakeCertificateAndInfoFromDers( std::vector({ "(remote) audio" })); - pc_->SetRemoteCertificate( - kAudioTransport, audio_remote_certinfo->certificate->ssl_certificate() - .GetUniqueReference()); + pc_->SetRemoteCertChain( + kAudioTransport, + audio_remote_certinfo->certificate->ssl_cert_chain().UniqueCopy()); pc_->AddVideoChannel("video", kVideoTransport); std::unique_ptr video_local_certinfo = @@ -705,9 +698,9 @@ TEST_F(RTCStatsCollectorTest, CollectRTCCertificateStatsMultiple) { std::unique_ptr video_remote_certinfo = CreateFakeCertificateAndInfoFromDers( std::vector({ "(remote) video" })); - pc_->SetRemoteCertificate( - kVideoTransport, video_remote_certinfo->certificate->ssl_certificate() - .GetUniqueReference()); + pc_->SetRemoteCertChain( + kVideoTransport, + video_remote_certinfo->certificate->ssl_cert_chain().UniqueCopy()); rtc::scoped_refptr report = stats_->GetStatsReport(); ExpectReportContainsCertificateInfo(report, *audio_local_certinfo); @@ -730,9 +723,9 @@ TEST_F(RTCStatsCollectorTest, CollectRTCCertificateStatsChain) { CreateFakeCertificateAndInfoFromDers({"(remote) this", "(remote) is", "(remote) another", "(remote) chain"}); - pc_->SetRemoteCertificate( + pc_->SetRemoteCertChain( kTransportName, - remote_certinfo->certificate->ssl_certificate().GetUniqueReference()); + remote_certinfo->certificate->ssl_cert_chain().UniqueCopy()); rtc::scoped_refptr report = stats_->GetStatsReport(); ExpectReportContainsCertificateInfo(report, *local_certinfo); @@ -1805,9 +1798,9 @@ TEST_F(RTCStatsCollectorTest, CollectRTCTransportStats) { std::unique_ptr remote_certinfo = CreateFakeCertificateAndInfoFromDers( {"(remote) local", "(remote) chain"}); - pc_->SetRemoteCertificate( + pc_->SetRemoteCertChain( kTransportName, - remote_certinfo->certificate->ssl_certificate().GetUniqueReference()); + remote_certinfo->certificate->ssl_cert_chain().UniqueCopy()); report = stats_->GetFreshStatsReport(); diff --git a/pc/statscollector.cc b/pc/statscollector.cc index e7fb95b76a..53d429c1ff 100644 --- a/pc/statscollector.cc +++ b/pc/statscollector.cc @@ -623,14 +623,12 @@ bool StatsCollector::IsValidTrack(const std::string& track_id) { } StatsReport* StatsCollector::AddCertificateReports( - const rtc::SSLCertificate* cert) { + std::unique_ptr cert_stats) { RTC_DCHECK(pc_->signaling_thread()->IsCurrent()); - RTC_DCHECK(cert != NULL); - std::unique_ptr first_stats = cert->GetStats(); StatsReport* first_report = nullptr; StatsReport* prev_report = nullptr; - for (rtc::SSLCertificateStats* stats = first_stats.get(); stats; + for (rtc::SSLCertificateStats* stats = cert_stats.get(); stats; stats = stats->issuer.get()) { StatsReport::Id id(StatsReport::NewTypedId( StatsReport::kStatsReportTypeCertificate, stats->fingerprint)); @@ -786,15 +784,16 @@ void StatsCollector::ExtractSessionInfo() { StatsReport::Id local_cert_report_id, remote_cert_report_id; rtc::scoped_refptr certificate; if (pc_->GetLocalCertificate(transport_name, &certificate)) { - StatsReport* r = AddCertificateReports(&(certificate->ssl_certificate())); + StatsReport* r = + AddCertificateReports(certificate->ssl_cert_chain().GetStats()); if (r) local_cert_report_id = r->id(); } - std::unique_ptr cert = - pc_->GetRemoteSSLCertificate(transport_name); - if (cert) { - StatsReport* r = AddCertificateReports(cert.get()); + std::unique_ptr remote_cert_chain = + pc_->GetRemoteSSLCertChain(transport_name); + if (remote_cert_chain) { + StatsReport* r = AddCertificateReports(remote_cert_chain->GetStats()); if (r) remote_cert_report_id = r->id(); } diff --git a/pc/statscollector.h b/pc/statscollector.h index abd65a97f9..1ed650d2a0 100644 --- a/pc/statscollector.h +++ b/pc/statscollector.h @@ -15,6 +15,7 @@ #define PC_STATSCOLLECTOR_H_ #include +#include #include #include #include @@ -101,8 +102,9 @@ class StatsCollector { bool local); // Adds a report for this certificate and every certificate in its chain, and - // returns the leaf certificate's report (|cert|'s report). - StatsReport* AddCertificateReports(const rtc::SSLCertificate* cert); + // returns the leaf certificate's report (|cert_stats|'s report). + StatsReport* AddCertificateReports( + std::unique_ptr cert_stats); StatsReport* AddConnectionInfoReport(const std::string& content_name, int component, int connection_id, diff --git a/pc/statscollector_unittest.cc b/pc/statscollector_unittest.cc index 46e7bd24ca..081c0b3429 100644 --- a/pc/statscollector_unittest.cc +++ b/pc/statscollector_unittest.cc @@ -636,11 +636,10 @@ class StatsCollectorTest : public testing::Test { } } - void TestCertificateReports( - const rtc::FakeSSLCertificate& local_cert, - const std::vector& local_ders, - std::unique_ptr remote_cert, - const std::vector& remote_ders) { + void TestCertificateReports(const rtc::FakeSSLIdentity& local_identity, + const std::vector& local_ders, + const rtc::FakeSSLIdentity& remote_identity, + const std::vector& remote_ders) { const std::string kTransportName = "transport"; auto pc = CreatePeerConnection(); @@ -656,12 +655,13 @@ class StatsCollectorTest : public testing::Test { internal::TLS_ECDHE_RSA_WITH_AES_256_CBC_SHA; pc->SetTransportStats(kTransportName, channel_stats); - // Fake certificate to report + // Fake certificate to report. rtc::scoped_refptr local_certificate( rtc::RTCCertificate::Create( - rtc::MakeUnique(local_cert))); + std::unique_ptr(local_identity.GetReference()))); pc->SetLocalCertificate(kTransportName, local_certificate); - pc->SetRemoteCertificate(kTransportName, std::move(remote_cert)); + pc->SetRemoteCertChain(kTransportName, + remote_identity.cert_chain().UniqueCopy()); stats->UpdateStats(PeerConnectionInterface::kStatsOutputLevelStandard); @@ -1329,7 +1329,7 @@ TEST_F(StatsCollectorTest, ChainedCertificateReportsCreated) { local_ders[2] = "some"; local_ders[3] = "der"; local_ders[4] = "values"; - rtc::FakeSSLCertificate local_cert(DersToPems(local_ders)); + rtc::FakeSSLIdentity local_identity(DersToPems(local_ders)); // Build remote certificate chain std::vector remote_ders(4); @@ -1337,10 +1337,9 @@ TEST_F(StatsCollectorTest, ChainedCertificateReportsCreated) { remote_ders[1] = "non-"; remote_ders[2] = "intersecting"; remote_ders[3] = "set"; - std::unique_ptr remote_cert( - new rtc::FakeSSLCertificate(DersToPems(remote_ders))); + rtc::FakeSSLIdentity remote_identity(DersToPems(remote_ders)); - TestCertificateReports(local_cert, local_ders, std::move(remote_cert), + TestCertificateReports(local_identity, local_ders, remote_identity, remote_ders); } @@ -1349,15 +1348,14 @@ TEST_F(StatsCollectorTest, ChainedCertificateReportsCreated) { TEST_F(StatsCollectorTest, ChainlessCertificateReportsCreated) { // Build local certificate. std::string local_der = "This is the local der."; - rtc::FakeSSLCertificate local_cert(DerToPem(local_der)); + rtc::FakeSSLIdentity local_identity(DerToPem(local_der)); // Build remote certificate. std::string remote_der = "This is somebody else's der."; - std::unique_ptr remote_cert( - new rtc::FakeSSLCertificate(DerToPem(remote_der))); + rtc::FakeSSLIdentity remote_identity(DerToPem(remote_der)); - TestCertificateReports(local_cert, std::vector(1, local_der), - std::move(remote_cert), + TestCertificateReports(local_identity, std::vector(1, local_der), + remote_identity, std::vector(1, remote_der)); } @@ -1405,16 +1403,16 @@ TEST_F(StatsCollectorTest, NoTransport) { TEST_F(StatsCollectorTest, UnsupportedDigestIgnored) { // Build a local certificate. std::string local_der = "This is the local der."; - rtc::FakeSSLCertificate local_cert(DerToPem(local_der)); + rtc::FakeSSLIdentity local_identity(DerToPem(local_der)); // Build a remote certificate with an unsupported digest algorithm. std::string remote_der = "This is somebody else's der."; - std::unique_ptr remote_cert( - new rtc::FakeSSLCertificate(DerToPem(remote_der))); - remote_cert->set_digest_algorithm("foobar"); + rtc::FakeSSLCertificate remote_cert(DerToPem(remote_der)); + remote_cert.set_digest_algorithm("foobar"); + rtc::FakeSSLIdentity remote_identity(remote_cert); - TestCertificateReports(local_cert, std::vector(1, local_der), - std::move(remote_cert), std::vector()); + TestCertificateReports(local_identity, std::vector(1, local_der), + remote_identity, std::vector()); } // This test verifies that the audio/video related stats which are -1 initially diff --git a/pc/test/fakepeerconnectionbase.h b/pc/test/fakepeerconnectionbase.h index 99964cbe2c..5ca1ba7b95 100644 --- a/pc/test/fakepeerconnectionbase.h +++ b/pc/test/fakepeerconnectionbase.h @@ -292,7 +292,7 @@ class FakePeerConnectionBase : public PeerConnectionInternal { return false; } - std::unique_ptr GetRemoteSSLCertificate( + std::unique_ptr GetRemoteSSLCertChain( const std::string& transport_name) override { return nullptr; } diff --git a/pc/test/fakepeerconnectionforstats.h b/pc/test/fakepeerconnectionforstats.h index c26d144ca4..1c2fb70543 100644 --- a/pc/test/fakepeerconnectionforstats.h +++ b/pc/test/fakepeerconnectionforstats.h @@ -196,9 +196,9 @@ class FakePeerConnectionForStats : public FakePeerConnectionBase { local_certificates_by_transport_[transport_name] = certificate; } - void SetRemoteCertificate(const std::string& transport_name, - std::unique_ptr certificate) { - remote_certificates_by_transport_[transport_name] = std::move(certificate); + void SetRemoteCertChain(const std::string& transport_name, + std::unique_ptr chain) { + remote_cert_chains_by_transport_[transport_name] = std::move(chain); } // PeerConnectionInterface overrides. @@ -313,11 +313,11 @@ class FakePeerConnectionForStats : public FakePeerConnectionBase { } } - std::unique_ptr GetRemoteSSLCertificate( + std::unique_ptr GetRemoteSSLCertChain( const std::string& transport_name) override { - auto it = remote_certificates_by_transport_.find(transport_name); - if (it != remote_certificates_by_transport_.end()) { - return it->second->GetUniqueReference(); + auto it = remote_cert_chains_by_transport_.find(transport_name); + if (it != remote_cert_chains_by_transport_.end()) { + return it->second->UniqueCopy(); } else { return nullptr; } @@ -379,8 +379,8 @@ class FakePeerConnectionForStats : public FakePeerConnectionBase { std::map> local_certificates_by_transport_; - std::map> - remote_certificates_by_transport_; + std::map> + remote_cert_chains_by_transport_; }; } // namespace webrtc diff --git a/pc/transportcontroller.cc b/pc/transportcontroller.cc index 1876c3be78..4e20981da6 100644 --- a/pc/transportcontroller.cc +++ b/pc/transportcontroller.cc @@ -191,17 +191,6 @@ bool TransportController::GetLocalCertificate( this, transport_name, certificate)); } -std::unique_ptr -TransportController::GetRemoteSSLCertificate( - const std::string& transport_name) const { - if (network_thread_->IsCurrent()) { - return GetRemoteSSLCertificate_n(transport_name); - } - return network_thread_->Invoke>( - RTC_FROM_HERE, rtc::Bind(&TransportController::GetRemoteSSLCertificate_n, - this, transport_name)); -} - std::unique_ptr TransportController::GetRemoteSSLCertChain( const std::string& transport_name) const { if (!network_thread_->IsCurrent()) { @@ -209,6 +198,9 @@ std::unique_ptr TransportController::GetRemoteSSLCertChain( RTC_FROM_HERE, [&] { return GetRemoteSSLCertChain(transport_name); }); } + // Get the certificate from the RTP channel's DTLS handshake. Should be + // identical to the RTCP channel's, since they were given the same remote + // fingerprint. const RefCountedChannel* ch = GetChannel_n(transport_name, cricket::ICE_CANDIDATE_COMPONENT_RTP); if (!ch) { @@ -762,21 +754,6 @@ bool TransportController::GetLocalCertificate_n( return t->GetLocalCertificate(certificate); } -std::unique_ptr -TransportController::GetRemoteSSLCertificate_n( - const std::string& transport_name) const { - RTC_DCHECK(network_thread_->IsCurrent()); - - // Get the certificate from the RTP channel's DTLS handshake. Should be - // identical to the RTCP channel's, since they were given the same remote - // fingerprint. - const RefCountedChannel* ch = GetChannel_n(transport_name, 1); - if (!ch) { - return nullptr; - } - return ch->dtls()->GetRemoteSSLCertificate(); -} - bool TransportController::SetLocalTransportDescription_n( const std::string& transport_name, const TransportDescription& tdesc, diff --git a/pc/transportcontroller.h b/pc/transportcontroller.h index bd2be7c765..ea07dfe83f 100644 --- a/pc/transportcontroller.h +++ b/pc/transportcontroller.h @@ -91,10 +91,8 @@ class TransportController : public sigslot::has_slots<>, bool GetLocalCertificate( const std::string& transport_name, rtc::scoped_refptr* certificate) const; - // Caller owns returned certificate. This method mainly exists for stats - // reporting. - std::unique_ptr GetRemoteSSLCertificate( - const std::string& transport_name) const; + // Caller owns returned certificate chain. This method mainly exists for + // stats reporting. std::unique_ptr GetRemoteSSLCertChain( const std::string& transport_name) const; @@ -249,8 +247,6 @@ class TransportController : public sigslot::has_slots<>, bool GetLocalCertificate_n( const std::string& transport_name, rtc::scoped_refptr* certificate) const; - std::unique_ptr GetRemoteSSLCertificate_n( - const std::string& transport_name) const; bool SetLocalTransportDescription_n(const std::string& transport_name, const TransportDescription& tdesc, webrtc::SdpType type, diff --git a/pc/transportcontroller_unittest.cc b/pc/transportcontroller_unittest.cc index 05c30c41b6..4f38c66b2c 100644 --- a/pc/transportcontroller_unittest.cc +++ b/pc/transportcontroller_unittest.cc @@ -348,21 +348,21 @@ TEST_F(TransportControllerTest, TestSetAndGetLocalCertificate) { EXPECT_FALSE(transport_controller_->SetLocalCertificate(certificate2)); } -TEST_F(TransportControllerTest, TestGetRemoteSSLCertificate) { +TEST_F(TransportControllerTest, TestGetRemoteSSLCertChain) { rtc::FakeSSLCertificate fake_certificate("fake_data"); FakeDtlsTransport* transport = CreateFakeDtlsTransport("audio", 1); ASSERT_NE(nullptr, transport); transport->SetRemoteSSLCertificate(&fake_certificate); - std::unique_ptr returned_certificate = - transport_controller_->GetRemoteSSLCertificate("audio"); - EXPECT_TRUE(returned_certificate); + std::unique_ptr returned_cert_chain = + transport_controller_->GetRemoteSSLCertChain("audio"); + EXPECT_TRUE(returned_cert_chain); EXPECT_EQ(fake_certificate.ToPEMString(), - returned_certificate->ToPEMString()); + returned_cert_chain->Get(0).ToPEMString()); // Should fail if called for a nonexistant transport. - EXPECT_FALSE(transport_controller_->GetRemoteSSLCertificate("video")); + EXPECT_FALSE(transport_controller_->GetRemoteSSLCertChain("video")); } TEST_F(TransportControllerTest, TestSetLocalTransportDescription) { diff --git a/rtc_base/fakesslidentity.cc b/rtc_base/fakesslidentity.cc index 296f09c918..825c89bda3 100644 --- a/rtc_base/fakesslidentity.cc +++ b/rtc_base/fakesslidentity.cc @@ -20,19 +20,10 @@ namespace rtc { -FakeSSLCertificate::FakeSSLCertificate(const std::string& data) - : data_(data), digest_algorithm_(DIGEST_SHA_1), expiration_time_(-1) {} - -FakeSSLCertificate::FakeSSLCertificate(const std::vector& certs) - : data_(certs.front()), +FakeSSLCertificate::FakeSSLCertificate(const std::string& pem_string) + : pem_string_(pem_string), digest_algorithm_(DIGEST_SHA_1), - expiration_time_(-1) { - std::vector::const_iterator it; - // Skip certs[0]. - for (it = certs.begin() + 1; it != certs.end(); ++it) { - certs_.push_back(FakeSSLCertificate(*it)); - } -} + expiration_time_(-1) {} FakeSSLCertificate::FakeSSLCertificate(const FakeSSLCertificate&) = default; @@ -43,12 +34,13 @@ FakeSSLCertificate* FakeSSLCertificate::GetReference() const { } std::string FakeSSLCertificate::ToPEMString() const { - return data_; + return pem_string_; } void FakeSSLCertificate::ToDER(Buffer* der_buffer) const { std::string der_string; - RTC_CHECK(SSLIdentity::PemToDer(kPemTypeCertificate, data_, &der_string)); + RTC_CHECK( + SSLIdentity::PemToDer(kPemTypeCertificate, pem_string_, &der_string)); der_buffer->SetData(der_string.c_str(), der_string.size()); } @@ -74,30 +66,40 @@ bool FakeSSLCertificate::ComputeDigest(const std::string& algorithm, unsigned char* digest, size_t size, size_t* length) const { - *length = - rtc::ComputeDigest(algorithm, data_.c_str(), data_.size(), digest, size); + *length = rtc::ComputeDigest(algorithm, pem_string_.c_str(), + pem_string_.size(), digest, size); return (*length != 0); } -std::unique_ptr FakeSSLCertificate::GetChain() const { - if (certs_.empty()) - return nullptr; - std::vector> new_certs(certs_.size()); - std::transform(certs_.begin(), certs_.end(), new_certs.begin(), DupCert); - return MakeUnique(std::move(new_certs)); +FakeSSLIdentity::FakeSSLIdentity(const std::string& pem_string) + : FakeSSLIdentity(FakeSSLCertificate(pem_string)) {} + +FakeSSLIdentity::FakeSSLIdentity(const std::vector& pem_strings) { + std::vector> certs; + for (const std::string& pem_string : pem_strings) { + certs.push_back(MakeUnique(pem_string)); + } + cert_chain_ = MakeUnique(std::move(certs)); } -FakeSSLIdentity::FakeSSLIdentity(const std::string& data) : cert_(data) {} - FakeSSLIdentity::FakeSSLIdentity(const FakeSSLCertificate& cert) - : cert_(cert) {} + : cert_chain_(MakeUnique(&cert)) {} + +FakeSSLIdentity::FakeSSLIdentity(const FakeSSLIdentity& o) + : cert_chain_(o.cert_chain_->UniqueCopy()) {} + +FakeSSLIdentity::~FakeSSLIdentity() = default; FakeSSLIdentity* FakeSSLIdentity::GetReference() const { return new FakeSSLIdentity(*this); } -const FakeSSLCertificate& FakeSSLIdentity::certificate() const { - return cert_; +const SSLCertificate& FakeSSLIdentity::certificate() const { + return cert_chain_->Get(0); +} + +const SSLCertChain& FakeSSLIdentity::cert_chain() const { + return *cert_chain_.get(); } std::string FakeSSLIdentity::PrivateKeyToPEMString() const { diff --git a/rtc_base/fakesslidentity.h b/rtc_base/fakesslidentity.h index 52aaf05558..4494a524ef 100644 --- a/rtc_base/fakesslidentity.h +++ b/rtc_base/fakesslidentity.h @@ -18,13 +18,11 @@ namespace rtc { -class FakeSSLCertificate : public rtc::SSLCertificate { +class FakeSSLCertificate : public SSLCertificate { public: // SHA-1 is the default digest algorithm because it is available in all build // configurations used for unit testing. - explicit FakeSSLCertificate(const std::string& data); - - explicit FakeSSLCertificate(const std::vector& certs); + explicit FakeSSLCertificate(const std::string& pem_string); FakeSSLCertificate(const FakeSSLCertificate&); ~FakeSSLCertificate() override; @@ -39,32 +37,33 @@ class FakeSSLCertificate : public rtc::SSLCertificate { unsigned char* digest, size_t size, size_t* length) const override; - std::unique_ptr GetChain() const override; void SetCertificateExpirationTime(int64_t expiration_time); void set_digest_algorithm(const std::string& algorithm); private: - static std::unique_ptr DupCert(FakeSSLCertificate cert) { - return cert.GetUniqueReference(); - } - static void DeleteCert(SSLCertificate* cert) { delete cert; } - std::string data_; - std::vector certs_; + std::string pem_string_; std::string digest_algorithm_; // Expiration time in seconds relative to epoch, 1970-01-01T00:00:00Z (UTC). int64_t expiration_time_; }; -class FakeSSLIdentity : public rtc::SSLIdentity { +class FakeSSLIdentity : public SSLIdentity { public: - explicit FakeSSLIdentity(const std::string& data); + explicit FakeSSLIdentity(const std::string& pem_string); + // For a certificate chain. + explicit FakeSSLIdentity(const std::vector& pem_strings); explicit FakeSSLIdentity(const FakeSSLCertificate& cert); + explicit FakeSSLIdentity(const FakeSSLIdentity& o); + + ~FakeSSLIdentity() override; + // SSLIdentity implementation. FakeSSLIdentity* GetReference() const override; - const FakeSSLCertificate& certificate() const override; + const SSLCertificate& certificate() const override; + const SSLCertChain& cert_chain() const override; // Not implemented. std::string PrivateKeyToPEMString() const override; // Not implemented. @@ -73,7 +72,7 @@ class FakeSSLIdentity : public rtc::SSLIdentity { virtual bool operator==(const SSLIdentity& other) const; private: - FakeSSLCertificate cert_; + std::unique_ptr cert_chain_; }; } // namespace rtc diff --git a/rtc_base/opensslidentity.cc b/rtc_base/opensslidentity.cc index 69ce5acb8b..9f7c63b06c 100644 --- a/rtc_base/opensslidentity.cc +++ b/rtc_base/opensslidentity.cc @@ -366,10 +366,6 @@ bool OpenSSLCertificate::GetSignatureDigestAlgorithm( return true; } -std::unique_ptr OpenSSLCertificate::GetChain() const { - return nullptr; -} - bool OpenSSLCertificate::ComputeDigest(const std::string& algorithm, unsigned char* digest, size_t size, @@ -590,6 +586,10 @@ const OpenSSLCertificate& OpenSSLIdentity::certificate() const { return *static_cast(&cert_chain_->Get(0)); } +const SSLCertChain& OpenSSLIdentity::cert_chain() const { + return *cert_chain_.get(); +} + OpenSSLIdentity* OpenSSLIdentity::GetReference() const { return new OpenSSLIdentity(WrapUnique(key_pair_->GetReference()), WrapUnique(cert_chain_->Copy())); diff --git a/rtc_base/opensslidentity.h b/rtc_base/opensslidentity.h index a700a1d107..c1dc49fb58 100644 --- a/rtc_base/opensslidentity.h +++ b/rtc_base/opensslidentity.h @@ -92,7 +92,6 @@ class OpenSSLCertificate : public SSLCertificate { size_t* length); bool GetSignatureDigestAlgorithm(std::string* algorithm) const override; - std::unique_ptr GetChain() const override; int64_t CertificateExpirationTime() const override; @@ -118,6 +117,7 @@ class OpenSSLIdentity : public SSLIdentity { ~OpenSSLIdentity() override; const OpenSSLCertificate& certificate() const override; + const SSLCertChain& cert_chain() const override; OpenSSLIdentity* GetReference() const override; // Configure an SSL context object to use our key and certificate. diff --git a/rtc_base/opensslstreamadapter.cc b/rtc_base/opensslstreamadapter.cc index d715e27ecc..c0fb108b9e 100644 --- a/rtc_base/opensslstreamadapter.cc +++ b/rtc_base/opensslstreamadapter.cc @@ -288,13 +288,6 @@ void OpenSSLStreamAdapter::SetServerRole(SSLRole role) { role_ = role; } -std::unique_ptr OpenSSLStreamAdapter::GetPeerCertificate() - const { - return peer_certificate_ ? std::unique_ptr( - peer_certificate_->GetReference()) - : nullptr; -} - bool OpenSSLStreamAdapter::SetPeerCertificateDigest( const std::string& digest_alg, const unsigned char* digest_val, @@ -324,7 +317,7 @@ bool OpenSSLStreamAdapter::SetPeerCertificateDigest( peer_certificate_digest_value_.SetData(digest_val, digest_len); peer_certificate_digest_algorithm_ = digest_alg; - if (!peer_certificate_) { + if (!peer_cert_chain_) { // Normal case, where the digest is set before we obtain the certificate // from the handshake. return true; @@ -831,7 +824,7 @@ int OpenSSLStreamAdapter::ContinueSSL() { RTC_LOG(LS_VERBOSE) << " -- success"; // By this point, OpenSSL should have given us a certificate, or errored // out if one was missing. - RTC_DCHECK(peer_certificate_ || !client_auth_enabled()); + RTC_DCHECK(peer_cert_chain_ || !client_auth_enabled()); state_ = SSL_CONNECTED; if (!waiting_to_verify_peer_certificate()) { @@ -928,7 +921,7 @@ void OpenSSLStreamAdapter::Cleanup(uint8_t alert) { ssl_ctx_ = nullptr; } identity_.reset(); - peer_certificate_.reset(); + peer_cert_chain_.reset(); // Clear the DTLS timer Thread::Current()->Clear(this, MSG_TIMEOUT); @@ -1062,15 +1055,18 @@ SSL_CTX* OpenSSLStreamAdapter::SetupSSLContext() { } bool OpenSSLStreamAdapter::VerifyPeerCertificate() { - if (!has_peer_certificate_digest() || !peer_certificate_) { + if (!has_peer_certificate_digest() || !peer_cert_chain_ || + !peer_cert_chain_->GetSize()) { RTC_LOG(LS_WARNING) << "Missing digest or peer certificate."; return false; } + const OpenSSLCertificate* leaf_cert = + static_cast(&peer_cert_chain_->Get(0)); unsigned char digest[EVP_MAX_MD_SIZE]; size_t digest_length; if (!OpenSSLCertificate::ComputeDigest( - peer_certificate_->x509(), peer_certificate_digest_algorithm_, digest, + leaf_cert->x509(), peer_certificate_digest_algorithm_, digest, sizeof(digest), &digest_length)) { RTC_LOG(LS_WARNING) << "Failed to compute peer cert digest."; return false; @@ -1092,7 +1088,7 @@ bool OpenSSLStreamAdapter::VerifyPeerCertificate() { std::unique_ptr OpenSSLStreamAdapter::GetPeerSSLCertChain() const { - return std::unique_ptr(peer_cert_chain_->Copy()); + return peer_cert_chain_ ? peer_cert_chain_->UniqueCopy() : nullptr; } int OpenSSLStreamAdapter::SSLVerifyCallback(X509_STORE_CTX* store, void* arg) { @@ -1104,9 +1100,6 @@ int OpenSSLStreamAdapter::SSLVerifyCallback(X509_STORE_CTX* store, void* arg) { #if defined(OPENSSL_IS_BORINGSSL) STACK_OF(X509)* chain = SSL_get_peer_full_cert_chain(ssl); - // Creates certificate. - stream->peer_certificate_.reset( - new OpenSSLCertificate(sk_X509_value(chain, 0))); // Creates certificate chain. std::vector> cert_chain; for (X509* cert : chain) { @@ -1116,7 +1109,8 @@ int OpenSSLStreamAdapter::SSLVerifyCallback(X509_STORE_CTX* store, void* arg) { #else // Record the peer's certificate. X509* cert = SSL_get_peer_certificate(ssl); - stream->peer_certificate_.reset(new OpenSSLCertificate(cert)); + stream->peer_cert_chain_.reset( + new SSLCertChain(new OpenSSLCertificate(cert))); X509_free(cert); #endif diff --git a/rtc_base/opensslstreamadapter.h b/rtc_base/opensslstreamadapter.h index b43dcc7e01..97ab557f41 100644 --- a/rtc_base/opensslstreamadapter.h +++ b/rtc_base/opensslstreamadapter.h @@ -69,8 +69,6 @@ class OpenSSLStreamAdapter : public SSLStreamAdapter { size_t digest_len, SSLPeerCertificateDigestError* error = nullptr) override; - std::unique_ptr GetPeerCertificate() const override; - std::unique_ptr GetPeerSSLCertChain() const override; // Goes from state SSL_NONE to either SSL_CONNECTING or SSL_WAIT, depending @@ -197,9 +195,8 @@ class OpenSSLStreamAdapter : public SSLStreamAdapter { // Our key and certificate. std::unique_ptr identity_; - // The certificate that the peer presented. Initially null, until the + // The certificate chain that the peer presented. Initially null, until the // connection is established. - std::unique_ptr peer_certificate_; std::unique_ptr peer_cert_chain_; bool peer_certificate_verified_ = false; // The digest of the certificate that the peer must present. diff --git a/rtc_base/rtccertificate.cc b/rtc_base/rtccertificate.cc index dd6f40a6af..2887895816 100644 --- a/rtc_base/rtccertificate.cc +++ b/rtc_base/rtccertificate.cc @@ -46,6 +46,10 @@ const SSLCertificate& RTCCertificate::ssl_certificate() const { return identity_->certificate(); } +const SSLCertChain& RTCCertificate::ssl_cert_chain() const { + return identity_->cert_chain(); +} + RTCCertificatePEM RTCCertificate::ToPEM() const { return RTCCertificatePEM(identity_->PrivateKeyToPEMString(), ssl_certificate().ToPEMString()); diff --git a/rtc_base/rtccertificate.h b/rtc_base/rtccertificate.h index 47f0e0c998..f13caba4cf 100644 --- a/rtc_base/rtccertificate.h +++ b/rtc_base/rtccertificate.h @@ -58,6 +58,7 @@ class RTCCertificate : public RefCountInterface { // relative to epoch, 1970-01-01T00:00:00Z. bool HasExpired(uint64_t now) const; const SSLCertificate& ssl_certificate() const; + const SSLCertChain& ssl_cert_chain() const; // TODO(hbos): If possible, remove once RTCCertificate and its // ssl_certificate() is used in all relevant places. Should not pass around diff --git a/rtc_base/sslidentity.cc b/rtc_base/sslidentity.cc index e035d9ec1b..1514e52be1 100644 --- a/rtc_base/sslidentity.cc +++ b/rtc_base/sslidentity.cc @@ -43,30 +43,6 @@ SSLCertificateStats::~SSLCertificateStats() { } std::unique_ptr SSLCertificate::GetStats() const { - // We have a certificate and optionally a chain of certificates. This forms a - // linked list, starting with |this|, then the first element of |chain| and - // ending with the last element of |chain|. The "issuer" of a certificate is - // the next certificate in the chain. Stats are produced for each certificate - // in the list. Here, the "issuer" is the issuer's stats. - std::unique_ptr chain = GetChain(); - std::unique_ptr issuer; - if (chain) { - // The loop runs in reverse so that the |issuer| is known before the - // |cert|'s stats. - for (ptrdiff_t i = chain->GetSize() - 1; i >= 0; --i) { - const SSLCertificate* cert = &chain->Get(i); - issuer = cert->GetStats(std::move(issuer)); - } - } - return GetStats(std::move(issuer)); -} - -std::unique_ptr SSLCertificate::GetUniqueReference() const { - return WrapUnique(GetReference()); -} - -std::unique_ptr SSLCertificate::GetStats( - std::unique_ptr issuer) const { // TODO(bemasc): Move this computation to a helper class that caches these // values to reduce CPU use in |StatsCollector::GetStats|. This will require // adding a fast |SSLCertificate::Equals| to detect certificate changes. @@ -89,11 +65,13 @@ std::unique_ptr SSLCertificate::GetStats( std::string der_base64; Base64::EncodeFromArray(der_buffer.data(), der_buffer.size(), &der_base64); - return std::unique_ptr(new SSLCertificateStats( - std::move(fingerprint), - std::move(digest_algorithm), - std::move(der_base64), - std::move(issuer))); + return rtc::MakeUnique(std::move(fingerprint), + std::move(digest_algorithm), + std::move(der_base64), nullptr); +} + +std::unique_ptr SSLCertificate::GetUniqueReference() const { + return WrapUnique(GetReference()); } KeyParams::KeyParams(KeyType key_type) { @@ -228,6 +206,28 @@ SSLCertChain* SSLCertChain::Copy() const { return new SSLCertChain(std::move(new_certs)); } +std::unique_ptr SSLCertChain::UniqueCopy() const { + return WrapUnique(Copy()); +} + +std::unique_ptr SSLCertChain::GetStats() const { + // We have a linked list of certificates, starting with the first element of + // |certs_| and ending with the last element of |certs_|. The "issuer" of a + // certificate is the next certificate in the chain. Stats are produced for + // each certificate in the list. Here, the "issuer" is the issuer's stats. + std::unique_ptr issuer; + // The loop runs in reverse so that the |issuer| is known before the + // certificate issued by |issuer|. + for (ptrdiff_t i = certs_.size() - 1; i >= 0; --i) { + std::unique_ptr new_stats = certs_[i]->GetStats(); + if (new_stats) { + new_stats->issuer = std::move(issuer); + } + issuer = std::move(new_stats); + } + return issuer; +} + // static SSLCertificate* SSLCertificate::FromPEMString(const std::string& pem_string) { return OpenSSLCertificate::FromPEMString(pem_string); diff --git a/rtc_base/sslidentity.h b/rtc_base/sslidentity.h index 952e2ab8c2..d14610b0f9 100644 --- a/rtc_base/sslidentity.h +++ b/rtc_base/sslidentity.h @@ -67,10 +67,6 @@ class SSLCertificate { std::unique_ptr GetUniqueReference() const; - // Returns null. This is deprecated. Please use - // SSLStreamAdapter::GetPeerSSLCertChain - virtual std::unique_ptr GetChain() const = 0; - // Returns a PEM encoded string representation of the certificate. virtual std::string ToPEMString() const = 0; @@ -91,14 +87,10 @@ class SSLCertificate { // or -1 if an expiration time could not be retrieved. virtual int64_t CertificateExpirationTime() const = 0; - // Gets information (fingerprint, etc.) about this certificate and its chain - // (if it has a certificate chain). This is used for certificate stats, see + // Gets information (fingerprint, etc.) about this certificate. This is used + // for certificate stats, see // https://w3c.github.io/webrtc-stats/#certificatestats-dict*. std::unique_ptr GetStats() const; - - private: - std::unique_ptr GetStats( - std::unique_ptr issuer) const; }; // SSLCertChain is a simple wrapper for a vector of SSLCertificates. It serves @@ -122,6 +114,13 @@ class SSLCertChain { // Returns a new SSLCertChain object instance wrapping the same underlying // certificate chain. Caller is responsible for freeing the returned object. SSLCertChain* Copy() const; + // Same as above, but returning a unique_ptr for convenience. + std::unique_ptr UniqueCopy() const; + + // Gets information (fingerprint, etc.) about this certificate chain. This is + // used for certificate stats, see + // https://w3c.github.io/webrtc-stats/#certificatestats-dict*. + std::unique_ptr GetStats() const; private: std::vector> certs_; @@ -241,8 +240,10 @@ class SSLIdentity { // TODO(hbos,torbjorng): Rename to a less confusing name. virtual SSLIdentity* GetReference() const = 0; - // Returns a temporary reference to the certificate. + // Returns a temporary reference to the end-entity (leaf) certificate. virtual const SSLCertificate& certificate() const = 0; + // Returns a temporary reference to the entire certificate chain. + virtual const SSLCertChain& cert_chain() const = 0; virtual std::string PrivateKeyToPEMString() const = 0; virtual std::string PublicKeyToPEMString() const = 0; diff --git a/rtc_base/sslidentity_unittest.cc b/rtc_base/sslidentity_unittest.cc index c26d8d73f9..e1dbe05858 100644 --- a/rtc_base/sslidentity_unittest.cc +++ b/rtc_base/sslidentity_unittest.cc @@ -175,8 +175,7 @@ IdentityAndInfo CreateFakeIdentityAndInfoFromDers( reinterpret_cast(der.c_str()), der.length())); } - info.identity.reset( - new rtc::FakeSSLIdentity(rtc::FakeSSLCertificate(info.pems))); + info.identity.reset(new rtc::FakeSSLIdentity(info.pems)); // Strip header/footer and newline characters of PEM strings. for (size_t i = 0; i < info.pems.size(); ++i) { rtc::replace_substrs("-----BEGIN CERTIFICATE-----", 27, @@ -186,20 +185,14 @@ IdentityAndInfo CreateFakeIdentityAndInfoFromDers( rtc::replace_substrs("\n", 1, "", 0, &info.pems[i]); } - // Fingerprint of leaf certificate. - std::unique_ptr fp( - rtc::SSLFingerprint::Create("sha-1", &info.identity->certificate())); - EXPECT_TRUE(fp); - info.fingerprints.push_back(fp->GetRfc4572Fingerprint()); - // Fingerprints of the rest of the chain. - std::unique_ptr chain = - info.identity->certificate().GetChain(); - if (chain) { - for (size_t i = 0; i < chain->GetSize(); i++) { - fp.reset(rtc::SSLFingerprint::Create("sha-1", &chain->Get(i))); - EXPECT_TRUE(fp); - info.fingerprints.push_back(fp->GetRfc4572Fingerprint()); - } + // Fingerprints for the whole certificate chain, starting with leaf + // certificate. + const rtc::SSLCertChain& chain = info.identity->cert_chain(); + std::unique_ptr fp; + for (size_t i = 0; i < chain.GetSize(); i++) { + fp.reset(rtc::SSLFingerprint::Create("sha-1", &chain.Get(i))); + EXPECT_TRUE(fp); + info.fingerprints.push_back(fp->GetRfc4572Fingerprint()); } EXPECT_EQ(info.ders.size(), info.fingerprints.size()); return info; @@ -477,7 +470,7 @@ TEST_F(SSLIdentityTest, SSLCertificateGetStatsWithChain) { EXPECT_EQ(info.fingerprints.size(), info.ders.size()); std::unique_ptr first_stats = - info.identity->certificate().GetStats(); + info.identity->cert_chain().GetStats(); rtc::SSLCertificateStats* cert_stats = first_stats.get(); for (size_t i = 0; i < info.ders.size(); ++i) { EXPECT_EQ(cert_stats->fingerprint, info.fingerprints[i]); diff --git a/rtc_base/sslstreamadapter.h b/rtc_base/sslstreamadapter.h index 560dd591c0..c04fb346c5 100644 --- a/rtc_base/sslstreamadapter.h +++ b/rtc_base/sslstreamadapter.h @@ -201,11 +201,7 @@ class SSLStreamAdapter : public StreamAdapterInterface { size_t digest_len, SSLPeerCertificateDigestError* error = nullptr) = 0; - // Retrieves the peer's X.509 certificate, if a connection has been - // established. - virtual std::unique_ptr GetPeerCertificate() const = 0; - - // Retrieves the peer's certificate chain including leaf, if a + // Retrieves the peer's certificate chain including leaf certificate, if a // connection has been established. virtual std::unique_ptr GetPeerSSLCertChain() const = 0; diff --git a/rtc_base/sslstreamadapter_unittest.cc b/rtc_base/sslstreamadapter_unittest.cc index 41ff09b62f..ce96274600 100644 --- a/rtc_base/sslstreamadapter_unittest.cc +++ b/rtc_base/sslstreamadapter_unittest.cc @@ -597,10 +597,13 @@ class SSLStreamAdapterTestBase : public testing::Test, } std::unique_ptr GetPeerCertificate(bool client) { + std::unique_ptr chain; if (client) - return client_ssl_->GetPeerCertificate(); + chain = client_ssl_->GetPeerSSLCertChain(); else - return server_ssl_->GetPeerCertificate(); + chain = server_ssl_->GetPeerSSLCertChain(); + return (chain && chain->GetSize()) ? chain->Get(0).GetUniqueReference() + : nullptr; } bool GetSslCipherSuite(bool client, int* retval) { @@ -971,11 +974,10 @@ TEST_P(SSLStreamAdapterTestTLS, GetPeerCertChainWithOneCertificate) { TestHandshake(); std::unique_ptr cert_chain = client_ssl_->GetPeerSSLCertChain(); - std::unique_ptr certificate = - client_ssl_->GetPeerCertificate(); ASSERT_NE(nullptr, cert_chain); EXPECT_EQ(1u, cert_chain->GetSize()); - EXPECT_EQ(cert_chain->Get(0).ToPEMString(), certificate->ToPEMString()); + EXPECT_EQ(cert_chain->Get(0).ToPEMString(), + server_identity_->certificate().ToPEMString()); } TEST_F(SSLStreamAdapterTestDTLSCertChain, TwoCertHandshake) { @@ -1388,9 +1390,6 @@ TEST_F(SSLStreamAdapterTestDTLSFromPEMStrings, TestDTLSGetPeerCertificate) { std::string client_peer_string = client_peer_cert->ToPEMString(); ASSERT_NE(kCERT_PEM, client_peer_string); - // It must not have a chain, because the test certs are self-signed. - ASSERT_FALSE(client_peer_cert->GetChain()); - // The server should have a peer certificate after the handshake. std::unique_ptr server_peer_cert = GetPeerCertificate(false); @@ -1398,9 +1397,6 @@ TEST_F(SSLStreamAdapterTestDTLSFromPEMStrings, TestDTLSGetPeerCertificate) { // It's kCERT_PEM ASSERT_EQ(kCERT_PEM, server_peer_cert->ToPEMString()); - - // It must not have a chain, because the test certs are self-signed. - ASSERT_FALSE(server_peer_cert->GetChain()); } // Test getting the used DTLS ciphers.