diff --git a/pc/channel.cc b/pc/channel.cc index 21e666cf7d..3950c03a10 100644 --- a/pc/channel.cc +++ b/pc/channel.cc @@ -878,13 +878,26 @@ bool BaseChannel::SetupDtlsSrtp_n(bool rtcp) { recv_key = &server_write_key; } + // Use an empty encrypted header extension ID vector if not set. This could + // happen when the DTLS handshake is completed before processing the + // Offer/Answer which contains the encrypted header extension IDs. + std::vector send_extension_ids; + std::vector recv_extension_ids; + if (catched_send_extension_ids_) { + send_extension_ids = *catched_send_extension_ids_; + } + if (catched_recv_extension_ids_) { + recv_extension_ids = *catched_recv_extension_ids_; + } + if (rtcp) { if (!dtls_active()) { RTC_DCHECK(srtp_transport_); ret = srtp_transport_->SetRtcpParams( selected_crypto_suite, &(*send_key)[0], - static_cast(send_key->size()), selected_crypto_suite, - &(*recv_key)[0], static_cast(recv_key->size())); + static_cast(send_key->size()), send_extension_ids, + selected_crypto_suite, &(*recv_key)[0], + static_cast(recv_key->size()), recv_extension_ids); } else { // RTCP doesn't need to call SetRtpParam because it is only used // to make the updated encrypted RTP header extension IDs take effect. @@ -892,10 +905,11 @@ bool BaseChannel::SetupDtlsSrtp_n(bool rtcp) { } } else { RTC_DCHECK(srtp_transport_); - ret = srtp_transport_->SetRtpParams(selected_crypto_suite, &(*send_key)[0], - static_cast(send_key->size()), - selected_crypto_suite, &(*recv_key)[0], - static_cast(recv_key->size())); + ret = srtp_transport_->SetRtpParams( + selected_crypto_suite, &(*send_key)[0], + static_cast(send_key->size()), send_extension_ids, + selected_crypto_suite, &(*recv_key)[0], + static_cast(recv_key->size()), recv_extension_ids); dtls_active_ = ret; } @@ -1043,10 +1057,11 @@ bool BaseChannel::SetSrtp_n(const std::vector& cryptos, if (!srtp_transport_ && !dtls && !cryptos.empty()) { EnableSrtpTransport_n(); } - if (srtp_transport_) { - srtp_transport_->SetEncryptedHeaderExtensionIds(src, - encrypted_extension_ids); - } + + bool encrypted_header_extensions_id_changed = + EncryptedHeaderExtensionIdsChanged(src, encrypted_extension_ids); + CacheEncryptedHeaderExtensionIds(src, encrypted_extension_ids); + switch (action) { case CA_OFFER: // If DTLS is already active on the channel, we could be renegotiating @@ -1078,13 +1093,17 @@ bool BaseChannel::SetSrtp_n(const std::vector& cryptos, if ((action == CA_PRANSWER || action == CA_ANSWER) && !dtls && ret) { if (sdes_negotiator_.send_cipher_suite() && sdes_negotiator_.recv_cipher_suite()) { + RTC_DCHECK(catched_send_extension_ids_); + RTC_DCHECK(catched_recv_extension_ids_); ret = srtp_transport_->SetRtpParams( *(sdes_negotiator_.send_cipher_suite()), sdes_negotiator_.send_key().data(), static_cast(sdes_negotiator_.send_key().size()), + *(catched_send_extension_ids_), *(sdes_negotiator_.recv_cipher_suite()), sdes_negotiator_.recv_key().data(), - static_cast(sdes_negotiator_.recv_key().size())); + static_cast(sdes_negotiator_.recv_key().size()), + *(catched_recv_extension_ids_)); } else { RTC_LOG(LS_INFO) << "No crypto keys are provided for SDES."; if (action == CA_ANSWER && srtp_transport_) { @@ -1096,16 +1115,16 @@ bool BaseChannel::SetSrtp_n(const std::vector& cryptos, } } - // Only update SRTP filter if using DTLS. SDES is handled internally + // Only update SRTP transport if using DTLS. SDES is handled internally // by the SRTP filter. - // TODO(jbauch): Only update if encrypted extension ids have changed. if (ret && dtls_active() && rtp_dtls_transport_ && - rtp_dtls_transport_->dtls_state() == DTLS_TRANSPORT_CONNECTED) { - bool rtcp = false; - ret = SetupDtlsSrtp_n(rtcp); + rtp_dtls_transport_->dtls_state() == DTLS_TRANSPORT_CONNECTED && + encrypted_header_extensions_id_changed) { + ret = SetupDtlsSrtp_n(/*rtcp=*/false); } + if (!ret) { - SafeSetError("Failed to setup SRTP filter.", error_desc); + SafeSetError("Failed to setup SRTP.", error_desc); return false; } return true; @@ -1433,6 +1452,26 @@ void BaseChannel::SignalSentPacket_w(const rtc::SentPacket& sent_packet) { SignalSentPacket(sent_packet); } +void BaseChannel::CacheEncryptedHeaderExtensionIds( + cricket::ContentSource source, + const std::vector& extension_ids) { + source == ContentSource::CS_LOCAL + ? catched_recv_extension_ids_.emplace(extension_ids) + : catched_send_extension_ids_.emplace(extension_ids); +} + +bool BaseChannel::EncryptedHeaderExtensionIdsChanged( + cricket::ContentSource source, + const std::vector& new_extension_ids) { + if (source == ContentSource::CS_LOCAL) { + return !catched_recv_extension_ids_ || + (*catched_recv_extension_ids_) != new_extension_ids; + } else { + return !catched_send_extension_ids_ || + (*catched_send_extension_ids_) != new_extension_ids; + } +} + VoiceChannel::VoiceChannel(rtc::Thread* worker_thread, rtc::Thread* network_thread, rtc::Thread* signaling_thread, diff --git a/pc/channel.h b/pc/channel.h index 5689338230..ec13f07cf3 100644 --- a/pc/channel.h +++ b/pc/channel.h @@ -368,6 +368,18 @@ class BaseChannel // Wraps the existing RtpTransport in an SrtpTransport. void EnableSrtpTransport_n(); + // Cache the encrypted header extension IDs when setting the local/remote + // description and use them later together with other crypto parameters from + // DtlsTransport. + void CacheEncryptedHeaderExtensionIds(cricket::ContentSource source, + const std::vector& extension_ids); + + // Return true if the new header extension IDs are different from the existing + // ones. + bool EncryptedHeaderExtensionIdsChanged( + cricket::ContentSource source, + const std::vector& new_extension_ids); + rtc::Thread* const worker_thread_; rtc::Thread* const network_thread_; rtc::Thread* const signaling_thread_; @@ -410,6 +422,10 @@ class BaseChannel MediaContentDirection local_content_direction_ = MD_INACTIVE; MediaContentDirection remote_content_direction_ = MD_INACTIVE; CandidatePairInterface* selected_candidate_pair_; + + // The cached encrypted header extension IDs. + rtc::Optional> catched_send_extension_ids_; + rtc::Optional> catched_recv_extension_ids_; }; // VoiceChannel is a specialization that adds support for early media, DTMF, diff --git a/pc/srtpsession.cc b/pc/srtpsession.cc index 8fe8dc0f8f..a07848d475 100644 --- a/pc/srtpsession.cc +++ b/pc/srtpsession.cc @@ -32,20 +32,32 @@ SrtpSession::~SrtpSession() { } } -bool SrtpSession::SetSend(int cs, const uint8_t* key, size_t len) { - return SetKey(ssrc_any_outbound, cs, key, len); +bool SrtpSession::SetSend(int cs, + const uint8_t* key, + size_t len, + const std::vector& extension_ids) { + return SetKey(ssrc_any_outbound, cs, key, len, extension_ids); } -bool SrtpSession::UpdateSend(int cs, const uint8_t* key, size_t len) { - return UpdateKey(ssrc_any_outbound, cs, key, len); +bool SrtpSession::UpdateSend(int cs, + const uint8_t* key, + size_t len, + const std::vector& extension_ids) { + return UpdateKey(ssrc_any_outbound, cs, key, len, extension_ids); } -bool SrtpSession::SetRecv(int cs, const uint8_t* key, size_t len) { - return SetKey(ssrc_any_inbound, cs, key, len); +bool SrtpSession::SetRecv(int cs, + const uint8_t* key, + size_t len, + const std::vector& extension_ids) { + return SetKey(ssrc_any_inbound, cs, key, len, extension_ids); } -bool SrtpSession::UpdateRecv(int cs, const uint8_t* key, size_t len) { - return UpdateKey(ssrc_any_inbound, cs, key, len); +bool SrtpSession::UpdateRecv(int cs, + const uint8_t* key, + size_t len, + const std::vector& extension_ids) { + return UpdateKey(ssrc_any_inbound, cs, key, len, extension_ids); } bool SrtpSession::ProtectRtp(void* p, int in_len, int max_len, int* out_len) { @@ -203,7 +215,11 @@ bool SrtpSession::GetSendStreamPacketIndex(void* p, return true; } -bool SrtpSession::DoSetKey(int type, int cs, const uint8_t* key, size_t len) { +bool SrtpSession::DoSetKey(int type, + int cs, + const uint8_t* key, + size_t len, + const std::vector& extension_ids) { RTC_DCHECK(thread_checker_.CalledOnValidThread()); srtp_policy_t policy; @@ -262,10 +278,9 @@ bool SrtpSession::DoSetKey(int type, int cs, const uint8_t* key, size_t len) { !rtc::IsGcmCryptoSuite(cs)) { policy.rtp.auth_type = EXTERNAL_HMAC_SHA1; } - if (!encrypted_header_extension_ids_.empty()) { - policy.enc_xtn_hdr = const_cast(&encrypted_header_extension_ids_[0]); - policy.enc_xtn_hdr_count = - static_cast(encrypted_header_extension_ids_.size()); + if (!extension_ids.empty()) { + policy.enc_xtn_hdr = const_cast(&extension_ids[0]); + policy.enc_xtn_hdr_count = static_cast(extension_ids.size()); } policy.next = nullptr; @@ -291,7 +306,11 @@ bool SrtpSession::DoSetKey(int type, int cs, const uint8_t* key, size_t len) { return true; } -bool SrtpSession::SetKey(int type, int cs, const uint8_t* key, size_t len) { +bool SrtpSession::SetKey(int type, + int cs, + const uint8_t* key, + size_t len, + const std::vector& extension_ids) { RTC_DCHECK(thread_checker_.CalledOnValidThread()); if (session_) { RTC_LOG(LS_ERROR) << "Failed to create SRTP session: " @@ -307,23 +326,21 @@ bool SrtpSession::SetKey(int type, int cs, const uint8_t* key, size_t len) { return false; } - return DoSetKey(type, cs, key, len); + return DoSetKey(type, cs, key, len, extension_ids); } -bool SrtpSession::UpdateKey(int type, int cs, const uint8_t* key, size_t len) { +bool SrtpSession::UpdateKey(int type, + int cs, + const uint8_t* key, + size_t len, + const std::vector& extension_ids) { RTC_DCHECK(thread_checker_.CalledOnValidThread()); if (!session_) { RTC_LOG(LS_ERROR) << "Failed to update non-existing SRTP session"; return false; } - return DoSetKey(type, cs, key, len); -} - -void SrtpSession::SetEncryptedHeaderExtensionIds( - const std::vector& encrypted_header_extension_ids) { - RTC_DCHECK(thread_checker_.CalledOnValidThread()); - encrypted_header_extension_ids_ = encrypted_header_extension_ids; + return DoSetKey(type, cs, key, len, extension_ids); } int g_libsrtp_usage_count = 0; diff --git a/pc/srtpsession.h b/pc/srtpsession.h index 94702da130..a6e78fab6b 100644 --- a/pc/srtpsession.h +++ b/pc/srtpsession.h @@ -30,16 +30,25 @@ class SrtpSession { // Configures the session for sending data using the specified // cipher-suite and key. Receiving must be done by a separate session. - bool SetSend(int cs, const uint8_t* key, size_t len); - bool UpdateSend(int cs, const uint8_t* key, size_t len); + bool SetSend(int cs, + const uint8_t* key, + size_t len, + const std::vector& extension_ids); + bool UpdateSend(int cs, + const uint8_t* key, + size_t len, + const std::vector& extension_ids); // Configures the session for receiving data using the specified // cipher-suite and key. Sending must be done by a separate session. - bool SetRecv(int cs, const uint8_t* key, size_t len); - bool UpdateRecv(int cs, const uint8_t* key, size_t len); - - void SetEncryptedHeaderExtensionIds( - const std::vector& encrypted_header_extension_ids); + bool SetRecv(int cs, + const uint8_t* key, + size_t len, + const std::vector& extension_ids); + bool UpdateRecv(int cs, + const uint8_t* key, + size_t len, + const std::vector& extension_ids); // Encrypts/signs an individual RTP/RTCP packet, in-place. // If an HMAC is used, this will increase the packet size. @@ -75,12 +84,21 @@ class SrtpSession { bool IsExternalAuthActive() const; private: - bool DoSetKey(int type, int cs, const uint8_t* key, size_t len); - bool SetKey(int type, int cs, const uint8_t* key, size_t len); - bool UpdateKey(int type, int cs, const uint8_t* key, size_t len); - bool SetEncryptedHeaderExtensionIds( - int type, - const std::vector& encrypted_header_extension_ids); + bool DoSetKey(int type, + int cs, + const uint8_t* key, + size_t len, + const std::vector& extension_ids); + bool SetKey(int type, + int cs, + const uint8_t* key, + size_t len, + const std::vector& extension_ids); + bool UpdateKey(int type, + int cs, + const uint8_t* key, + size_t len, + const std::vector& extension_ids); // Returns send stream current packet index from srtp db. bool GetSendStreamPacketIndex(void* data, int in_len, int64_t* index); @@ -104,7 +122,6 @@ class SrtpSession { int last_send_seq_num_ = -1; bool external_auth_active_ = false; bool external_auth_enabled_ = false; - std::vector encrypted_header_extension_ids_; RTC_DISALLOW_COPY_AND_ASSIGN(SrtpSession); }; diff --git a/pc/srtpsession_unittest.cc b/pc/srtpsession_unittest.cc index b89b3ad55a..dc325739e8 100644 --- a/pc/srtpsession_unittest.cc +++ b/pc/srtpsession_unittest.cc @@ -19,6 +19,8 @@ namespace rtc { +std::vector kEncryptedHeaderExtensionIds; + class SrtpSessionTest : public testing::Test { protected: virtual void SetUp() { @@ -65,28 +67,38 @@ class SrtpSessionTest : public testing::Test { // Test that we can set up the session and keys properly. TEST_F(SrtpSessionTest, TestGoodSetup) { - EXPECT_TRUE(s1_.SetSend(SRTP_AES128_CM_SHA1_80, kTestKey1, kTestKeyLen)); - EXPECT_TRUE(s2_.SetRecv(SRTP_AES128_CM_SHA1_80, kTestKey1, kTestKeyLen)); + EXPECT_TRUE(s1_.SetSend(SRTP_AES128_CM_SHA1_80, kTestKey1, kTestKeyLen, + kEncryptedHeaderExtensionIds)); + EXPECT_TRUE(s2_.SetRecv(SRTP_AES128_CM_SHA1_80, kTestKey1, kTestKeyLen, + kEncryptedHeaderExtensionIds)); } // Test that we can't change the keys once set. TEST_F(SrtpSessionTest, TestBadSetup) { - EXPECT_TRUE(s1_.SetSend(SRTP_AES128_CM_SHA1_80, kTestKey1, kTestKeyLen)); - EXPECT_TRUE(s2_.SetRecv(SRTP_AES128_CM_SHA1_80, kTestKey1, kTestKeyLen)); - EXPECT_FALSE(s1_.SetSend(SRTP_AES128_CM_SHA1_80, kTestKey2, kTestKeyLen)); - EXPECT_FALSE(s2_.SetRecv(SRTP_AES128_CM_SHA1_80, kTestKey2, kTestKeyLen)); + EXPECT_TRUE(s1_.SetSend(SRTP_AES128_CM_SHA1_80, kTestKey1, kTestKeyLen, + kEncryptedHeaderExtensionIds)); + EXPECT_TRUE(s2_.SetRecv(SRTP_AES128_CM_SHA1_80, kTestKey1, kTestKeyLen, + kEncryptedHeaderExtensionIds)); + EXPECT_FALSE(s1_.SetSend(SRTP_AES128_CM_SHA1_80, kTestKey2, kTestKeyLen, + kEncryptedHeaderExtensionIds)); + EXPECT_FALSE(s2_.SetRecv(SRTP_AES128_CM_SHA1_80, kTestKey2, kTestKeyLen, + kEncryptedHeaderExtensionIds)); } // Test that we fail keys of the wrong length. TEST_F(SrtpSessionTest, TestKeysTooShort) { - EXPECT_FALSE(s1_.SetSend(SRTP_AES128_CM_SHA1_80, kTestKey1, 1)); - EXPECT_FALSE(s2_.SetRecv(SRTP_AES128_CM_SHA1_80, kTestKey1, 1)); + EXPECT_FALSE(s1_.SetSend(SRTP_AES128_CM_SHA1_80, kTestKey1, 1, + kEncryptedHeaderExtensionIds)); + EXPECT_FALSE(s2_.SetRecv(SRTP_AES128_CM_SHA1_80, kTestKey1, 1, + kEncryptedHeaderExtensionIds)); } // Test that we can encrypt and decrypt RTP/RTCP using AES_CM_128_HMAC_SHA1_80. TEST_F(SrtpSessionTest, TestProtect_AES_CM_128_HMAC_SHA1_80) { - EXPECT_TRUE(s1_.SetSend(SRTP_AES128_CM_SHA1_80, kTestKey1, kTestKeyLen)); - EXPECT_TRUE(s2_.SetRecv(SRTP_AES128_CM_SHA1_80, kTestKey1, kTestKeyLen)); + EXPECT_TRUE(s1_.SetSend(SRTP_AES128_CM_SHA1_80, kTestKey1, kTestKeyLen, + kEncryptedHeaderExtensionIds)); + EXPECT_TRUE(s2_.SetRecv(SRTP_AES128_CM_SHA1_80, kTestKey1, kTestKeyLen, + kEncryptedHeaderExtensionIds)); TestProtectRtp(CS_AES_CM_128_HMAC_SHA1_80); TestProtectRtcp(CS_AES_CM_128_HMAC_SHA1_80); TestUnprotectRtp(CS_AES_CM_128_HMAC_SHA1_80); @@ -95,8 +107,10 @@ TEST_F(SrtpSessionTest, TestProtect_AES_CM_128_HMAC_SHA1_80) { // Test that we can encrypt and decrypt RTP/RTCP using AES_CM_128_HMAC_SHA1_32. TEST_F(SrtpSessionTest, TestProtect_AES_CM_128_HMAC_SHA1_32) { - EXPECT_TRUE(s1_.SetSend(SRTP_AES128_CM_SHA1_32, kTestKey1, kTestKeyLen)); - EXPECT_TRUE(s2_.SetRecv(SRTP_AES128_CM_SHA1_32, kTestKey1, kTestKeyLen)); + EXPECT_TRUE(s1_.SetSend(SRTP_AES128_CM_SHA1_32, kTestKey1, kTestKeyLen, + kEncryptedHeaderExtensionIds)); + EXPECT_TRUE(s2_.SetRecv(SRTP_AES128_CM_SHA1_32, kTestKey1, kTestKeyLen, + kEncryptedHeaderExtensionIds)); TestProtectRtp(CS_AES_CM_128_HMAC_SHA1_32); TestProtectRtcp(CS_AES_CM_128_HMAC_SHA1_32); TestUnprotectRtp(CS_AES_CM_128_HMAC_SHA1_32); @@ -104,7 +118,8 @@ TEST_F(SrtpSessionTest, TestProtect_AES_CM_128_HMAC_SHA1_32) { } TEST_F(SrtpSessionTest, TestGetSendStreamPacketIndex) { - EXPECT_TRUE(s1_.SetSend(SRTP_AES128_CM_SHA1_32, kTestKey1, kTestKeyLen)); + EXPECT_TRUE(s1_.SetSend(SRTP_AES128_CM_SHA1_32, kTestKey1, kTestKeyLen, + kEncryptedHeaderExtensionIds)); int64_t index; int out_len = 0; EXPECT_TRUE(s1_.ProtectRtp(rtp_packet_, rtp_len_, sizeof(rtp_packet_), @@ -117,8 +132,10 @@ TEST_F(SrtpSessionTest, TestGetSendStreamPacketIndex) { // Test that we fail to unprotect if someone tampers with the RTP/RTCP paylaods. TEST_F(SrtpSessionTest, TestTamperReject) { int out_len; - EXPECT_TRUE(s1_.SetSend(SRTP_AES128_CM_SHA1_80, kTestKey1, kTestKeyLen)); - EXPECT_TRUE(s2_.SetRecv(SRTP_AES128_CM_SHA1_80, kTestKey1, kTestKeyLen)); + EXPECT_TRUE(s1_.SetSend(SRTP_AES128_CM_SHA1_80, kTestKey1, kTestKeyLen, + kEncryptedHeaderExtensionIds)); + EXPECT_TRUE(s2_.SetRecv(SRTP_AES128_CM_SHA1_80, kTestKey1, kTestKeyLen, + kEncryptedHeaderExtensionIds)); TestProtectRtp(CS_AES_CM_128_HMAC_SHA1_80); TestProtectRtcp(CS_AES_CM_128_HMAC_SHA1_80); rtp_packet_[0] = 0x12; @@ -130,8 +147,10 @@ TEST_F(SrtpSessionTest, TestTamperReject) { // Test that we fail to unprotect if the payloads are not authenticated. TEST_F(SrtpSessionTest, TestUnencryptReject) { int out_len; - EXPECT_TRUE(s1_.SetSend(SRTP_AES128_CM_SHA1_80, kTestKey1, kTestKeyLen)); - EXPECT_TRUE(s2_.SetRecv(SRTP_AES128_CM_SHA1_80, kTestKey1, kTestKeyLen)); + EXPECT_TRUE(s1_.SetSend(SRTP_AES128_CM_SHA1_80, kTestKey1, kTestKeyLen, + kEncryptedHeaderExtensionIds)); + EXPECT_TRUE(s2_.SetRecv(SRTP_AES128_CM_SHA1_80, kTestKey1, kTestKeyLen, + kEncryptedHeaderExtensionIds)); EXPECT_FALSE(s2_.UnprotectRtp(rtp_packet_, rtp_len_, &out_len)); EXPECT_FALSE(s2_.UnprotectRtcp(rtcp_packet_, rtcp_len_, &out_len)); } @@ -139,7 +158,8 @@ TEST_F(SrtpSessionTest, TestUnencryptReject) { // Test that we fail when using buffers that are too small. TEST_F(SrtpSessionTest, TestBuffersTooSmall) { int out_len; - EXPECT_TRUE(s1_.SetSend(SRTP_AES128_CM_SHA1_80, kTestKey1, kTestKeyLen)); + EXPECT_TRUE(s1_.SetSend(SRTP_AES128_CM_SHA1_80, kTestKey1, kTestKeyLen, + kEncryptedHeaderExtensionIds)); EXPECT_FALSE(s1_.ProtectRtp(rtp_packet_, rtp_len_, sizeof(rtp_packet_) - 10, &out_len)); EXPECT_FALSE(s1_.ProtectRtcp(rtcp_packet_, rtcp_len_, @@ -153,8 +173,10 @@ TEST_F(SrtpSessionTest, TestReplay) { static const uint16_t replay_window = 1024; int out_len; - EXPECT_TRUE(s1_.SetSend(SRTP_AES128_CM_SHA1_80, kTestKey1, kTestKeyLen)); - EXPECT_TRUE(s2_.SetRecv(SRTP_AES128_CM_SHA1_80, kTestKey1, kTestKeyLen)); + EXPECT_TRUE(s1_.SetSend(SRTP_AES128_CM_SHA1_80, kTestKey1, kTestKeyLen, + kEncryptedHeaderExtensionIds)); + EXPECT_TRUE(s2_.SetRecv(SRTP_AES128_CM_SHA1_80, kTestKey1, kTestKeyLen, + kEncryptedHeaderExtensionIds)); // Initial sequence number. SetBE16(reinterpret_cast(rtp_packet_) + 2, seqnum_big); diff --git a/pc/srtptransport.cc b/pc/srtptransport.cc index b71276cac1..1343fd0cac 100644 --- a/pc/srtptransport.cc +++ b/pc/srtptransport.cc @@ -173,9 +173,11 @@ void SrtpTransport::OnPacketReceived(bool rtcp, bool SrtpTransport::SetRtpParams(int send_cs, const uint8_t* send_key, int send_key_len, + const std::vector& send_extension_ids, int recv_cs, const uint8_t* recv_key, - int recv_key_len) { + int recv_key_len, + const std::vector& recv_extension_ids) { // If parameters are being set for the first time, we should create new SRTP // sessions and call "SetSend/SetRecv". Otherwise we should call // "UpdateSend"/"UpdateRecv" on the existing sessions, which will internally @@ -186,21 +188,20 @@ bool SrtpTransport::SetRtpParams(int send_cs, CreateSrtpSessions(); new_sessions = true; } - send_session_->SetEncryptedHeaderExtensionIds( - send_encrypted_header_extension_ids_); bool ret = new_sessions - ? send_session_->SetSend(send_cs, send_key, send_key_len) - : send_session_->UpdateSend(send_cs, send_key, send_key_len); + ? send_session_->SetSend(send_cs, send_key, send_key_len, + send_extension_ids) + : send_session_->UpdateSend(send_cs, send_key, send_key_len, + send_extension_ids); if (!ret) { ResetParams(); return false; } - recv_session_->SetEncryptedHeaderExtensionIds( - recv_encrypted_header_extension_ids_); - ret = new_sessions - ? recv_session_->SetRecv(recv_cs, recv_key, recv_key_len) - : recv_session_->UpdateRecv(recv_cs, recv_key, recv_key_len); + ret = new_sessions ? recv_session_->SetRecv(recv_cs, recv_key, recv_key_len, + recv_extension_ids) + : recv_session_->UpdateRecv( + recv_cs, recv_key, recv_key_len, recv_extension_ids); if (!ret) { ResetParams(); return false; @@ -216,9 +217,11 @@ bool SrtpTransport::SetRtpParams(int send_cs, bool SrtpTransport::SetRtcpParams(int send_cs, const uint8_t* send_key, int send_key_len, + const std::vector& send_extension_ids, int recv_cs, const uint8_t* recv_key, - int recv_key_len) { + int recv_key_len, + const std::vector& recv_extension_ids) { // This can only be called once, but can be safely called after // SetRtpParams if (send_rtcp_session_ || recv_rtcp_session_) { @@ -227,12 +230,14 @@ bool SrtpTransport::SetRtcpParams(int send_cs, } send_rtcp_session_.reset(new cricket::SrtpSession()); - if (!send_rtcp_session_->SetSend(send_cs, send_key, send_key_len)) { + if (!send_rtcp_session_->SetSend(send_cs, send_key, send_key_len, + send_extension_ids)) { return false; } recv_rtcp_session_.reset(new cricket::SrtpSession()); - if (!recv_rtcp_session_->SetRecv(recv_cs, recv_key, recv_key_len)) { + if (!recv_rtcp_session_->SetRecv(recv_cs, recv_key, recv_key_len, + recv_extension_ids)) { return false; } @@ -255,16 +260,6 @@ void SrtpTransport::ResetParams() { RTC_LOG(LS_INFO) << "The params in SRTP transport are reset."; } -void SrtpTransport::SetEncryptedHeaderExtensionIds( - cricket::ContentSource source, - const std::vector& extension_ids) { - if (source == cricket::CS_LOCAL) { - recv_encrypted_header_extension_ids_ = extension_ids; - } else { - send_encrypted_header_extension_ids_ = extension_ids; - } -} - void SrtpTransport::CreateSrtpSessions() { send_session_.reset(new cricket::SrtpSession()); recv_session_.reset(new cricket::SrtpSession()); diff --git a/pc/srtptransport.h b/pc/srtptransport.h index 03c353c530..13abd6b47d 100644 --- a/pc/srtptransport.h +++ b/pc/srtptransport.h @@ -100,9 +100,11 @@ class SrtpTransport : public RtpTransportInternal { bool SetRtpParams(int send_cs, const uint8_t* send_key, int send_key_len, + const std::vector& send_extension_ids, int recv_cs, const uint8_t* recv_key, - int recv_key_len); + int recv_key_len, + const std::vector& recv_extension_ids); // Create new send/recv sessions and set the negotiated crypto keys for RTCP // packet encryption. The keys can either come from SDES negotiation or DTLS @@ -110,18 +112,14 @@ class SrtpTransport : public RtpTransportInternal { bool SetRtcpParams(int send_cs, const uint8_t* send_key, int send_key_len, + const std::vector& send_extension_ids, int recv_cs, const uint8_t* recv_key, - int recv_key_len); + int recv_key_len, + const std::vector& recv_extension_ids); void ResetParams(); - // Set the header extension ids that should be encrypted for the given source. - // This method doesn't immediately update the SRTP session with the new IDs, - // and you need to call SetRtpParams for that to happen. - void SetEncryptedHeaderExtensionIds(cricket::ContentSource source, - const std::vector& extension_ids); - // If external auth is enabled, SRTP will write a dummy auth tag that then // later must get replaced before the packet is sent out. Only supported for // non-GCM cipher suites and can be checked through "IsExternalAuthActive" @@ -187,8 +185,6 @@ class SrtpTransport : public RtpTransportInternal { std::unique_ptr send_rtcp_session_; std::unique_ptr recv_rtcp_session_; - std::vector send_encrypted_header_extension_ids_; - std::vector recv_encrypted_header_extension_ids_; bool external_auth_enabled_ = false; int rtp_abs_sendtime_extn_id_ = -1; diff --git a/pc/srtptransport_unittest.cc b/pc/srtptransport_unittest.cc index 35a792ddb3..3533863852 100644 --- a/pc/srtptransport_unittest.cc +++ b/pc/srtptransport_unittest.cc @@ -220,14 +220,15 @@ class SrtpTransportTest : public testing::Test, public sigslot::has_slots<> { srtp_transport1_->EnableExternalAuth(); srtp_transport2_->EnableExternalAuth(); } - EXPECT_TRUE( - srtp_transport1_->SetRtpParams(cs, key1, key1_len, cs, key2, key2_len)); - EXPECT_TRUE( - srtp_transport2_->SetRtpParams(cs, key2, key2_len, cs, key1, key1_len)); - EXPECT_TRUE(srtp_transport1_->SetRtcpParams(cs, key1, key1_len, cs, key2, - key2_len)); - EXPECT_TRUE(srtp_transport2_->SetRtcpParams(cs, key2, key2_len, cs, key1, - key1_len)); + std::vector extension_ids; + EXPECT_TRUE(srtp_transport1_->SetRtpParams( + cs, key1, key1_len, extension_ids, cs, key2, key2_len, extension_ids)); + EXPECT_TRUE(srtp_transport2_->SetRtpParams( + cs, key2, key2_len, extension_ids, cs, key1, key1_len, extension_ids)); + EXPECT_TRUE(srtp_transport1_->SetRtcpParams( + cs, key1, key1_len, extension_ids, cs, key2, key2_len, extension_ids)); + EXPECT_TRUE(srtp_transport2_->SetRtcpParams( + cs, key2, key2_len, extension_ids, cs, key1, key1_len, extension_ids)); EXPECT_TRUE(srtp_transport1_->IsActive()); EXPECT_TRUE(srtp_transport2_->IsActive()); if (rtc::IsGcmCryptoSuite(cs)) { @@ -308,18 +309,12 @@ class SrtpTransportTest : public testing::Test, public sigslot::has_slots<> { encrypted_headers.push_back(4); EXPECT_EQ(key1_len, key2_len); EXPECT_EQ(cs_name, rtc::SrtpCryptoSuiteToName(cs)); - srtp_transport1_->SetEncryptedHeaderExtensionIds(cricket::CS_LOCAL, - encrypted_headers); - srtp_transport1_->SetEncryptedHeaderExtensionIds(cricket::CS_REMOTE, - encrypted_headers); - srtp_transport2_->SetEncryptedHeaderExtensionIds(cricket::CS_LOCAL, - encrypted_headers); - srtp_transport2_->SetEncryptedHeaderExtensionIds(cricket::CS_REMOTE, - encrypted_headers); - EXPECT_TRUE( - srtp_transport1_->SetRtpParams(cs, key1, key1_len, cs, key2, key2_len)); - EXPECT_TRUE( - srtp_transport2_->SetRtpParams(cs, key2, key2_len, cs, key1, key1_len)); + EXPECT_TRUE(srtp_transport1_->SetRtpParams(cs, key1, key1_len, + encrypted_headers, cs, key2, + key2_len, encrypted_headers)); + EXPECT_TRUE(srtp_transport2_->SetRtpParams(cs, key2, key2_len, + encrypted_headers, cs, key1, + key1_len, encrypted_headers)); EXPECT_TRUE(srtp_transport1_->IsActive()); EXPECT_TRUE(srtp_transport2_->IsActive()); EXPECT_FALSE(srtp_transport1_->IsExternalAuthActive()); @@ -409,12 +404,13 @@ INSTANTIATE_TEST_CASE_P(ExternalAuth, // Test directly setting the params with bogus keys. TEST_F(SrtpTransportTest, TestSetParamsKeyTooShort) { + std::vector extension_ids; EXPECT_FALSE(srtp_transport1_->SetRtpParams( - rtc::SRTP_AES128_CM_SHA1_80, kTestKey1, kTestKeyLen - 1, - rtc::SRTP_AES128_CM_SHA1_80, kTestKey1, kTestKeyLen - 1)); + rtc::SRTP_AES128_CM_SHA1_80, kTestKey1, kTestKeyLen - 1, extension_ids, + rtc::SRTP_AES128_CM_SHA1_80, kTestKey1, kTestKeyLen - 1, extension_ids)); EXPECT_FALSE(srtp_transport1_->SetRtcpParams( - rtc::SRTP_AES128_CM_SHA1_80, kTestKey1, kTestKeyLen - 1, - rtc::SRTP_AES128_CM_SHA1_80, kTestKey1, kTestKeyLen - 1)); + rtc::SRTP_AES128_CM_SHA1_80, kTestKey1, kTestKeyLen - 1, extension_ids, + rtc::SRTP_AES128_CM_SHA1_80, kTestKey1, kTestKeyLen - 1, extension_ids)); } } // namespace webrtc