diff --git a/rtc_base/ssladapter_unittest.cc b/rtc_base/ssladapter_unittest.cc index c15ecfe3e5..1bb979ec90 100644 --- a/rtc_base/ssladapter_unittest.cc +++ b/rtc_base/ssladapter_unittest.cc @@ -10,9 +10,11 @@ #include #include +#include #include "rtc_base/gunit.h" #include "rtc_base/ipaddress.h" +#include "rtc_base/ptr_util.h" #include "rtc_base/socketstream.h" #include "rtc_base/ssladapter.h" #include "rtc_base/sslidentity.h" @@ -20,6 +22,10 @@ #include "rtc_base/stream.h" #include "rtc_base/stringencode.h" #include "rtc_base/virtualsocketserver.h" +#include "test/gmock.h" + +using ::testing::_; +using ::testing::Return; static const int kTimeout = 5000; @@ -39,6 +45,15 @@ static std::string GetSSLProtocolName(const rtc::SSLMode& ssl_mode) { return (ssl_mode == rtc::SSL_MODE_DTLS) ? "DTLS" : "TLS"; } +// Simple mock for the certificate verifier. +class MockCertVerifier : public rtc::SSLCertificateVerifier { + public: + virtual ~MockCertVerifier() = default; + MOCK_METHOD1(Verify, bool(const rtc::SSLCertificate&)); +}; + +// TODO(benwright) - Move to using INSTANTIATE_TEST_CASE_P instead of using +// duplicate test cases for simple parameter changes. class SSLAdapterTestDummyClient : public sigslot::has_slots<> { public: explicit SSLAdapterTestDummyClient(const rtc::SSLMode& ssl_mode) @@ -60,6 +75,14 @@ class SSLAdapterTestDummyClient : public sigslot::has_slots<> { &SSLAdapterTestDummyClient::OnSSLAdapterCloseEvent); } + void SetIgnoreBadCert(bool ignore_bad_cert) { + ssl_adapter_->SetIgnoreBadCert(ignore_bad_cert); + } + + void SetCertVerifier(rtc::SSLCertificateVerifier* ssl_cert_verifier) { + ssl_adapter_->SetCertVerifier(ssl_cert_verifier); + } + void SetAlpnProtocols(const std::vector& protos) { ssl_adapter_->SetAlpnProtocols(protos); } @@ -291,6 +314,14 @@ class SSLAdapterTestBase : public testing::Test, handshake_wait_ = wait; } + void SetIgnoreBadCert(bool ignore_bad_cert) { + client_->SetIgnoreBadCert(ignore_bad_cert); + } + + void SetCertVerifier(rtc::SSLCertificateVerifier* ssl_cert_verifier) { + client_->SetCertVerifier(ssl_cert_verifier); + } + void SetAlpnProtocols(const std::vector& protos) { client_->SetAlpnProtocols(protos); } @@ -299,6 +330,16 @@ class SSLAdapterTestBase : public testing::Test, client_->SetEllipticCurves(curves); } + void SetMockCertVerifier(bool return_value) { + auto mock_verifier = rtc::MakeUnique(); + EXPECT_CALL(*mock_verifier, Verify(_)).WillRepeatedly(Return(return_value)); + cert_verifier_ = + std::unique_ptr(std::move(mock_verifier)); + + SetIgnoreBadCert(false); + SetCertVerifier(cert_verifier_.get()); + } + void TestHandshake(bool expect_success) { int rv; @@ -359,6 +400,7 @@ class SSLAdapterTestBase : public testing::Test, rtc::AutoSocketServerThread thread_; std::unique_ptr server_; std::unique_ptr client_; + std::unique_ptr cert_verifier_; int handshake_wait_; }; @@ -394,9 +436,34 @@ TEST_F(SSLAdapterTestTLS_RSA, TestTLSConnect) { TestHandshake(true); } +// Test that handshake works with a custom verifier that returns true. RSA. +TEST_F(SSLAdapterTestTLS_RSA, TestTLSConnectCustomCertVerifierSucceeds) { + SetMockCertVerifier(/*return_value=*/true); + TestHandshake(/*expect_success=*/true); +} + +// Test that handshake fails with a custom verifier that returns false. RSA. +TEST_F(SSLAdapterTestTLS_RSA, TestTLSConnectCustomCertVerifierFails) { + SetMockCertVerifier(/*return_value=*/false); + TestHandshake(/*expect_success=*/false); +} + // Test that handshake works, using ECDSA TEST_F(SSLAdapterTestTLS_ECDSA, TestTLSConnect) { - TestHandshake(true); + SetMockCertVerifier(/*return_value=*/true); + TestHandshake(/*expect_success=*/true); +} + +// Test that handshake works with a custom verifier that returns true. ECDSA. +TEST_F(SSLAdapterTestTLS_ECDSA, TestTLSConnectCustomCertVerifierSucceeds) { + SetMockCertVerifier(/*return_value=*/true); + TestHandshake(/*expect_success=*/true); +} + +// Test that handshake fails with a custom verifier that returns false. ECDSA. +TEST_F(SSLAdapterTestTLS_ECDSA, TestTLSConnectCustomCertVerifierFails) { + SetMockCertVerifier(/*return_value=*/false); + TestHandshake(/*expect_success=*/false); } // Test transfer between client and server, using RSA @@ -405,6 +472,13 @@ TEST_F(SSLAdapterTestTLS_RSA, TestTLSTransfer) { TestTransfer("Hello, world!"); } +// Test transfer between client and server, using RSA with custom cert verifier. +TEST_F(SSLAdapterTestTLS_RSA, TestTLSTransferCustomCertVerifier) { + SetMockCertVerifier(/*return_value=*/true); + TestHandshake(/*expect_success=*/true); + TestTransfer("Hello, world!"); +} + TEST_F(SSLAdapterTestTLS_RSA, TestTLSTransferWithBlockedSocket) { TestHandshake(true); @@ -452,6 +526,14 @@ TEST_F(SSLAdapterTestTLS_ECDSA, TestTLSTransfer) { TestTransfer("Hello, world!"); } +// Test transfer between client and server, using ECDSA with custom cert +// verifier. +TEST_F(SSLAdapterTestTLS_ECDSA, TestTLSTransferCustomCertVerifier) { + SetMockCertVerifier(/*return_value=*/true); + TestHandshake(/*expect_success=*/true); + TestTransfer("Hello, world!"); +} + // Test transfer using ALPN with protos as h2 and http/1.1 TEST_F(SSLAdapterTestTLS_ECDSA, TestTLSALPN) { std::vector alpn_protos{"h2", "http/1.1"}; @@ -475,19 +557,61 @@ TEST_F(SSLAdapterTestDTLS_RSA, TestDTLSConnect) { TestHandshake(true); } +// Test that handshake works with a custom verifier that returns true. DTLS_RSA. +TEST_F(SSLAdapterTestDTLS_RSA, TestDTLSConnectCustomCertVerifierSucceeds) { + SetMockCertVerifier(/*return_value=*/true); + TestHandshake(/*expect_success=*/true); +} + +// Test that handshake fails with a custom verifier that returns false. +// DTLS_RSA. +TEST_F(SSLAdapterTestDTLS_RSA, TestTLSConnectCustomCertVerifierFails) { + SetMockCertVerifier(/*return_value=*/false); + TestHandshake(/*expect_success=*/false); +} + // Test that handshake works, using ECDSA TEST_F(SSLAdapterTestDTLS_ECDSA, TestDTLSConnect) { TestHandshake(true); } +// Test that handshake works with a custom verifier that returns true. +// DTLS_ECDSA. +TEST_F(SSLAdapterTestDTLS_ECDSA, TestDTLSConnectCustomCertVerifierSucceeds) { + SetMockCertVerifier(/*return_value=*/true); + TestHandshake(/*expect_success=*/true); +} + +// Test that handshake fails with a custom verifier that returns false. +// DTLS_ECDSA. +TEST_F(SSLAdapterTestDTLS_ECDSA, TestTLSConnectCustomCertVerifierFails) { + SetMockCertVerifier(/*return_value=*/false); + TestHandshake(/*expect_success=*/false); +} + // Test transfer between client and server, using RSA TEST_F(SSLAdapterTestDTLS_RSA, TestDTLSTransfer) { TestHandshake(true); TestTransfer("Hello, world!"); } +// Test transfer between client and server, using RSA with custom cert verifier. +TEST_F(SSLAdapterTestDTLS_RSA, TestDTLSTransferCustomCertVerifier) { + SetMockCertVerifier(/*return_value=*/true); + TestHandshake(/*expect_success=*/true); + TestTransfer("Hello, world!"); +} + // Test transfer between client and server, using ECDSA TEST_F(SSLAdapterTestDTLS_ECDSA, TestDTLSTransfer) { TestHandshake(true); TestTransfer("Hello, world!"); } + +// Test transfer between client and server, using ECDSA with custom cert +// verifier. +TEST_F(SSLAdapterTestDTLS_ECDSA, TestDTLSTransferCustomCertVerifier) { + SetMockCertVerifier(/*return_value=*/true); + TestHandshake(/*expect_success=*/true); + TestTransfer("Hello, world!"); +}