Revert of Completed the functionalities of SrtpTransport. (patchset 7 id:320001 of https://codereview.webrtc.org/2997983002/ )

Reason for revert:
This seems to be causing some video freezes. See https://bugs.chromium.org/p/webrtc/issues/detail?id=8251

Original issue's description:
> Completed the functionalities of SrtpTransport.
>
> The SrtpTransport takes the SRTP responsibilities from the BaseChannel
> and SrtpFilter. SrtpTransport is now responsible for setting the crypto
> keys, protecting and unprotecting the packets. SrtpTransport doesn't know
> if the keys are from SDES or DTLS handshake.
>
> BaseChannel is now only responsible setting the offer/answer for SDES
> or extracting the key from DtlsTransport and configuring the
> SrtpTransport.
>
> SrtpFilter is used by BaseChannel as a helper for SDES negotiation.
>
> BUG=webrtc:7013
>
> Review-Url: https://codereview.webrtc.org/2997983002
> Cr-Commit-Position: refs/heads/master@{#19636}
> Committed: e683c6871f

TBR=deadbeef@webrtc.org,pthatcher@google.com,zhihuang@webrtc.org
Not skipping CQ checks because original CL landed more than 1 days ago.
BUG=webrtc:7013

Review-Url: https://codereview.webrtc.org/3018513002
Cr-Commit-Position: refs/heads/master@{#19895}
This commit is contained in:
zhihuang
2017-09-19 01:12:52 -07:00
committed by Commit Bot
parent 2572404789
commit eb23e17798
13 changed files with 905 additions and 1094 deletions

View File

@ -86,8 +86,6 @@ class FakePacketTransport : public PacketTransportInternal {
bool GetOption(Socket::Option opt, int* value) override { return true; } bool GetOption(Socket::Option opt, int* value) override { return true; }
int GetError() override { return 0; } int GetError() override { return 0; }
const CopyOnWriteBuffer* last_sent_packet() { return &last_sent_packet_; }
private: private:
void set_writable(bool writable) { void set_writable(bool writable) {
if (writable_ == writable) { if (writable_ == writable) {
@ -109,14 +107,12 @@ class FakePacketTransport : public PacketTransportInternal {
} }
void SendPacketInternal(const CopyOnWriteBuffer& packet) { void SendPacketInternal(const CopyOnWriteBuffer& packet) {
last_sent_packet_ = packet;
if (dest_) { if (dest_) {
dest_->SignalReadPacket(dest_, packet.data<char>(), packet.size(), dest_->SignalReadPacket(dest_, packet.data<char>(), packet.size(),
CreatePacketTime(0), 0); CreatePacketTime(0), 0);
} }
} }
CopyOnWriteBuffer last_sent_packet_;
AsyncInvoker invoker_; AsyncInvoker invoker_;
std::string debug_name_; std::string debug_name_;
FakePacketTransport* dest_ = nullptr; FakePacketTransport* dest_ = nullptr;

View File

@ -151,22 +151,18 @@ BaseChannel::BaseChannel(rtc::Thread* worker_thread,
signaling_thread_(signaling_thread), signaling_thread_(signaling_thread),
content_name_(content_name), content_name_(content_name),
rtcp_mux_required_(rtcp_mux_required), rtcp_mux_required_(rtcp_mux_required),
rtp_transport_(
srtp_required
? rtc::WrapUnique<webrtc::RtpTransportInternal>(
new webrtc::SrtpTransport(rtcp_mux_required, content_name))
: rtc::MakeUnique<webrtc::RtpTransport>(rtcp_mux_required)),
srtp_required_(srtp_required), srtp_required_(srtp_required),
media_channel_(media_channel), media_channel_(media_channel),
selected_candidate_pair_(nullptr) { selected_candidate_pair_(nullptr) {
RTC_DCHECK(worker_thread_ == rtc::Thread::Current()); RTC_DCHECK(worker_thread_ == rtc::Thread::Current());
if (srtp_required) {
auto transport =
rtc::MakeUnique<webrtc::SrtpTransport>(rtcp_mux_required, content_name);
srtp_transport_ = transport.get();
rtp_transport_ = std::move(transport);
#if defined(ENABLE_EXTERNAL_AUTH) #if defined(ENABLE_EXTERNAL_AUTH)
srtp_transport_->EnableExternalAuth(); srtp_filter_.EnableExternalAuth();
#endif #endif
} else {
rtp_transport_ = rtc::MakeUnique<webrtc::RtpTransport>(rtcp_mux_required);
srtp_transport_ = nullptr;
}
rtp_transport_->SignalReadyToSend.connect( rtp_transport_->SignalReadyToSend.connect(
this, &BaseChannel::OnTransportReadyToSend); this, &BaseChannel::OnTransportReadyToSend);
// TODO(zstein): RtpTransport::SignalPacketReceived will probably be replaced // TODO(zstein): RtpTransport::SignalPacketReceived will probably be replaced
@ -311,17 +307,14 @@ void BaseChannel::SetTransports_n(
return; return;
} }
// When using DTLS-SRTP, we must reset the SrtpTransport every time the // When using DTLS-SRTP, we must reset the SrtpFilter every time the transport
// DtlsTransport changes and wait until the DTLS handshake is complete to set // changes and wait until the DTLS handshake is complete to set the newly
// the newly negotiated parameters. // negotiated parameters.
if (ShouldSetupDtlsSrtp_n()) { if (ShouldSetupDtlsSrtp_n()) {
// Set |writable_| to false such that UpdateWritableState_w can set up // Set |writable_| to false such that UpdateWritableState_w can set up
// DTLS-SRTP when |writable_| becomes true again. // DTLS-SRTP when |writable_| becomes true again.
writable_ = false; writable_ = false;
dtls_active_ = false; srtp_filter_.ResetParams();
if (srtp_transport_) {
srtp_transport_->ResetParams();
}
} }
// If this BaseChannel doesn't require RTCP mux and we haven't fully // If this BaseChannel doesn't require RTCP mux and we haven't fully
@ -377,8 +370,8 @@ void BaseChannel::SetTransport_n(
} }
if (rtcp && new_dtls_transport) { if (rtcp && new_dtls_transport) {
RTC_CHECK(!(ShouldSetupDtlsSrtp_n() && srtp_active())) RTC_CHECK(!(ShouldSetupDtlsSrtp_n() && srtp_filter_.IsActive()))
<< "Setting RTCP for DTLS/SRTP after the DTLS is active " << "Setting RTCP for DTLS/SRTP after SrtpFilter is active "
<< "should never happen."; << "should never happen.";
} }
@ -529,7 +522,8 @@ bool BaseChannel::IsReadyToSendMedia_n() const {
// and we have had some form of connectivity. // and we have had some form of connectivity.
return enabled() && IsReceiveContentDirection(remote_content_direction_) && return enabled() && IsReceiveContentDirection(remote_content_direction_) &&
IsSendContentDirection(local_content_direction_) && IsSendContentDirection(local_content_direction_) &&
was_ever_writable() && (srtp_active() || !ShouldSetupDtlsSrtp_n()); was_ever_writable() &&
(srtp_filter_.IsActive() || !ShouldSetupDtlsSrtp_n());
} }
bool BaseChannel::SendPacket(rtc::CopyOnWriteBuffer* packet, bool BaseChannel::SendPacket(rtc::CopyOnWriteBuffer* packet,
@ -581,16 +575,13 @@ void BaseChannel::OnDtlsState(DtlsTransportInternal* transport,
return; return;
} }
// Reset the SrtpTransport if it's not the CONNECTED state. For the CONNECTED // Reset the srtp filter if it's not the CONNECTED state. For the CONNECTED
// state, setting up DTLS-SRTP context is deferred to ChannelWritable_w to // state, setting up DTLS-SRTP context is deferred to ChannelWritable_w to
// cover other scenarios like the whole transport is writable (not just this // cover other scenarios like the whole transport is writable (not just this
// TransportChannel) or when TransportChannel is attached after DTLS is // TransportChannel) or when TransportChannel is attached after DTLS is
// negotiated. // negotiated.
if (state != DTLS_TRANSPORT_CONNECTED) { if (state != DTLS_TRANSPORT_CONNECTED) {
dtls_active_ = false; srtp_filter_.ResetParams();
if (srtp_transport_) {
srtp_transport_->ResetParams();
}
} }
} }
@ -664,8 +655,73 @@ bool BaseChannel::SendPacket(bool rtcp,
return false; return false;
} }
if (!srtp_active()) { rtc::PacketOptions updated_options;
if (srtp_required_) { updated_options = options;
// Protect if needed.
if (srtp_filter_.IsActive()) {
TRACE_EVENT0("webrtc", "SRTP Encode");
bool res;
uint8_t* data = packet->data();
int len = static_cast<int>(packet->size());
if (!rtcp) {
// If ENABLE_EXTERNAL_AUTH flag is on then packet authentication is not done
// inside libsrtp for a RTP packet. A external HMAC module will be writing
// a fake HMAC value. This is ONLY done for a RTP packet.
// Socket layer will update rtp sendtime extension header if present in
// packet with current time before updating the HMAC.
#if !defined(ENABLE_EXTERNAL_AUTH)
res = srtp_filter_.ProtectRtp(data, len,
static_cast<int>(packet->capacity()), &len);
#else
if (!srtp_filter_.IsExternalAuthActive()) {
res = srtp_filter_.ProtectRtp(
data, len, static_cast<int>(packet->capacity()), &len);
} else {
updated_options.packet_time_params.rtp_sendtime_extension_id =
rtp_abs_sendtime_extn_id_;
res = srtp_filter_.ProtectRtp(
data, len, static_cast<int>(packet->capacity()), &len,
&updated_options.packet_time_params.srtp_packet_index);
// If protection succeeds, let's get auth params from srtp.
if (res) {
uint8_t* auth_key = NULL;
int key_len;
res = srtp_filter_.GetRtpAuthParams(
&auth_key, &key_len,
&updated_options.packet_time_params.srtp_auth_tag_len);
if (res) {
updated_options.packet_time_params.srtp_auth_key.resize(key_len);
updated_options.packet_time_params.srtp_auth_key.assign(
auth_key, auth_key + key_len);
}
}
}
#endif
if (!res) {
int seq_num = -1;
uint32_t ssrc = 0;
GetRtpSeqNum(data, len, &seq_num);
GetRtpSsrc(data, len, &ssrc);
LOG(LS_ERROR) << "Failed to protect " << content_name_
<< " RTP packet: size=" << len << ", seqnum=" << seq_num
<< ", SSRC=" << ssrc;
return false;
}
} else {
res = srtp_filter_.ProtectRtcp(
data, len, static_cast<int>(packet->capacity()), &len);
if (!res) {
int type = -1;
GetRtcpType(data, len, &type);
LOG(LS_ERROR) << "Failed to protect " << content_name_
<< " RTCP packet: size=" << len << ", type=" << type;
return false;
}
}
// Update the length of the packet now that we've added the auth tag.
packet->SetSize(len);
} else if (srtp_required_) {
// The audio/video engines may attempt to send RTCP packets as soon as the // The audio/video engines may attempt to send RTCP packets as soon as the
// streams are created, so don't treat this as an error for RTCP. // streams are created, so don't treat this as an error for RTCP.
// See: https://bugs.chromium.org/p/webrtc/issues/detail?id=6809 // See: https://bugs.chromium.org/p/webrtc/issues/detail?id=6809
@ -679,15 +735,10 @@ bool BaseChannel::SendPacket(bool rtcp,
RTC_NOTREACHED(); RTC_NOTREACHED();
return false; return false;
} }
// Bon voyage. // Bon voyage.
return rtcp ? rtp_transport_->SendRtcpPacket(packet, options, PF_NORMAL) int flags = (secure() && secure_dtls()) ? PF_SRTP_BYPASS : PF_NORMAL;
: rtp_transport_->SendRtpPacket(packet, options, PF_NORMAL); return rtp_transport_->SendPacket(rtcp, packet, updated_options, flags);
}
RTC_DCHECK(srtp_transport_);
RTC_DCHECK(srtp_transport_->IsActive());
// Bon voyage.
return rtcp ? srtp_transport_->SendRtcpPacket(packet, options, PF_SRTP_BYPASS)
: srtp_transport_->SendRtpPacket(packet, options, PF_SRTP_BYPASS);
} }
bool BaseChannel::HandlesPayloadType(int packet_type) const { bool BaseChannel::HandlesPayloadType(int packet_type) const {
@ -702,7 +753,37 @@ void BaseChannel::OnPacketReceived(bool rtcp,
signaling_thread()->Post(RTC_FROM_HERE, this, MSG_FIRSTPACKETRECEIVED); signaling_thread()->Post(RTC_FROM_HERE, this, MSG_FIRSTPACKETRECEIVED);
} }
if (!srtp_active() && srtp_required_) { // Unprotect the packet, if needed.
if (srtp_filter_.IsActive()) {
TRACE_EVENT0("webrtc", "SRTP Decode");
char* data = packet->data<char>();
int len = static_cast<int>(packet->size());
bool res;
if (!rtcp) {
res = srtp_filter_.UnprotectRtp(data, len, &len);
if (!res) {
int seq_num = -1;
uint32_t ssrc = 0;
GetRtpSeqNum(data, len, &seq_num);
GetRtpSsrc(data, len, &ssrc);
LOG(LS_ERROR) << "Failed to unprotect " << content_name_
<< " RTP packet: size=" << len << ", seqnum=" << seq_num
<< ", SSRC=" << ssrc;
return;
}
} else {
res = srtp_filter_.UnprotectRtcp(data, len, &len);
if (!res) {
int type = -1;
GetRtcpType(data, len, &type);
LOG(LS_ERROR) << "Failed to unprotect " << content_name_
<< " RTCP packet: size=" << len << ", type=" << type;
return;
}
}
packet->SetSize(len);
} else if (srtp_required_) {
// Our session description indicates that SRTP is required, but we got a // Our session description indicates that SRTP is required, but we got a
// packet before our SRTP filter is active. This means either that // packet before our SRTP filter is active. This means either that
// a) we got SRTP packets before we received the SDES keys, in which case // a) we got SRTP packets before we received the SDES keys, in which case
@ -878,37 +959,42 @@ bool BaseChannel::SetupDtlsSrtp_n(bool rtcp) {
recv_key = &server_write_key; recv_key = &server_write_key;
} }
if (!srtp_filter_.IsActive()) {
if (rtcp) { if (rtcp) {
if (!dtls_active()) { ret = srtp_filter_.SetRtcpParams(selected_crypto_suite, &(*send_key)[0],
RTC_DCHECK(srtp_transport_);
ret = srtp_transport_->SetRtcpParams(
selected_crypto_suite, &(*send_key)[0],
static_cast<int>(send_key->size()), selected_crypto_suite,
&(*recv_key)[0], static_cast<int>(recv_key->size()));
} else {
// RTCP doesn't need to call SetRtpParam because it is only used
// to make the updated encrypted RTP header extension IDs take effect.
ret = true;
}
} else {
RTC_DCHECK(srtp_transport_);
ret = srtp_transport_->SetRtpParams(selected_crypto_suite, &(*send_key)[0],
static_cast<int>(send_key->size()), static_cast<int>(send_key->size()),
selected_crypto_suite, &(*recv_key)[0], selected_crypto_suite, &(*recv_key)[0],
static_cast<int>(recv_key->size())); static_cast<int>(recv_key->size()));
dtls_active_ = ret; } else {
ret = srtp_filter_.SetRtpParams(selected_crypto_suite, &(*send_key)[0],
static_cast<int>(send_key->size()),
selected_crypto_suite, &(*recv_key)[0],
static_cast<int>(recv_key->size()));
}
} else {
if (rtcp) {
// RTCP doesn't need to be updated because UpdateRtpParams is only used
// to update the set of encrypted RTP header extension IDs.
ret = true;
} else {
ret = srtp_filter_.UpdateRtpParams(selected_crypto_suite, &(*send_key)[0],
static_cast<int>(send_key->size()),
selected_crypto_suite, &(*recv_key)[0],
static_cast<int>(recv_key->size()));
}
} }
if (!ret) { if (!ret) {
LOG(LS_WARNING) << "DTLS-SRTP key installation failed"; LOG(LS_WARNING) << "DTLS-SRTP key installation failed";
} else { } else {
dtls_keyed_ = true;
UpdateTransportOverhead(); UpdateTransportOverhead();
} }
return ret; return ret;
} }
void BaseChannel::MaybeSetupDtlsSrtp_n() { void BaseChannel::MaybeSetupDtlsSrtp_n() {
if (dtls_active()) { if (srtp_filter_.IsActive()) {
return; return;
} }
@ -916,10 +1002,6 @@ void BaseChannel::MaybeSetupDtlsSrtp_n() {
return; return;
} }
if (!srtp_transport_) {
EnableSrtpTransport_n();
}
if (!SetupDtlsSrtp_n(false)) { if (!SetupDtlsSrtp_n(false)) {
SignalDtlsSrtpSetupFailure_n(false); SignalDtlsSrtpSetupFailure_n(false);
return; return;
@ -1003,24 +1085,6 @@ bool BaseChannel::CheckSrtpConfig_n(const std::vector<CryptoParams>& cryptos,
return true; return true;
} }
void BaseChannel::EnableSrtpTransport_n() {
if (srtp_transport_ == nullptr) {
rtp_transport_->SignalReadyToSend.disconnect(this);
rtp_transport_->SignalPacketReceived.disconnect(this);
auto transport = rtc::MakeUnique<webrtc::SrtpTransport>(
std::move(rtp_transport_), content_name_);
srtp_transport_ = transport.get();
rtp_transport_ = std::move(transport);
rtp_transport_->SignalReadyToSend.connect(
this, &BaseChannel::OnTransportReadyToSend);
rtp_transport_->SignalPacketReceived.connect(
this, &BaseChannel::OnPacketReceived);
LOG(LS_INFO) << "Wrapping RtpTransport in SrtpTransport.";
}
}
bool BaseChannel::SetSrtp_n(const std::vector<CryptoParams>& cryptos, bool BaseChannel::SetSrtp_n(const std::vector<CryptoParams>& cryptos,
ContentAction action, ContentAction action,
ContentSource src, ContentSource src,
@ -1037,69 +1101,36 @@ bool BaseChannel::SetSrtp_n(const std::vector<CryptoParams>& cryptos,
if (!ret) { if (!ret) {
return false; return false;
} }
srtp_filter_.SetEncryptedHeaderExtensionIds(src, encrypted_extension_ids);
// If SRTP was not required, but we're setting a description that uses SDES,
// we need to upgrade to an SrtpTransport.
if (!srtp_transport_ && !dtls && !cryptos.empty()) {
EnableSrtpTransport_n();
}
if (srtp_transport_) {
srtp_transport_->SetEncryptedHeaderExtensionIds(src,
encrypted_extension_ids);
}
switch (action) { switch (action) {
case CA_OFFER: case CA_OFFER:
// If DTLS is already active on the channel, we could be renegotiating // If DTLS is already active on the channel, we could be renegotiating
// here. We don't update the srtp filter. // here. We don't update the srtp filter.
if (!dtls) { if (!dtls) {
ret = sdes_negotiator_.SetOffer(cryptos, src); ret = srtp_filter_.SetOffer(cryptos, src);
} }
break; break;
case CA_PRANSWER: case CA_PRANSWER:
// If we're doing DTLS-SRTP, we don't want to update the filter // If we're doing DTLS-SRTP, we don't want to update the filter
// with an answer, because we already have SRTP parameters. // with an answer, because we already have SRTP parameters.
if (!dtls) { if (!dtls) {
ret = sdes_negotiator_.SetProvisionalAnswer(cryptos, src); ret = srtp_filter_.SetProvisionalAnswer(cryptos, src);
} }
break; break;
case CA_ANSWER: case CA_ANSWER:
// If we're doing DTLS-SRTP, we don't want to update the filter // If we're doing DTLS-SRTP, we don't want to update the filter
// with an answer, because we already have SRTP parameters. // with an answer, because we already have SRTP parameters.
if (!dtls) { if (!dtls) {
ret = sdes_negotiator_.SetAnswer(cryptos, src); ret = srtp_filter_.SetAnswer(cryptos, src);
} }
break; break;
default: default:
break; break;
} }
// If setting an SDES answer succeeded, apply the negotiated parameters
// to the SRTP transport.
if ((action == CA_PRANSWER || action == CA_ANSWER) && !dtls && ret) {
if (sdes_negotiator_.send_cipher_suite() &&
sdes_negotiator_.recv_cipher_suite()) {
ret = srtp_transport_->SetRtpParams(
*(sdes_negotiator_.send_cipher_suite()),
sdes_negotiator_.send_key().data(),
static_cast<int>(sdes_negotiator_.send_key().size()),
*(sdes_negotiator_.recv_cipher_suite()),
sdes_negotiator_.recv_key().data(),
static_cast<int>(sdes_negotiator_.recv_key().size()));
} else {
LOG(LS_INFO) << "No crypto keys are provided for SDES.";
if (action == CA_ANSWER && srtp_transport_) {
// Explicitly reset the |srtp_transport_| if no crypto param is
// provided in the answer. No need to call |ResetParams()| for
// |sdes_negotiator_| because it resets the params inside |SetAnswer|.
srtp_transport_->ResetParams();
}
}
}
// Only update SRTP filter if using DTLS. SDES is handled internally // Only update SRTP filter if using DTLS. SDES is handled internally
// by the SRTP filter. // by the SRTP filter.
// TODO(jbauch): Only update if encrypted extension ids have changed. // TODO(jbauch): Only update if encrypted extension ids have changed.
if (ret && dtls_active() && rtp_dtls_transport_ && if (ret && dtls_keyed_ && rtp_dtls_transport_ &&
rtp_dtls_transport_->dtls_state() == DTLS_TRANSPORT_CONNECTED) { rtp_dtls_transport_->dtls_state() == DTLS_TRANSPORT_CONNECTED) {
bool rtcp = false; bool rtcp = false;
ret = SetupDtlsSrtp_n(rtcp); ret = SetupDtlsSrtp_n(rtcp);
@ -1143,6 +1174,7 @@ bool BaseChannel::SetRtcpMux_n(bool enable,
transport_name_.empty() transport_name_.empty()
? rtp_transport_->rtp_packet_transport()->debug_name() ? rtp_transport_->rtp_packet_transport()->debug_name()
: transport_name_; : transport_name_;
;
LOG(LS_INFO) << "Enabling rtcp-mux for " << content_name() LOG(LS_INFO) << "Enabling rtcp-mux for " << content_name()
<< "; no longer need RTCP transport for " << debug_name; << "; no longer need RTCP transport for " << debug_name;
if (rtp_transport_->rtcp_packet_transport()) { if (rtp_transport_->rtcp_packet_transport()) {
@ -1371,13 +1403,7 @@ void BaseChannel::MaybeCacheRtpAbsSendTimeHeaderExtension_w(
void BaseChannel::CacheRtpAbsSendTimeHeaderExtension_n( void BaseChannel::CacheRtpAbsSendTimeHeaderExtension_n(
int rtp_abs_sendtime_extn_id) { int rtp_abs_sendtime_extn_id) {
if (srtp_transport_) { rtp_abs_sendtime_extn_id_ = rtp_abs_sendtime_extn_id;
srtp_transport_->CacheRtpAbsSendTimeHeaderExtension(
rtp_abs_sendtime_extn_id);
} else {
LOG(LS_WARNING) << "Trying to cache the Absolute Send Time extension id "
"but the SRTP is not active.";
}
} }
void BaseChannel::OnMessage(rtc::Message *pmsg) { void BaseChannel::OnMessage(rtc::Message *pmsg) {
@ -1661,9 +1687,9 @@ int BaseChannel::GetTransportOverheadPerPacket() const {
? kTcpOverhaed ? kTcpOverhaed
: kUdpOverhaed; : kUdpOverhaed;
if (sdes_active()) { if (secure()) {
int srtp_overhead = 0; int srtp_overhead = 0;
if (srtp_transport_->GetSrtpOverhead(&srtp_overhead)) if (srtp_filter_.GetSrtpOverhead(&srtp_overhead))
transport_overhead_per_packet += srtp_overhead; transport_overhead_per_packet += srtp_overhead;
} }

View File

@ -33,6 +33,7 @@
#include "pc/mediamonitor.h" #include "pc/mediamonitor.h"
#include "pc/mediasession.h" #include "pc/mediasession.h"
#include "pc/rtcpmuxfilter.h" #include "pc/rtcpmuxfilter.h"
#include "pc/rtptransportinternal.h"
#include "pc/srtpfilter.h" #include "pc/srtpfilter.h"
#include "rtc_base/asyncinvoker.h" #include "rtc_base/asyncinvoker.h"
#include "rtc_base/asyncudpsocket.h" #include "rtc_base/asyncudpsocket.h"
@ -43,8 +44,6 @@
namespace webrtc { namespace webrtc {
class AudioSinkInterface; class AudioSinkInterface;
class RtpTransportInternal;
class SrtpTransport;
} // namespace webrtc } // namespace webrtc
namespace cricket { namespace cricket {
@ -100,12 +99,12 @@ class BaseChannel
const std::string& transport_name() const { return transport_name_; } const std::string& transport_name() const { return transport_name_; }
bool enabled() const { return enabled_; } bool enabled() const { return enabled_; }
// This function returns true if we are using SDES. // This function returns true if we are using SRTP.
bool sdes_active() const { return sdes_negotiator_.IsActive(); } bool secure() const { return srtp_filter_.IsActive(); }
// The following function returns true if we are using DTLS-based keying. // The following function returns true if we are using
bool dtls_active() const { return dtls_active_; } // DTLS-based keying. If you turned off SRTP later, however
// This function returns true if using SRTP (DTLS-based keying or SDES). // you could have secure() == false and dtls_secure() == true.
bool srtp_active() const { return sdes_active() || dtls_active(); } bool secure_dtls() const { return dtls_keyed_; }
bool writable() const { return writable_; } bool writable() const { return writable_; }
@ -183,6 +182,8 @@ class BaseChannel
override; override;
int SetOption_n(SocketType type, rtc::Socket::Option o, int val); int SetOption_n(SocketType type, rtc::Socket::Option o, int val);
SrtpFilter* srtp_filter() { return &srtp_filter_; }
virtual cricket::MediaType media_type() = 0; virtual cricket::MediaType media_type() = 0;
// This function returns true if we require SRTP for call setup. // This function returns true if we require SRTP for call setup.
@ -368,8 +369,6 @@ class BaseChannel
void CacheRtpAbsSendTimeHeaderExtension_n(int rtp_abs_sendtime_extn_id); void CacheRtpAbsSendTimeHeaderExtension_n(int rtp_abs_sendtime_extn_id);
int GetTransportOverheadPerPacket() const; int GetTransportOverheadPerPacket() const;
void UpdateTransportOverhead(); void UpdateTransportOverhead();
// Wraps the existing RtpTransport in an SrtpTransport.
void EnableSrtpTransport_n();
rtc::Thread* const worker_thread_; rtc::Thread* const worker_thread_;
rtc::Thread* const network_thread_; rtc::Thread* const network_thread_;
@ -390,16 +389,16 @@ class BaseChannel
DtlsTransportInternal* rtp_dtls_transport_ = nullptr; DtlsTransportInternal* rtp_dtls_transport_ = nullptr;
DtlsTransportInternal* rtcp_dtls_transport_ = nullptr; DtlsTransportInternal* rtcp_dtls_transport_ = nullptr;
std::unique_ptr<webrtc::RtpTransportInternal> rtp_transport_; std::unique_ptr<webrtc::RtpTransportInternal> rtp_transport_;
webrtc::SrtpTransport* srtp_transport_ = nullptr;
std::vector<std::pair<rtc::Socket::Option, int> > socket_options_; std::vector<std::pair<rtc::Socket::Option, int> > socket_options_;
std::vector<std::pair<rtc::Socket::Option, int> > rtcp_socket_options_; std::vector<std::pair<rtc::Socket::Option, int> > rtcp_socket_options_;
SrtpFilter sdes_negotiator_; SrtpFilter srtp_filter_;
RtcpMuxFilter rtcp_mux_filter_; RtcpMuxFilter rtcp_mux_filter_;
bool writable_ = false; bool writable_ = false;
bool was_ever_writable_ = false; bool was_ever_writable_ = false;
bool has_received_packet_ = false; bool has_received_packet_ = false;
bool dtls_active_ = false; bool dtls_keyed_ = false;
const bool srtp_required_ = true; const bool srtp_required_ = true;
int rtp_abs_sendtime_extn_id_ = -1;
// MediaChannel related members that should be accessed from the worker // MediaChannel related members that should be accessed from the worker
// thread. // thread.

View File

@ -577,7 +577,7 @@ class ChannelTest : public testing::Test, public sigslot::has_slots<> {
// Basic sanity check. // Basic sanity check.
void TestInit() { void TestInit() {
CreateChannels(0, 0); CreateChannels(0, 0);
EXPECT_FALSE(channel1_->srtp_active()); EXPECT_FALSE(channel1_->secure());
EXPECT_FALSE(media_channel1_->sending()); EXPECT_FALSE(media_channel1_->sending());
if (verify_playout_) { if (verify_playout_) {
EXPECT_FALSE(media_channel1_->playout()); EXPECT_FALSE(media_channel1_->playout());
@ -892,8 +892,8 @@ class ChannelTest : public testing::Test, public sigslot::has_slots<> {
EXPECT_TRUE(channel2_->SetRemoteContent(&content4, CA_ANSWER, NULL)); EXPECT_TRUE(channel2_->SetRemoteContent(&content4, CA_ANSWER, NULL));
EXPECT_EQ(0u, media_channel2_->recv_streams().size()); EXPECT_EQ(0u, media_channel2_->recv_streams().size());
EXPECT_TRUE(channel1_->srtp_active()); EXPECT_TRUE(channel1_->secure());
EXPECT_TRUE(channel2_->srtp_active()); EXPECT_TRUE(channel2_->secure());
SendCustomRtp2(kSsrc2, 0); SendCustomRtp2(kSsrc2, 0);
WaitForThreads(); WaitForThreads();
EXPECT_TRUE(CheckCustomRtp1(kSsrc2, 0)); EXPECT_TRUE(CheckCustomRtp1(kSsrc2, 0));
@ -1249,14 +1249,14 @@ class ChannelTest : public testing::Test, public sigslot::has_slots<> {
// Test setting up a call. // Test setting up a call.
void TestCallSetup() { void TestCallSetup() {
CreateChannels(0, 0); CreateChannels(0, 0);
EXPECT_FALSE(channel1_->srtp_active()); EXPECT_FALSE(channel1_->secure());
EXPECT_TRUE(SendInitiate()); EXPECT_TRUE(SendInitiate());
if (verify_playout_) { if (verify_playout_) {
EXPECT_TRUE(media_channel1_->playout()); EXPECT_TRUE(media_channel1_->playout());
} }
EXPECT_FALSE(media_channel1_->sending()); EXPECT_FALSE(media_channel1_->sending());
EXPECT_TRUE(SendAccept()); EXPECT_TRUE(SendAccept());
EXPECT_FALSE(channel1_->srtp_active()); EXPECT_FALSE(channel1_->secure());
EXPECT_TRUE(media_channel1_->sending()); EXPECT_TRUE(media_channel1_->sending());
EXPECT_EQ(1U, media_channel1_->codecs().size()); EXPECT_EQ(1U, media_channel1_->codecs().size());
if (verify_playout_) { if (verify_playout_) {
@ -1531,17 +1531,17 @@ class ChannelTest : public testing::Test, public sigslot::has_slots<> {
bool dtls1 = !!(flags1_in & DTLS); bool dtls1 = !!(flags1_in & DTLS);
bool dtls2 = !!(flags2_in & DTLS); bool dtls2 = !!(flags2_in & DTLS);
CreateChannels(flags1, flags2); CreateChannels(flags1, flags2);
EXPECT_FALSE(channel1_->srtp_active()); EXPECT_FALSE(channel1_->secure());
EXPECT_FALSE(channel2_->srtp_active()); EXPECT_FALSE(channel2_->secure());
EXPECT_TRUE(SendInitiate()); EXPECT_TRUE(SendInitiate());
WaitForThreads(); WaitForThreads();
EXPECT_TRUE(channel1_->writable()); EXPECT_TRUE(channel1_->writable());
EXPECT_TRUE(channel2_->writable()); EXPECT_TRUE(channel2_->writable());
EXPECT_TRUE(SendAccept()); EXPECT_TRUE(SendAccept());
EXPECT_TRUE(channel1_->srtp_active()); EXPECT_TRUE(channel1_->secure());
EXPECT_TRUE(channel2_->srtp_active()); EXPECT_TRUE(channel2_->secure());
EXPECT_EQ(dtls1 && dtls2, channel1_->dtls_active()); EXPECT_EQ(dtls1 && dtls2, channel1_->secure_dtls());
EXPECT_EQ(dtls1 && dtls2, channel2_->dtls_active()); EXPECT_EQ(dtls1 && dtls2, channel2_->secure_dtls());
SendRtp1(); SendRtp1();
SendRtp2(); SendRtp2();
SendRtcp1(); SendRtcp1();
@ -1560,12 +1560,12 @@ class ChannelTest : public testing::Test, public sigslot::has_slots<> {
// Test that we properly handling SRTP negotiating down to RTP. // Test that we properly handling SRTP negotiating down to RTP.
void SendSrtpToRtp() { void SendSrtpToRtp() {
CreateChannels(SECURE, 0); CreateChannels(SECURE, 0);
EXPECT_FALSE(channel1_->srtp_active()); EXPECT_FALSE(channel1_->secure());
EXPECT_FALSE(channel2_->srtp_active()); EXPECT_FALSE(channel2_->secure());
EXPECT_TRUE(SendInitiate()); EXPECT_TRUE(SendInitiate());
EXPECT_TRUE(SendAccept()); EXPECT_TRUE(SendAccept());
EXPECT_FALSE(channel1_->srtp_active()); EXPECT_FALSE(channel1_->secure());
EXPECT_FALSE(channel2_->srtp_active()); EXPECT_FALSE(channel2_->secure());
SendRtp1(); SendRtp1();
SendRtp2(); SendRtp2();
SendRtcp1(); SendRtcp1();
@ -1590,8 +1590,8 @@ class ChannelTest : public testing::Test, public sigslot::has_slots<> {
SSRC_MUX | RTCP_MUX | SECURE); SSRC_MUX | RTCP_MUX | SECURE);
EXPECT_TRUE(SendOffer()); EXPECT_TRUE(SendOffer());
EXPECT_TRUE(SendProvisionalAnswer()); EXPECT_TRUE(SendProvisionalAnswer());
EXPECT_TRUE(channel1_->srtp_active()); EXPECT_TRUE(channel1_->secure());
EXPECT_TRUE(channel2_->srtp_active()); EXPECT_TRUE(channel2_->secure());
EXPECT_TRUE(channel1_->NeedsRtcpTransport()); EXPECT_TRUE(channel1_->NeedsRtcpTransport());
EXPECT_TRUE(channel2_->NeedsRtcpTransport()); EXPECT_TRUE(channel2_->NeedsRtcpTransport());
WaitForThreads(); // Wait for 'sending' flag go through network thread. WaitForThreads(); // Wait for 'sending' flag go through network thread.
@ -1616,8 +1616,8 @@ class ChannelTest : public testing::Test, public sigslot::has_slots<> {
EXPECT_FALSE(channel2_->NeedsRtcpTransport()); EXPECT_FALSE(channel2_->NeedsRtcpTransport());
EXPECT_EQ(1, rtcp_mux_activated_callbacks1_); EXPECT_EQ(1, rtcp_mux_activated_callbacks1_);
EXPECT_EQ(1, rtcp_mux_activated_callbacks2_); EXPECT_EQ(1, rtcp_mux_activated_callbacks2_);
EXPECT_TRUE(channel1_->srtp_active()); EXPECT_TRUE(channel1_->secure());
EXPECT_TRUE(channel2_->srtp_active()); EXPECT_TRUE(channel2_->secure());
SendCustomRtcp1(kSsrc1); SendCustomRtcp1(kSsrc1);
SendCustomRtp1(kSsrc1, ++sequence_number1_1); SendCustomRtp1(kSsrc1, ++sequence_number1_1);
SendCustomRtcp2(kSsrc2); SendCustomRtcp2(kSsrc2);

View File

@ -76,18 +76,6 @@ bool RtpTransport::IsWritable(bool rtcp) const {
return transport && transport->writable(); return transport && transport->writable();
} }
bool RtpTransport::SendRtpPacket(rtc::CopyOnWriteBuffer* packet,
const rtc::PacketOptions& options,
int flags) {
return SendPacket(false, packet, options, flags);
}
bool RtpTransport::SendRtcpPacket(rtc::CopyOnWriteBuffer* packet,
const rtc::PacketOptions& options,
int flags) {
return SendPacket(true, packet, options, flags);
}
bool RtpTransport::SendPacket(bool rtcp, bool RtpTransport::SendPacket(bool rtcp,
rtc::CopyOnWriteBuffer* packet, rtc::CopyOnWriteBuffer* packet,
const rtc::PacketOptions& options, const rtc::PacketOptions& options,

View File

@ -56,11 +56,8 @@ class RtpTransport : public RtpTransportInternal {
bool IsWritable(bool rtcp) const override; bool IsWritable(bool rtcp) const override;
bool SendRtpPacket(rtc::CopyOnWriteBuffer* packet, bool SendPacket(bool rtcp,
const rtc::PacketOptions& options, rtc::CopyOnWriteBuffer* packet,
int flags) override;
bool SendRtcpPacket(rtc::CopyOnWriteBuffer* packet,
const rtc::PacketOptions& options, const rtc::PacketOptions& options,
int flags) override; int flags) override;
@ -83,11 +80,6 @@ class RtpTransport : public RtpTransportInternal {
void MaybeSignalReadyToSend(); void MaybeSignalReadyToSend();
bool SendPacket(bool rtcp,
rtc::CopyOnWriteBuffer* packet,
const rtc::PacketOptions& options,
int flags);
void OnReadPacket(rtc::PacketTransportInternal* transport, void OnReadPacket(rtc::PacketTransportInternal* transport,
const char* data, const char* data,
size_t len, size_t len,

View File

@ -54,11 +54,8 @@ class RtpTransportInternal : public RtpTransportInterface,
virtual bool IsWritable(bool rtcp) const = 0; virtual bool IsWritable(bool rtcp) const = 0;
virtual bool SendRtpPacket(rtc::CopyOnWriteBuffer* packet, virtual bool SendPacket(bool rtcp,
const rtc::PacketOptions& options, rtc::CopyOnWriteBuffer* packet,
int flags) = 0;
virtual bool SendRtcpPacket(rtc::CopyOnWriteBuffer* packet,
const rtc::PacketOptions& options, const rtc::PacketOptions& options,
int flags) = 0; int flags) = 0;

View File

@ -17,6 +17,7 @@
#include "media/base/rtputils.h" #include "media/base/rtputils.h"
#include "pc/srtpsession.h" #include "pc/srtpsession.h"
#include "rtc_base/base64.h" #include "rtc_base/base64.h"
#include "rtc_base/buffer.h"
#include "rtc_base/byteorder.h" #include "rtc_base/byteorder.h"
#include "rtc_base/checks.h" #include "rtc_base/checks.h"
#include "rtc_base/logging.h" #include "rtc_base/logging.h"
@ -62,6 +63,210 @@ bool SrtpFilter::SetProvisionalAnswer(
return DoSetAnswer(answer_params, source, false); return DoSetAnswer(answer_params, source, false);
} }
bool SrtpFilter::SetRtpParams(int send_cs,
const uint8_t* send_key,
int send_key_len,
int recv_cs,
const uint8_t* recv_key,
int recv_key_len) {
if (IsActive()) {
LOG(LS_ERROR) << "Tried to set SRTP Params when filter already active";
return false;
}
CreateSrtpSessions();
send_session_->SetEncryptedHeaderExtensionIds(
send_encrypted_header_extension_ids_);
if (!send_session_->SetSend(send_cs, send_key, send_key_len)) {
return false;
}
recv_session_->SetEncryptedHeaderExtensionIds(
recv_encrypted_header_extension_ids_);
if (!recv_session_->SetRecv(recv_cs, recv_key, recv_key_len)) {
return false;
}
state_ = ST_ACTIVE;
LOG(LS_INFO) << "SRTP activated with negotiated parameters:"
<< " send cipher_suite " << send_cs << " recv cipher_suite "
<< recv_cs;
return true;
}
bool SrtpFilter::UpdateRtpParams(int send_cs,
const uint8_t* send_key,
int send_key_len,
int recv_cs,
const uint8_t* recv_key,
int recv_key_len) {
if (!IsActive()) {
LOG(LS_ERROR) << "Tried to update SRTP Params when filter is not active";
return false;
}
send_session_->SetEncryptedHeaderExtensionIds(
send_encrypted_header_extension_ids_);
if (!send_session_->UpdateSend(send_cs, send_key, send_key_len)) {
return false;
}
recv_session_->SetEncryptedHeaderExtensionIds(
recv_encrypted_header_extension_ids_);
if (!recv_session_->UpdateRecv(recv_cs, recv_key, recv_key_len)) {
return false;
}
LOG(LS_INFO) << "SRTP updated with negotiated parameters:"
<< " send cipher_suite " << send_cs << " recv cipher_suite "
<< recv_cs;
return true;
}
// This function is provided separately because DTLS-SRTP behaves
// differently in RTP/RTCP mux and non-mux modes.
//
// - In the non-muxed case, RTP and RTCP are keyed with different
// keys (from different DTLS handshakes), and so we need a new
// SrtpSession.
// - In the muxed case, they are keyed with the same keys, so
// this function is not needed
bool SrtpFilter::SetRtcpParams(int send_cs,
const uint8_t* send_key,
int send_key_len,
int recv_cs,
const uint8_t* recv_key,
int recv_key_len) {
// This can only be called once, but can be safely called after
// SetRtpParams
if (send_rtcp_session_ || recv_rtcp_session_) {
LOG(LS_ERROR) << "Tried to set SRTCP Params when filter already active";
return false;
}
send_rtcp_session_.reset(new SrtpSession());
if (!send_rtcp_session_->SetRecv(send_cs, send_key, send_key_len)) {
return false;
}
recv_rtcp_session_.reset(new SrtpSession());
if (!recv_rtcp_session_->SetRecv(recv_cs, recv_key, recv_key_len)) {
return false;
}
LOG(LS_INFO) << "SRTCP activated with negotiated parameters:"
<< " send cipher_suite " << send_cs << " recv cipher_suite "
<< recv_cs;
return true;
}
bool SrtpFilter::ProtectRtp(void* p, int in_len, int max_len, int* out_len) {
if (!IsActive()) {
LOG(LS_WARNING) << "Failed to ProtectRtp: SRTP not active";
return false;
}
RTC_CHECK(send_session_);
return send_session_->ProtectRtp(p, in_len, max_len, out_len);
}
bool SrtpFilter::ProtectRtp(void* p,
int in_len,
int max_len,
int* out_len,
int64_t* index) {
if (!IsActive()) {
LOG(LS_WARNING) << "Failed to ProtectRtp: SRTP not active";
return false;
}
RTC_CHECK(send_session_);
return send_session_->ProtectRtp(p, in_len, max_len, out_len, index);
}
bool SrtpFilter::ProtectRtcp(void* p, int in_len, int max_len, int* out_len) {
if (!IsActive()) {
LOG(LS_WARNING) << "Failed to ProtectRtcp: SRTP not active";
return false;
}
if (send_rtcp_session_) {
return send_rtcp_session_->ProtectRtcp(p, in_len, max_len, out_len);
} else {
RTC_CHECK(send_session_);
return send_session_->ProtectRtcp(p, in_len, max_len, out_len);
}
}
bool SrtpFilter::UnprotectRtp(void* p, int in_len, int* out_len) {
if (!IsActive()) {
LOG(LS_WARNING) << "Failed to UnprotectRtp: SRTP not active";
return false;
}
RTC_CHECK(recv_session_);
return recv_session_->UnprotectRtp(p, in_len, out_len);
}
bool SrtpFilter::UnprotectRtcp(void* p, int in_len, int* out_len) {
if (!IsActive()) {
LOG(LS_WARNING) << "Failed to UnprotectRtcp: SRTP not active";
return false;
}
if (recv_rtcp_session_) {
return recv_rtcp_session_->UnprotectRtcp(p, in_len, out_len);
} else {
RTC_CHECK(recv_session_);
return recv_session_->UnprotectRtcp(p, in_len, out_len);
}
}
bool SrtpFilter::GetRtpAuthParams(uint8_t** key, int* key_len, int* tag_len) {
if (!IsActive()) {
LOG(LS_WARNING) << "Failed to GetRtpAuthParams: SRTP not active";
return false;
}
RTC_CHECK(send_session_);
return send_session_->GetRtpAuthParams(key, key_len, tag_len);
}
bool SrtpFilter::GetSrtpOverhead(int* srtp_overhead) const {
if (!IsActive()) {
LOG(LS_WARNING) << "Failed to GetSrtpOverhead: SRTP not active";
return false;
}
RTC_CHECK(send_session_);
*srtp_overhead = send_session_->GetSrtpOverhead();
return true;
}
void SrtpFilter::EnableExternalAuth() {
RTC_DCHECK(!IsActive());
external_auth_enabled_ = true;
}
bool SrtpFilter::IsExternalAuthEnabled() const {
return external_auth_enabled_;
}
bool SrtpFilter::IsExternalAuthActive() const {
if (!IsActive()) {
LOG(LS_WARNING) << "Failed to check IsExternalAuthActive: SRTP not active";
return false;
}
RTC_CHECK(send_session_);
return send_session_->IsExternalAuthActive();
}
void SrtpFilter::SetEncryptedHeaderExtensionIds(
ContentSource source,
const std::vector<int>& extension_ids) {
if (source == CS_LOCAL) {
recv_encrypted_header_extension_ids_ = extension_ids;
} else {
send_encrypted_header_extension_ids_ = extension_ids;
}
}
bool SrtpFilter::ExpectOffer(ContentSource source) { bool SrtpFilter::ExpectOffer(ContentSource source) {
return ((state_ == ST_INIT) || return ((state_ == ST_INIT) ||
(state_ == ST_ACTIVE) || (state_ == ST_ACTIVE) ||
@ -119,16 +324,13 @@ bool SrtpFilter::DoSetAnswer(const std::vector<CryptoParams>& answer_params,
CryptoParams selected_params; CryptoParams selected_params;
if (!NegotiateParams(answer_params, &selected_params)) if (!NegotiateParams(answer_params, &selected_params))
return false; return false;
const CryptoParams& send_params =
const CryptoParams& new_send_params =
(source == CS_REMOTE) ? selected_params : answer_params[0]; (source == CS_REMOTE) ? selected_params : answer_params[0];
const CryptoParams& new_recv_params = const CryptoParams& recv_params =
(source == CS_REMOTE) ? answer_params[0] : selected_params; (source == CS_REMOTE) ? answer_params[0] : selected_params;
if (!ApplySendParams(new_send_params) || !ApplyRecvParams(new_recv_params)) { if (!ApplyParams(send_params, recv_params)) {
return false; return false;
} }
applied_send_params_ = new_send_params;
applied_recv_params_ = new_recv_params;
if (final) { if (final) {
offer_params_.clear(); offer_params_.clear();
@ -140,6 +342,17 @@ bool SrtpFilter::DoSetAnswer(const std::vector<CryptoParams>& answer_params,
return true; return true;
} }
void SrtpFilter::CreateSrtpSessions() {
send_session_.reset(new SrtpSession());
applied_send_params_ = CryptoParams();
recv_session_.reset(new SrtpSession());
applied_recv_params_ = CryptoParams();
if (external_auth_enabled_) {
send_session_->EnableExternalAuth();
}
}
bool SrtpFilter::NegotiateParams(const std::vector<CryptoParams>& answer_params, bool SrtpFilter::NegotiateParams(const std::vector<CryptoParams>& answer_params,
CryptoParams* selected_params) { CryptoParams* selected_params) {
// We're processing an accept. We should have exactly one set of params, // We're processing an accept. We should have exactly one set of params,
@ -167,76 +380,85 @@ bool SrtpFilter::NegotiateParams(const std::vector<CryptoParams>& answer_params,
return ret; return ret;
} }
bool SrtpFilter::ResetParams() { bool SrtpFilter::ApplyParams(const CryptoParams& send_params,
offer_params_.clear(); const CryptoParams& recv_params) {
applied_send_params_ = CryptoParams(); // TODO(jiayl): Split this method to apply send and receive CryptoParams
applied_recv_params_ = CryptoParams(); // independently, so that we can skip one method when either send or receive
send_cipher_suite_ = rtc::Optional<int>(); // CryptoParams is unchanged.
recv_cipher_suite_ = rtc::Optional<int>();
send_key_.Clear();
recv_key_.Clear();
state_ = ST_INIT;
return true;
}
bool SrtpFilter::ApplySendParams(const CryptoParams& send_params) {
if (applied_send_params_.cipher_suite == send_params.cipher_suite && if (applied_send_params_.cipher_suite == send_params.cipher_suite &&
applied_send_params_.key_params == send_params.key_params) { applied_send_params_.key_params == send_params.key_params &&
LOG(LS_INFO) << "Applying the same SRTP send parameters again. No-op."; applied_recv_params_.cipher_suite == recv_params.cipher_suite &&
applied_recv_params_.key_params == recv_params.key_params) {
LOG(LS_INFO) << "Applying the same SRTP parameters again. No-op.";
// We do not want to reset the ROC if the keys are the same. So just return. // We do not want to reset the ROC if the keys are the same. So just return.
return true; return true;
} }
send_cipher_suite_ = rtc::Optional<int>( int send_suite = rtc::SrtpCryptoSuiteFromName(send_params.cipher_suite);
rtc::SrtpCryptoSuiteFromName(send_params.cipher_suite)); int recv_suite = rtc::SrtpCryptoSuiteFromName(recv_params.cipher_suite);
if (send_cipher_suite_ == rtc::SRTP_INVALID_CRYPTO_SUITE) { if (send_suite == rtc::SRTP_INVALID_CRYPTO_SUITE ||
recv_suite == rtc::SRTP_INVALID_CRYPTO_SUITE) {
LOG(LS_WARNING) << "Unknown crypto suite(s) received:" LOG(LS_WARNING) << "Unknown crypto suite(s) received:"
<< " send cipher_suite " << send_params.cipher_suite; << " send cipher_suite " << send_params.cipher_suite
<< " recv cipher_suite " << recv_params.cipher_suite;
return false; return false;
} }
int send_key_len, send_salt_len; int send_key_len, send_salt_len;
if (!rtc::GetSrtpKeyAndSaltLengths(*send_cipher_suite_, &send_key_len,
&send_salt_len)) {
LOG(LS_WARNING) << "Could not get lengths for crypto suite(s):"
<< " send cipher_suite " << send_params.cipher_suite;
return false;
}
send_key_ = rtc::Buffer(send_key_len + send_salt_len);
return ParseKeyParams(send_params.key_params, send_key_.data(),
send_key_.size());
}
bool SrtpFilter::ApplyRecvParams(const CryptoParams& recv_params) {
if (applied_recv_params_.cipher_suite == recv_params.cipher_suite &&
applied_recv_params_.key_params == recv_params.key_params) {
LOG(LS_INFO) << "Applying the same SRTP recv parameters again. No-op.";
// We do not want to reset the ROC if the keys are the same. So just return.
return true;
}
recv_cipher_suite_ = rtc::Optional<int>(
rtc::SrtpCryptoSuiteFromName(recv_params.cipher_suite));
if (recv_cipher_suite_ == rtc::SRTP_INVALID_CRYPTO_SUITE) {
LOG(LS_WARNING) << "Unknown crypto suite(s) received:"
<< " recv cipher_suite " << recv_params.cipher_suite;
return false;
}
int recv_key_len, recv_salt_len; int recv_key_len, recv_salt_len;
if (!rtc::GetSrtpKeyAndSaltLengths(*recv_cipher_suite_, &recv_key_len, if (!rtc::GetSrtpKeyAndSaltLengths(send_suite, &send_key_len,
&send_salt_len) ||
!rtc::GetSrtpKeyAndSaltLengths(recv_suite, &recv_key_len,
&recv_salt_len)) { &recv_salt_len)) {
LOG(LS_WARNING) << "Could not get lengths for crypto suite(s):" LOG(LS_WARNING) << "Could not get lengths for crypto suite(s):"
<< " send cipher_suite " << send_params.cipher_suite
<< " recv cipher_suite " << recv_params.cipher_suite; << " recv cipher_suite " << recv_params.cipher_suite;
return false; return false;
} }
recv_key_ = rtc::Buffer(recv_key_len + recv_salt_len); // TODO(juberti): Zero these buffers after use.
return ParseKeyParams(recv_params.key_params, recv_key_.data(), bool ret;
recv_key_.size()); rtc::Buffer send_key(send_key_len + send_salt_len);
rtc::Buffer recv_key(recv_key_len + recv_salt_len);
ret = (ParseKeyParams(send_params.key_params, send_key.data(),
send_key.size()) &&
ParseKeyParams(recv_params.key_params, recv_key.data(),
recv_key.size()));
if (ret) {
CreateSrtpSessions();
send_session_->SetEncryptedHeaderExtensionIds(
send_encrypted_header_extension_ids_);
recv_session_->SetEncryptedHeaderExtensionIds(
recv_encrypted_header_extension_ids_);
ret = (send_session_->SetSend(
rtc::SrtpCryptoSuiteFromName(send_params.cipher_suite),
send_key.data(), send_key.size()) &&
recv_session_->SetRecv(
rtc::SrtpCryptoSuiteFromName(recv_params.cipher_suite),
recv_key.data(), recv_key.size()));
}
if (ret) {
LOG(LS_INFO) << "SRTP activated with negotiated parameters:"
<< " send cipher_suite " << send_params.cipher_suite
<< " recv cipher_suite " << recv_params.cipher_suite;
applied_send_params_ = send_params;
applied_recv_params_ = recv_params;
} else {
LOG(LS_WARNING) << "Failed to apply negotiated SRTP parameters";
}
return ret;
}
bool SrtpFilter::ResetParams() {
offer_params_.clear();
state_ = ST_INIT;
send_session_ = nullptr;
recv_session_ = nullptr;
send_rtcp_session_ = nullptr;
recv_rtcp_session_ = nullptr;
LOG(LS_INFO) << "SRTP reset to init state";
return true;
} }
bool SrtpFilter::ParseKeyParams(const std::string& key_params, bool SrtpFilter::ParseKeyParams(const std::string& key_params,

View File

@ -17,11 +17,9 @@
#include <string> #include <string>
#include <vector> #include <vector>
#include "api/optional.h"
#include "media/base/cryptoparams.h" #include "media/base/cryptoparams.h"
#include "p2p/base/sessiondescription.h" #include "p2p/base/sessiondescription.h"
#include "rtc_base/basictypes.h" #include "rtc_base/basictypes.h"
#include "rtc_base/buffer.h"
#include "rtc_base/constructormagic.h" #include "rtc_base/constructormagic.h"
#include "rtc_base/criticalsection.h" #include "rtc_base/criticalsection.h"
#include "rtc_base/sslstreamadapter.h" #include "rtc_base/sslstreamadapter.h"
@ -33,10 +31,16 @@ struct srtp_ctx_t_;
namespace cricket { namespace cricket {
class SrtpSession;
void ShutdownSrtp(); void ShutdownSrtp();
// A helper class used to negotiate SDES crypto params. // Class to transform SRTP to/from RTP.
// TODO(zhihuang): Find a better name for this class, like "SdesNegotiator". // Initialize by calling SetSend with the local security params, then
// call
// SetRecv once the remote security params are received. At that point
// Protect/UnprotectRt(c)p can be called to encrypt/decrypt data.
// TODO: Figure out concurrency policy for SrtpFilter.
class SrtpFilter { class SrtpFilter {
public: public:
enum Mode { enum Mode {
@ -73,14 +77,66 @@ class SrtpFilter {
bool SetAnswer(const std::vector<CryptoParams>& answer_params, bool SetAnswer(const std::vector<CryptoParams>& answer_params,
ContentSource source); ContentSource source);
// Set the header extension ids that should be encrypted for the given
// source.
void SetEncryptedHeaderExtensionIds(ContentSource source,
const std::vector<int>& extension_ids);
// Just set up both sets of keys directly.
// Used with DTLS-SRTP.
bool SetRtpParams(int send_cs,
const uint8_t* send_key,
int send_key_len,
int recv_cs,
const uint8_t* recv_key,
int recv_key_len);
bool UpdateRtpParams(int send_cs,
const uint8_t* send_key,
int send_key_len,
int recv_cs,
const uint8_t* recv_key,
int recv_key_len);
bool SetRtcpParams(int send_cs,
const uint8_t* send_key,
int send_key_len,
int recv_cs,
const uint8_t* recv_key,
int recv_key_len);
// Encrypts/signs an individual RTP/RTCP packet, in-place.
// If an HMAC is used, this will increase the packet size.
bool ProtectRtp(void* data, int in_len, int max_len, int* out_len);
// Overloaded version, outputs packet index.
bool ProtectRtp(void* data,
int in_len,
int max_len,
int* out_len,
int64_t* index);
bool ProtectRtcp(void* data, int in_len, int max_len, int* out_len);
// Decrypts/verifies an invidiual RTP/RTCP packet.
// If an HMAC is used, this will decrease the packet size.
bool UnprotectRtp(void* data, int in_len, int* out_len);
bool UnprotectRtcp(void* data, int in_len, int* out_len);
// Returns rtp auth params from srtp context.
bool GetRtpAuthParams(uint8_t** key, int* key_len, int* tag_len);
// Returns srtp overhead for rtp packets.
bool GetSrtpOverhead(int* srtp_overhead) const;
// If external auth is enabled, SRTP will write a dummy auth tag that
// then
// later must get replaced before the packet is sent out. Only
// supported for
// non-GCM cipher suites and can be checked through
// "IsExternalAuthActive"
// if it is actually used. This method is only valid before the RTP
// params
// have been set.
void EnableExternalAuth();
bool IsExternalAuthEnabled() const;
// A SRTP filter supports external creation of the auth tag if a non-GCM
// cipher is used. This method is only valid after the RTP params have
// been set.
bool IsExternalAuthActive() const;
bool ResetParams(); bool ResetParams();
rtc::Optional<int> send_cipher_suite() { return send_cipher_suite_; }
rtc::Optional<int> recv_cipher_suite() { return recv_cipher_suite_; }
const rtc::Buffer& send_key() { return send_key_; }
const rtc::Buffer& recv_key() { return recv_key_; }
protected: protected:
bool ExpectOffer(ContentSource source); bool ExpectOffer(ContentSource source);
@ -93,18 +149,17 @@ class SrtpFilter {
ContentSource source, ContentSource source,
bool final); bool final);
void CreateSrtpSessions();
bool NegotiateParams(const std::vector<CryptoParams>& answer_params, bool NegotiateParams(const std::vector<CryptoParams>& answer_params,
CryptoParams* selected_params); CryptoParams* selected_params);
private: bool ApplyParams(const CryptoParams& send_params,
bool ApplySendParams(const CryptoParams& send_params); const CryptoParams& recv_params);
bool ApplyRecvParams(const CryptoParams& recv_params);
static bool ParseKeyParams(const std::string& params, static bool ParseKeyParams(const std::string& params,
uint8_t* key, uint8_t* key,
size_t len); size_t len);
private:
enum State { enum State {
ST_INIT, // SRTP filter unused. ST_INIT, // SRTP filter unused.
ST_SENTOFFER, // Offer with SRTP parameters sent. ST_SENTOFFER, // Offer with SRTP parameters sent.
@ -129,13 +184,16 @@ class SrtpFilter {
ST_RECEIVEDPRANSWER ST_RECEIVEDPRANSWER
}; };
State state_ = ST_INIT; State state_ = ST_INIT;
bool external_auth_enabled_ = false;
std::vector<CryptoParams> offer_params_; std::vector<CryptoParams> offer_params_;
std::unique_ptr<SrtpSession> send_session_;
std::unique_ptr<SrtpSession> recv_session_;
std::unique_ptr<SrtpSession> send_rtcp_session_;
std::unique_ptr<SrtpSession> recv_rtcp_session_;
CryptoParams applied_send_params_; CryptoParams applied_send_params_;
CryptoParams applied_recv_params_; CryptoParams applied_recv_params_;
rtc::Optional<int> send_cipher_suite_; std::vector<int> send_encrypted_header_extension_ids_;
rtc::Optional<int> recv_cipher_suite_; std::vector<int> recv_encrypted_header_extension_ids_;
rtc::Buffer send_key_;
rtc::Buffer recv_key_;
}; };
} // namespace cricket } // namespace cricket

View File

@ -13,7 +13,14 @@
#include "pc/srtpfilter.h" #include "pc/srtpfilter.h"
#include "media/base/cryptoparams.h" #include "media/base/cryptoparams.h"
#include "media/base/fakertp.h"
#include "p2p/base/sessiondescription.h"
#include "pc/srtptestutil.h"
#include "rtc_base/buffer.h"
#include "rtc_base/byteorder.h"
#include "rtc_base/constructormagic.h"
#include "rtc_base/gunit.h" #include "rtc_base/gunit.h"
#include "rtc_base/thread.h"
using cricket::CryptoParams; using cricket::CryptoParams;
using cricket::CS_LOCAL; using cricket::CS_LOCAL;
@ -21,6 +28,14 @@ using cricket::CS_REMOTE;
namespace rtc { namespace rtc {
static const uint8_t kTestKeyGcm128_1[] = "ABCDEFGHIJKLMNOPQRSTUVWXYZ12";
static const uint8_t kTestKeyGcm128_2[] = "21ZYXWVUTSRQPONMLKJIHGFEDCBA";
static const int kTestKeyGcm128Len = 28; // 128 bits key + 96 bits salt.
static const uint8_t kTestKeyGcm256_1[] =
"ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqr";
static const uint8_t kTestKeyGcm256_2[] =
"rqponmlkjihgfedcbaZYXWVUTSRQPONMLKJIHGFEDCBA";
static const int kTestKeyGcm256Len = 44; // 256 bits key + 96 bits salt.
static const std::string kTestKeyParams1 = static const std::string kTestKeyParams1 =
"inline:WVNfX19zZW1jdGwgKCkgewkyMjA7fQp9CnVubGVz"; "inline:WVNfX19zZW1jdGwgKCkgewkyMjA7fQp9CnVubGVz";
static const std::string kTestKeyParams2 = static const std::string kTestKeyParams2 =
@ -52,13 +67,14 @@ static const cricket::CryptoParams kTestCryptoParamsGcm4(
class SrtpFilterTest : public testing::Test { class SrtpFilterTest : public testing::Test {
protected: protected:
SrtpFilterTest() {} SrtpFilterTest()
// Need to initialize |sequence_number_|, the value does not matter.
: sequence_number_(1) {}
static std::vector<CryptoParams> MakeVector(const CryptoParams& params) { static std::vector<CryptoParams> MakeVector(const CryptoParams& params) {
std::vector<CryptoParams> vec; std::vector<CryptoParams> vec;
vec.push_back(params); vec.push_back(params);
return vec; return vec;
} }
void TestSetParams(const std::vector<CryptoParams>& params1, void TestSetParams(const std::vector<CryptoParams>& params1,
const std::vector<CryptoParams>& params2) { const std::vector<CryptoParams>& params2) {
EXPECT_TRUE(f1_.SetOffer(params1, CS_LOCAL)); EXPECT_TRUE(f1_.SetOffer(params1, CS_LOCAL));
@ -70,16 +86,186 @@ class SrtpFilterTest : public testing::Test {
EXPECT_TRUE(f1_.IsActive()); EXPECT_TRUE(f1_.IsActive());
EXPECT_TRUE(f2_.IsActive()); EXPECT_TRUE(f2_.IsActive());
} }
void TestRtpAuthParams(cricket::SrtpFilter* filter, const std::string& cs) {
void VerifyCryptoParamsMatch(const std::string& cs1, const std::string& cs2) { int overhead;
EXPECT_EQ(rtc::SrtpCryptoSuiteFromName(cs1), f1_.send_cipher_suite()); EXPECT_TRUE(filter->GetSrtpOverhead(&overhead));
EXPECT_EQ(rtc::SrtpCryptoSuiteFromName(cs2), f2_.send_cipher_suite()); switch (SrtpCryptoSuiteFromName(cs)) {
EXPECT_TRUE(f1_.send_key() == f2_.recv_key()); case SRTP_AES128_CM_SHA1_32:
EXPECT_TRUE(f2_.send_key() == f1_.recv_key()); EXPECT_EQ(32 / 8, overhead); // 32-bit tag.
break;
case SRTP_AES128_CM_SHA1_80:
EXPECT_EQ(80 / 8, overhead); // 80-bit tag.
break;
default:
RTC_NOTREACHED();
break;
} }
uint8_t* auth_key = nullptr;
int key_len = 0;
int tag_len = 0;
EXPECT_TRUE(filter->GetRtpAuthParams(&auth_key, &key_len, &tag_len));
EXPECT_NE(nullptr, auth_key);
EXPECT_EQ(160 / 8, key_len); // Length of SHA-1 is 160 bits.
EXPECT_EQ(overhead, tag_len);
}
void TestProtectUnprotect(const std::string& cs1, const std::string& cs2) {
Buffer rtp_buffer(sizeof(kPcmuFrame) + rtp_auth_tag_len(cs1));
char* rtp_packet = rtp_buffer.data<char>();
char original_rtp_packet[sizeof(kPcmuFrame)];
Buffer rtcp_buffer(sizeof(kRtcpReport) + 4 + rtcp_auth_tag_len(cs2));
char* rtcp_packet = rtcp_buffer.data<char>();
int rtp_len = sizeof(kPcmuFrame), rtcp_len = sizeof(kRtcpReport), out_len;
memcpy(rtp_packet, kPcmuFrame, rtp_len);
// In order to be able to run this test function multiple times we can not
// use the same sequence number twice. Increase the sequence number by one.
SetBE16(reinterpret_cast<uint8_t*>(rtp_packet) + 2, ++sequence_number_);
memcpy(original_rtp_packet, rtp_packet, rtp_len);
memcpy(rtcp_packet, kRtcpReport, rtcp_len);
EXPECT_TRUE(f1_.ProtectRtp(rtp_packet, rtp_len,
static_cast<int>(rtp_buffer.size()), &out_len));
EXPECT_EQ(out_len, rtp_len + rtp_auth_tag_len(cs1));
EXPECT_NE(0, memcmp(rtp_packet, original_rtp_packet, rtp_len));
if (!f1_.IsExternalAuthActive()) {
EXPECT_TRUE(f2_.UnprotectRtp(rtp_packet, out_len, &out_len));
EXPECT_EQ(rtp_len, out_len);
EXPECT_EQ(0, memcmp(rtp_packet, original_rtp_packet, rtp_len));
} else {
// With external auth enabled, SRTP doesn't write the auth tag and
// unprotect would fail. Check accessing the information about the
// tag instead, similar to what the actual code would do that relies
// on external auth.
TestRtpAuthParams(&f1_, cs1);
}
EXPECT_TRUE(f2_.ProtectRtp(rtp_packet, rtp_len,
static_cast<int>(rtp_buffer.size()), &out_len));
EXPECT_EQ(out_len, rtp_len + rtp_auth_tag_len(cs2));
EXPECT_NE(0, memcmp(rtp_packet, original_rtp_packet, rtp_len));
if (!f2_.IsExternalAuthActive()) {
EXPECT_TRUE(f1_.UnprotectRtp(rtp_packet, out_len, &out_len));
EXPECT_EQ(rtp_len, out_len);
EXPECT_EQ(0, memcmp(rtp_packet, original_rtp_packet, rtp_len));
} else {
TestRtpAuthParams(&f2_, cs2);
}
EXPECT_TRUE(f1_.ProtectRtcp(
rtcp_packet, rtcp_len, static_cast<int>(rtcp_buffer.size()), &out_len));
EXPECT_EQ(out_len, rtcp_len + 4 + rtcp_auth_tag_len(cs1)); // NOLINT
EXPECT_NE(0, memcmp(rtcp_packet, kRtcpReport, rtcp_len));
EXPECT_TRUE(f2_.UnprotectRtcp(rtcp_packet, out_len, &out_len));
EXPECT_EQ(rtcp_len, out_len);
EXPECT_EQ(0, memcmp(rtcp_packet, kRtcpReport, rtcp_len));
EXPECT_TRUE(f2_.ProtectRtcp(
rtcp_packet, rtcp_len, static_cast<int>(rtcp_buffer.size()), &out_len));
EXPECT_EQ(out_len, rtcp_len + 4 + rtcp_auth_tag_len(cs2)); // NOLINT
EXPECT_NE(0, memcmp(rtcp_packet, kRtcpReport, rtcp_len));
EXPECT_TRUE(f1_.UnprotectRtcp(rtcp_packet, out_len, &out_len));
EXPECT_EQ(rtcp_len, out_len);
EXPECT_EQ(0, memcmp(rtcp_packet, kRtcpReport, rtcp_len));
}
void TestProtectUnprotectHeaderEncryption(
const std::string& cs1,
const std::string& cs2,
const std::vector<int>& encrypted_header_ids) {
Buffer rtp_buffer(sizeof(kPcmuFrameWithExtensions) + rtp_auth_tag_len(cs1));
char* rtp_packet = rtp_buffer.data<char>();
size_t rtp_packet_size = rtp_buffer.size();
char original_rtp_packet[sizeof(kPcmuFrameWithExtensions)];
size_t original_rtp_packet_size = sizeof(original_rtp_packet);
int rtp_len = sizeof(kPcmuFrameWithExtensions), out_len;
memcpy(rtp_packet, kPcmuFrameWithExtensions, rtp_len);
// In order to be able to run this test function multiple times we can not
// use the same sequence number twice. Increase the sequence number by one.
SetBE16(reinterpret_cast<uint8_t*>(rtp_packet) + 2, ++sequence_number_);
memcpy(original_rtp_packet, rtp_packet, rtp_len);
EXPECT_TRUE(f1_.ProtectRtp(rtp_packet, rtp_len,
static_cast<int>(rtp_buffer.size()), &out_len));
EXPECT_EQ(out_len, rtp_len + rtp_auth_tag_len(cs1));
EXPECT_NE(0, memcmp(rtp_packet, original_rtp_packet, rtp_len));
CompareHeaderExtensions(rtp_packet, rtp_packet_size, original_rtp_packet,
original_rtp_packet_size, encrypted_header_ids,
false);
EXPECT_TRUE(f2_.UnprotectRtp(rtp_packet, out_len, &out_len));
EXPECT_EQ(rtp_len, out_len);
EXPECT_EQ(0, memcmp(rtp_packet, original_rtp_packet, rtp_len));
CompareHeaderExtensions(rtp_packet, rtp_packet_size, original_rtp_packet,
original_rtp_packet_size, encrypted_header_ids,
true);
EXPECT_TRUE(f2_.ProtectRtp(rtp_packet, rtp_len,
static_cast<int>(rtp_buffer.size()), &out_len));
EXPECT_EQ(out_len, rtp_len + rtp_auth_tag_len(cs2));
EXPECT_NE(0, memcmp(rtp_packet, original_rtp_packet, rtp_len));
CompareHeaderExtensions(rtp_packet, rtp_packet_size, original_rtp_packet,
original_rtp_packet_size, encrypted_header_ids,
false);
EXPECT_TRUE(f1_.UnprotectRtp(rtp_packet, out_len, &out_len));
EXPECT_EQ(rtp_len, out_len);
EXPECT_EQ(0, memcmp(rtp_packet, original_rtp_packet, rtp_len));
CompareHeaderExtensions(rtp_packet, rtp_packet_size, original_rtp_packet,
original_rtp_packet_size, encrypted_header_ids,
true);
}
void TestProtectSetParamsDirect(bool enable_external_auth,
int cs,
const uint8_t* key1,
int key1_len,
const uint8_t* key2,
int key2_len,
const std::string& cs_name) {
EXPECT_EQ(key1_len, key2_len);
EXPECT_EQ(cs_name, SrtpCryptoSuiteToName(cs));
if (enable_external_auth) {
f1_.EnableExternalAuth();
f2_.EnableExternalAuth();
}
EXPECT_TRUE(f1_.SetRtpParams(cs, key1, key1_len, cs, key2, key2_len));
EXPECT_TRUE(f2_.SetRtpParams(cs, key2, key2_len, cs, key1, key1_len));
EXPECT_TRUE(f1_.SetRtcpParams(cs, key1, key1_len, cs, key2, key2_len));
EXPECT_TRUE(f2_.SetRtcpParams(cs, key2, key2_len, cs, key1, key1_len));
EXPECT_TRUE(f1_.IsActive());
EXPECT_TRUE(f2_.IsActive());
if (IsGcmCryptoSuite(cs)) {
EXPECT_FALSE(f1_.IsExternalAuthActive());
EXPECT_FALSE(f2_.IsExternalAuthActive());
} else if (enable_external_auth) {
EXPECT_TRUE(f1_.IsExternalAuthActive());
EXPECT_TRUE(f2_.IsExternalAuthActive());
}
TestProtectUnprotect(cs_name, cs_name);
}
void TestProtectSetParamsDirectHeaderEncryption(int cs,
const uint8_t* key1,
int key1_len,
const uint8_t* key2,
int key2_len,
const std::string& cs_name) {
std::vector<int> encrypted_headers;
encrypted_headers.push_back(1);
// Don't encrypt header ids 2 and 3.
encrypted_headers.push_back(4);
EXPECT_EQ(key1_len, key2_len);
EXPECT_EQ(cs_name, SrtpCryptoSuiteToName(cs));
f1_.SetEncryptedHeaderExtensionIds(CS_LOCAL, encrypted_headers);
f1_.SetEncryptedHeaderExtensionIds(CS_REMOTE, encrypted_headers);
f2_.SetEncryptedHeaderExtensionIds(CS_LOCAL, encrypted_headers);
f2_.SetEncryptedHeaderExtensionIds(CS_REMOTE, encrypted_headers);
EXPECT_TRUE(f1_.SetRtpParams(cs, key1, key1_len, cs, key2, key2_len));
EXPECT_TRUE(f2_.SetRtpParams(cs, key2, key2_len, cs, key1, key1_len));
EXPECT_TRUE(f1_.IsActive());
EXPECT_TRUE(f2_.IsActive());
EXPECT_FALSE(f1_.IsExternalAuthActive());
EXPECT_FALSE(f2_.IsExternalAuthActive());
TestProtectUnprotectHeaderEncryption(cs_name, cs_name, encrypted_headers);
}
cricket::SrtpFilter f1_; cricket::SrtpFilter f1_;
cricket::SrtpFilter f2_; cricket::SrtpFilter f2_;
int sequence_number_;
}; };
// Test that we can set up the session and keys properly. // Test that we can set up the session and keys properly.
@ -293,6 +479,21 @@ TEST_F(SrtpFilterTest, TestUnsupportedOptions) {
EXPECT_FALSE(f1_.IsActive()); EXPECT_FALSE(f1_.IsActive());
} }
// Test that we can encrypt/decrypt after setting the same CryptoParams again on
// one side.
TEST_F(SrtpFilterTest, TestSettingSameKeyOnOneSide) {
std::vector<CryptoParams> offer(MakeVector(kTestCryptoParams1));
std::vector<CryptoParams> answer(MakeVector(kTestCryptoParams2));
TestSetParams(offer, answer);
TestProtectUnprotect(CS_AES_CM_128_HMAC_SHA1_80, CS_AES_CM_128_HMAC_SHA1_80);
// Re-applying the same keys on one end and it should not reset the ROC.
EXPECT_TRUE(f2_.SetOffer(offer, CS_REMOTE));
EXPECT_TRUE(f2_.SetAnswer(answer, CS_LOCAL));
TestProtectUnprotect(CS_AES_CM_128_HMAC_SHA1_80, CS_AES_CM_128_HMAC_SHA1_80);
}
// Test that we can encrypt/decrypt after negotiating AES_CM_128_HMAC_SHA1_80. // Test that we can encrypt/decrypt after negotiating AES_CM_128_HMAC_SHA1_80.
TEST_F(SrtpFilterTest, TestProtect_AES_CM_128_HMAC_SHA1_80) { TEST_F(SrtpFilterTest, TestProtect_AES_CM_128_HMAC_SHA1_80) {
std::vector<CryptoParams> offer(MakeVector(kTestCryptoParams1)); std::vector<CryptoParams> offer(MakeVector(kTestCryptoParams1));
@ -301,8 +502,7 @@ TEST_F(SrtpFilterTest, TestProtect_AES_CM_128_HMAC_SHA1_80) {
offer[1].tag = 2; offer[1].tag = 2;
offer[1].cipher_suite = CS_AES_CM_128_HMAC_SHA1_32; offer[1].cipher_suite = CS_AES_CM_128_HMAC_SHA1_32;
TestSetParams(offer, answer); TestSetParams(offer, answer);
VerifyCryptoParamsMatch(CS_AES_CM_128_HMAC_SHA1_80, TestProtectUnprotect(CS_AES_CM_128_HMAC_SHA1_80, CS_AES_CM_128_HMAC_SHA1_80);
CS_AES_CM_128_HMAC_SHA1_80);
} }
// Test that we can encrypt/decrypt after negotiating AES_CM_128_HMAC_SHA1_32. // Test that we can encrypt/decrypt after negotiating AES_CM_128_HMAC_SHA1_32.
@ -315,8 +515,7 @@ TEST_F(SrtpFilterTest, TestProtect_AES_CM_128_HMAC_SHA1_32) {
answer[0].tag = 2; answer[0].tag = 2;
answer[0].cipher_suite = CS_AES_CM_128_HMAC_SHA1_32; answer[0].cipher_suite = CS_AES_CM_128_HMAC_SHA1_32;
TestSetParams(offer, answer); TestSetParams(offer, answer);
VerifyCryptoParamsMatch(CS_AES_CM_128_HMAC_SHA1_32, TestProtectUnprotect(CS_AES_CM_128_HMAC_SHA1_32, CS_AES_CM_128_HMAC_SHA1_32);
CS_AES_CM_128_HMAC_SHA1_32);
} }
// Test that we can change encryption parameters. // Test that we can change encryption parameters.
@ -325,8 +524,7 @@ TEST_F(SrtpFilterTest, TestChangeParameters) {
std::vector<CryptoParams> answer(MakeVector(kTestCryptoParams2)); std::vector<CryptoParams> answer(MakeVector(kTestCryptoParams2));
TestSetParams(offer, answer); TestSetParams(offer, answer);
VerifyCryptoParamsMatch(CS_AES_CM_128_HMAC_SHA1_80, TestProtectUnprotect(CS_AES_CM_128_HMAC_SHA1_80, CS_AES_CM_128_HMAC_SHA1_80);
CS_AES_CM_128_HMAC_SHA1_80);
// Change the key parameters and cipher_suite. // Change the key parameters and cipher_suite.
offer[0].key_params = kTestKeyParams3; offer[0].key_params = kTestKeyParams3;
@ -340,15 +538,13 @@ TEST_F(SrtpFilterTest, TestChangeParameters) {
EXPECT_TRUE(f1_.IsActive()); EXPECT_TRUE(f1_.IsActive());
// Test that the old keys are valid until the negotiation is complete. // Test that the old keys are valid until the negotiation is complete.
VerifyCryptoParamsMatch(CS_AES_CM_128_HMAC_SHA1_80, TestProtectUnprotect(CS_AES_CM_128_HMAC_SHA1_80, CS_AES_CM_128_HMAC_SHA1_80);
CS_AES_CM_128_HMAC_SHA1_80);
// Complete the negotiation and test that we can still understand each other. // Complete the negotiation and test that we can still understand each other.
EXPECT_TRUE(f2_.SetAnswer(answer, CS_LOCAL)); EXPECT_TRUE(f2_.SetAnswer(answer, CS_LOCAL));
EXPECT_TRUE(f1_.SetAnswer(answer, CS_REMOTE)); EXPECT_TRUE(f1_.SetAnswer(answer, CS_REMOTE));
VerifyCryptoParamsMatch(CS_AES_CM_128_HMAC_SHA1_32, TestProtectUnprotect(CS_AES_CM_128_HMAC_SHA1_32, CS_AES_CM_128_HMAC_SHA1_32);
CS_AES_CM_128_HMAC_SHA1_32);
} }
// Test that we can send and receive provisional answers with crypto enabled. // Test that we can send and receive provisional answers with crypto enabled.
@ -368,8 +564,7 @@ TEST_F(SrtpFilterTest, TestProvisionalAnswer) {
EXPECT_TRUE(f1_.SetProvisionalAnswer(answer, CS_REMOTE)); EXPECT_TRUE(f1_.SetProvisionalAnswer(answer, CS_REMOTE));
EXPECT_TRUE(f1_.IsActive()); EXPECT_TRUE(f1_.IsActive());
EXPECT_TRUE(f2_.IsActive()); EXPECT_TRUE(f2_.IsActive());
VerifyCryptoParamsMatch(CS_AES_CM_128_HMAC_SHA1_80, TestProtectUnprotect(CS_AES_CM_128_HMAC_SHA1_80, CS_AES_CM_128_HMAC_SHA1_80);
CS_AES_CM_128_HMAC_SHA1_80);
answer[0].key_params = kTestKeyParams4; answer[0].key_params = kTestKeyParams4;
answer[0].tag = 2; answer[0].tag = 2;
@ -378,8 +573,7 @@ TEST_F(SrtpFilterTest, TestProvisionalAnswer) {
EXPECT_TRUE(f1_.SetAnswer(answer, CS_REMOTE)); EXPECT_TRUE(f1_.SetAnswer(answer, CS_REMOTE));
EXPECT_TRUE(f1_.IsActive()); EXPECT_TRUE(f1_.IsActive());
EXPECT_TRUE(f2_.IsActive()); EXPECT_TRUE(f2_.IsActive());
VerifyCryptoParamsMatch(CS_AES_CM_128_HMAC_SHA1_32, TestProtectUnprotect(CS_AES_CM_128_HMAC_SHA1_32, CS_AES_CM_128_HMAC_SHA1_32);
CS_AES_CM_128_HMAC_SHA1_32);
} }
// Test that a provisional answer doesn't need to contain a crypto. // Test that a provisional answer doesn't need to contain a crypto.
@ -401,8 +595,7 @@ TEST_F(SrtpFilterTest, TestProvisionalAnswerWithoutCrypto) {
EXPECT_TRUE(f1_.SetAnswer(answer, CS_REMOTE)); EXPECT_TRUE(f1_.SetAnswer(answer, CS_REMOTE));
EXPECT_TRUE(f1_.IsActive()); EXPECT_TRUE(f1_.IsActive());
EXPECT_TRUE(f2_.IsActive()); EXPECT_TRUE(f2_.IsActive());
VerifyCryptoParamsMatch(CS_AES_CM_128_HMAC_SHA1_80, TestProtectUnprotect(CS_AES_CM_128_HMAC_SHA1_80, CS_AES_CM_128_HMAC_SHA1_80);
CS_AES_CM_128_HMAC_SHA1_80);
} }
// Test that if we get a new local offer after a provisional answer // Test that if we get a new local offer after a provisional answer
@ -429,8 +622,7 @@ TEST_F(SrtpFilterTest, TestLocalOfferAfterProvisionalAnswerWithoutCrypto) {
EXPECT_TRUE(f1_.SetAnswer(answer, CS_REMOTE)); EXPECT_TRUE(f1_.SetAnswer(answer, CS_REMOTE));
EXPECT_TRUE(f1_.IsActive()); EXPECT_TRUE(f1_.IsActive());
EXPECT_TRUE(f2_.IsActive()); EXPECT_TRUE(f2_.IsActive());
VerifyCryptoParamsMatch(CS_AES_CM_128_HMAC_SHA1_80, TestProtectUnprotect(CS_AES_CM_128_HMAC_SHA1_80, CS_AES_CM_128_HMAC_SHA1_80);
CS_AES_CM_128_HMAC_SHA1_80);
} }
// Test that we can disable encryption. // Test that we can disable encryption.
@ -439,8 +631,7 @@ TEST_F(SrtpFilterTest, TestDisableEncryption) {
std::vector<CryptoParams> answer(MakeVector(kTestCryptoParams2)); std::vector<CryptoParams> answer(MakeVector(kTestCryptoParams2));
TestSetParams(offer, answer); TestSetParams(offer, answer);
VerifyCryptoParamsMatch(CS_AES_CM_128_HMAC_SHA1_80, TestProtectUnprotect(CS_AES_CM_128_HMAC_SHA1_80, CS_AES_CM_128_HMAC_SHA1_80);
CS_AES_CM_128_HMAC_SHA1_80);
offer.clear(); offer.clear();
answer.clear(); answer.clear();
@ -450,8 +641,7 @@ TEST_F(SrtpFilterTest, TestDisableEncryption) {
EXPECT_TRUE(f2_.IsActive()); EXPECT_TRUE(f2_.IsActive());
// Test that the old keys are valid until the negotiation is complete. // Test that the old keys are valid until the negotiation is complete.
VerifyCryptoParamsMatch(CS_AES_CM_128_HMAC_SHA1_80, TestProtectUnprotect(CS_AES_CM_128_HMAC_SHA1_80, CS_AES_CM_128_HMAC_SHA1_80);
CS_AES_CM_128_HMAC_SHA1_80);
// Complete the negotiation. // Complete the negotiation.
EXPECT_TRUE(f2_.SetAnswer(answer, CS_LOCAL)); EXPECT_TRUE(f2_.SetAnswer(answer, CS_LOCAL));
@ -461,4 +651,85 @@ TEST_F(SrtpFilterTest, TestDisableEncryption) {
EXPECT_FALSE(f2_.IsActive()); EXPECT_FALSE(f2_.IsActive());
} }
class SrtpFilterProtectSetParamsDirectTest
: public SrtpFilterTest,
public testing::WithParamInterface<bool> {};
// Test directly setting the params with AES_CM_128_HMAC_SHA1_80.
TEST_P(SrtpFilterProtectSetParamsDirectTest, Test_AES_CM_128_HMAC_SHA1_80) {
bool enable_external_auth = GetParam();
TestProtectSetParamsDirect(enable_external_auth, SRTP_AES128_CM_SHA1_80,
kTestKey1, kTestKeyLen, kTestKey2, kTestKeyLen,
CS_AES_CM_128_HMAC_SHA1_80);
}
TEST_F(SrtpFilterTest,
TestProtectSetParamsDirectHeaderEncryption_AES_CM_128_HMAC_SHA1_80) {
TestProtectSetParamsDirectHeaderEncryption(
SRTP_AES128_CM_SHA1_80, kTestKey1, kTestKeyLen, kTestKey2, kTestKeyLen,
CS_AES_CM_128_HMAC_SHA1_80);
}
// Test directly setting the params with AES_CM_128_HMAC_SHA1_32.
TEST_P(SrtpFilterProtectSetParamsDirectTest, Test_AES_CM_128_HMAC_SHA1_32) {
bool enable_external_auth = GetParam();
TestProtectSetParamsDirect(enable_external_auth, SRTP_AES128_CM_SHA1_32,
kTestKey1, kTestKeyLen, kTestKey2, kTestKeyLen,
CS_AES_CM_128_HMAC_SHA1_32);
}
TEST_F(SrtpFilterTest,
TestProtectSetParamsDirectHeaderEncryption_AES_CM_128_HMAC_SHA1_32) {
TestProtectSetParamsDirectHeaderEncryption(
SRTP_AES128_CM_SHA1_32, kTestKey1, kTestKeyLen, kTestKey2, kTestKeyLen,
CS_AES_CM_128_HMAC_SHA1_32);
}
// Test directly setting the params with SRTP_AEAD_AES_128_GCM.
TEST_P(SrtpFilterProtectSetParamsDirectTest, Test_SRTP_AEAD_AES_128_GCM) {
bool enable_external_auth = GetParam();
TestProtectSetParamsDirect(enable_external_auth, SRTP_AEAD_AES_128_GCM,
kTestKeyGcm128_1, kTestKeyGcm128Len,
kTestKeyGcm128_2, kTestKeyGcm128Len,
CS_AEAD_AES_128_GCM);
}
TEST_F(SrtpFilterTest,
TestProtectSetParamsDirectHeaderEncryption_SRTP_AEAD_AES_128_GCM) {
TestProtectSetParamsDirectHeaderEncryption(
SRTP_AEAD_AES_128_GCM, kTestKeyGcm128_1, kTestKeyGcm128Len,
kTestKeyGcm128_2, kTestKeyGcm128Len, CS_AEAD_AES_128_GCM);
}
// Test directly setting the params with SRTP_AEAD_AES_256_GCM.
TEST_P(SrtpFilterProtectSetParamsDirectTest, Test_SRTP_AEAD_AES_256_GCM) {
bool enable_external_auth = GetParam();
TestProtectSetParamsDirect(enable_external_auth, SRTP_AEAD_AES_256_GCM,
kTestKeyGcm256_1, kTestKeyGcm256Len,
kTestKeyGcm256_2, kTestKeyGcm256Len,
CS_AEAD_AES_256_GCM);
}
TEST_F(SrtpFilterTest,
TestProtectSetParamsDirectHeaderEncryption_SRTP_AEAD_AES_256_GCM) {
TestProtectSetParamsDirectHeaderEncryption(
SRTP_AEAD_AES_256_GCM, kTestKeyGcm256_1, kTestKeyGcm256Len,
kTestKeyGcm256_2, kTestKeyGcm256Len, CS_AEAD_AES_256_GCM);
}
// Run all tests both with and without external auth enabled.
INSTANTIATE_TEST_CASE_P(ExternalAuth,
SrtpFilterProtectSetParamsDirectTest,
::testing::Values(true, false));
// Test directly setting the params with bogus keys.
TEST_F(SrtpFilterTest, TestSetParamsKeyTooShort) {
EXPECT_FALSE(f1_.SetRtpParams(SRTP_AES128_CM_SHA1_80, kTestKey1,
kTestKeyLen - 1, SRTP_AES128_CM_SHA1_80,
kTestKey1, kTestKeyLen - 1));
EXPECT_FALSE(f1_.SetRtcpParams(SRTP_AES128_CM_SHA1_80, kTestKey1,
kTestKeyLen - 1, SRTP_AES128_CM_SHA1_80,
kTestKey1, kTestKeyLen - 1));
}
} // namespace rtc } // namespace rtc

View File

@ -16,7 +16,6 @@
#include "pc/rtptransport.h" #include "pc/rtptransport.h"
#include "pc/srtpsession.h" #include "pc/srtpsession.h"
#include "rtc_base/asyncpacketsocket.h" #include "rtc_base/asyncpacketsocket.h"
#include "rtc_base/base64.h"
#include "rtc_base/copyonwritebuffer.h" #include "rtc_base/copyonwritebuffer.h"
#include "rtc_base/ptr_util.h" #include "rtc_base/ptr_util.h"
#include "rtc_base/trace_event.h" #include "rtc_base/trace_event.h"
@ -43,322 +42,21 @@ void SrtpTransport::ConnectToRtpTransport() {
&SrtpTransport::OnReadyToSend); &SrtpTransport::OnReadyToSend);
} }
bool SrtpTransport::SendRtpPacket(rtc::CopyOnWriteBuffer* packet,
const rtc::PacketOptions& options,
int flags) {
return SendPacket(false, packet, options, flags);
}
bool SrtpTransport::SendRtcpPacket(rtc::CopyOnWriteBuffer* packet,
const rtc::PacketOptions& options,
int flags) {
return SendPacket(true, packet, options, flags);
}
bool SrtpTransport::SendPacket(bool rtcp, bool SrtpTransport::SendPacket(bool rtcp,
rtc::CopyOnWriteBuffer* packet, rtc::CopyOnWriteBuffer* packet,
const rtc::PacketOptions& options, const rtc::PacketOptions& options,
int flags) { int flags) {
if (!IsActive()) { // TODO(zstein): Protect packet.
LOG(LS_ERROR)
<< "Failed to send the packet because SRTP transport is inactive.";
return false;
}
rtc::PacketOptions updated_options = options; return rtp_transport_->SendPacket(rtcp, packet, options, flags);
rtc::CopyOnWriteBuffer cp = *packet;
TRACE_EVENT0("webrtc", "SRTP Encode");
bool res;
uint8_t* data = packet->data();
int len = static_cast<int>(packet->size());
if (!rtcp) {
// If ENABLE_EXTERNAL_AUTH flag is on then packet authentication is not done
// inside libsrtp for a RTP packet. A external HMAC module will be writing
// a fake HMAC value. This is ONLY done for a RTP packet.
// Socket layer will update rtp sendtime extension header if present in
// packet with current time before updating the HMAC.
#if !defined(ENABLE_EXTERNAL_AUTH)
res = ProtectRtp(data, len, static_cast<int>(packet->capacity()), &len);
#else
if (!IsExternalAuthActive()) {
res = ProtectRtp(data, len, static_cast<int>(packet->capacity()), &len);
} else {
updated_options.packet_time_params.rtp_sendtime_extension_id =
rtp_abs_sendtime_extn_id_;
res = ProtectRtp(data, len, static_cast<int>(packet->capacity()), &len,
&updated_options.packet_time_params.srtp_packet_index);
// If protection succeeds, let's get auth params from srtp.
if (res) {
uint8_t* auth_key = NULL;
int key_len;
res = GetRtpAuthParams(
&auth_key, &key_len,
&updated_options.packet_time_params.srtp_auth_tag_len);
if (res) {
updated_options.packet_time_params.srtp_auth_key.resize(key_len);
updated_options.packet_time_params.srtp_auth_key.assign(
auth_key, auth_key + key_len);
}
}
}
#endif
if (!res) {
int seq_num = -1;
uint32_t ssrc = 0;
cricket::GetRtpSeqNum(data, len, &seq_num);
cricket::GetRtpSsrc(data, len, &ssrc);
LOG(LS_ERROR) << "Failed to protect " << content_name_
<< " RTP packet: size=" << len << ", seqnum=" << seq_num
<< ", SSRC=" << ssrc;
return false;
}
} else {
res = ProtectRtcp(data, len, static_cast<int>(packet->capacity()), &len);
if (!res) {
int type = -1;
cricket::GetRtcpType(data, len, &type);
LOG(LS_ERROR) << "Failed to protect " << content_name_
<< " RTCP packet: size=" << len << ", type=" << type;
return false;
}
}
// Update the length of the packet now that we've added the auth tag.
packet->SetSize(len);
return rtcp ? rtp_transport_->SendRtcpPacket(packet, updated_options, flags)
: rtp_transport_->SendRtpPacket(packet, updated_options, flags);
} }
void SrtpTransport::OnPacketReceived(bool rtcp, void SrtpTransport::OnPacketReceived(bool rtcp,
rtc::CopyOnWriteBuffer* packet, rtc::CopyOnWriteBuffer* packet,
const rtc::PacketTime& packet_time) { const rtc::PacketTime& packet_time) {
if (!IsActive()) { // TODO(zstein): Unprotect packet.
LOG(LS_WARNING) << "Inactive SRTP transport received a packet. Drop it.";
return;
}
TRACE_EVENT0("webrtc", "SRTP Decode");
char* data = packet->data<char>();
int len = static_cast<int>(packet->size());
bool res;
if (!rtcp) {
res = UnprotectRtp(data, len, &len);
if (!res) {
int seq_num = -1;
uint32_t ssrc = 0;
cricket::GetRtpSeqNum(data, len, &seq_num);
cricket::GetRtpSsrc(data, len, &ssrc);
LOG(LS_ERROR) << "Failed to unprotect " << content_name_
<< " RTP packet: size=" << len << ", seqnum=" << seq_num
<< ", SSRC=" << ssrc;
return;
}
} else {
res = UnprotectRtcp(data, len, &len);
if (!res) {
int type = -1;
cricket::GetRtcpType(data, len, &type);
LOG(LS_ERROR) << "Failed to unprotect " << content_name_
<< " RTCP packet: size=" << len << ", type=" << type;
return;
}
}
packet->SetSize(len);
SignalPacketReceived(rtcp, packet, packet_time); SignalPacketReceived(rtcp, packet, packet_time);
} }
bool SrtpTransport::SetRtpParams(int send_cs,
const uint8_t* send_key,
int send_key_len,
int recv_cs,
const uint8_t* recv_key,
int recv_key_len) {
CreateSrtpSessions();
send_session_->SetEncryptedHeaderExtensionIds(
send_encrypted_header_extension_ids_);
if (external_auth_enabled_) {
send_session_->EnableExternalAuth();
}
if (!send_session_->SetSend(send_cs, send_key, send_key_len)) {
ResetParams();
return false;
}
recv_session_->SetEncryptedHeaderExtensionIds(
recv_encrypted_header_extension_ids_);
if (!recv_session_->SetRecv(recv_cs, recv_key, recv_key_len)) {
ResetParams();
return false;
}
LOG(LS_INFO) << "SRTP activated with negotiated parameters:"
<< " send cipher_suite " << send_cs << " recv cipher_suite "
<< recv_cs;
return true;
}
bool SrtpTransport::SetRtcpParams(int send_cs,
const uint8_t* send_key,
int send_key_len,
int recv_cs,
const uint8_t* recv_key,
int recv_key_len) {
// This can only be called once, but can be safely called after
// SetRtpParams
if (send_rtcp_session_ || recv_rtcp_session_) {
LOG(LS_ERROR) << "Tried to set SRTCP Params when filter already active";
return false;
}
send_rtcp_session_.reset(new cricket::SrtpSession());
if (!send_rtcp_session_->SetRecv(send_cs, send_key, send_key_len)) {
return false;
}
recv_rtcp_session_.reset(new cricket::SrtpSession());
if (!recv_rtcp_session_->SetRecv(recv_cs, recv_key, recv_key_len)) {
return false;
}
LOG(LS_INFO) << "SRTCP activated with negotiated parameters:"
<< " send cipher_suite " << send_cs << " recv cipher_suite "
<< recv_cs;
return true;
}
bool SrtpTransport::IsActive() const {
return send_session_ && recv_session_;
}
void SrtpTransport::ResetParams() {
send_session_ = nullptr;
recv_session_ = nullptr;
send_rtcp_session_ = nullptr;
recv_rtcp_session_ = nullptr;
LOG(LS_INFO) << "The params in SRTP transport are reset.";
}
void SrtpTransport::SetEncryptedHeaderExtensionIds(
cricket::ContentSource source,
const std::vector<int>& extension_ids) {
if (source == cricket::CS_LOCAL) {
recv_encrypted_header_extension_ids_ = extension_ids;
} else {
send_encrypted_header_extension_ids_ = extension_ids;
}
}
void SrtpTransport::CreateSrtpSessions() {
send_session_.reset(new cricket::SrtpSession());
recv_session_.reset(new cricket::SrtpSession());
if (external_auth_enabled_) {
send_session_->EnableExternalAuth();
}
}
bool SrtpTransport::ProtectRtp(void* p, int in_len, int max_len, int* out_len) {
if (!IsActive()) {
LOG(LS_WARNING) << "Failed to ProtectRtp: SRTP not active";
return false;
}
RTC_CHECK(send_session_);
return send_session_->ProtectRtp(p, in_len, max_len, out_len);
}
bool SrtpTransport::ProtectRtp(void* p,
int in_len,
int max_len,
int* out_len,
int64_t* index) {
if (!IsActive()) {
LOG(LS_WARNING) << "Failed to ProtectRtp: SRTP not active";
return false;
}
RTC_CHECK(send_session_);
return send_session_->ProtectRtp(p, in_len, max_len, out_len, index);
}
bool SrtpTransport::ProtectRtcp(void* p,
int in_len,
int max_len,
int* out_len) {
if (!IsActive()) {
LOG(LS_WARNING) << "Failed to ProtectRtcp: SRTP not active";
return false;
}
if (send_rtcp_session_) {
return send_rtcp_session_->ProtectRtcp(p, in_len, max_len, out_len);
} else {
RTC_CHECK(send_session_);
return send_session_->ProtectRtcp(p, in_len, max_len, out_len);
}
}
bool SrtpTransport::UnprotectRtp(void* p, int in_len, int* out_len) {
if (!IsActive()) {
LOG(LS_WARNING) << "Failed to UnprotectRtp: SRTP not active";
return false;
}
RTC_CHECK(recv_session_);
return recv_session_->UnprotectRtp(p, in_len, out_len);
}
bool SrtpTransport::UnprotectRtcp(void* p, int in_len, int* out_len) {
if (!IsActive()) {
LOG(LS_WARNING) << "Failed to UnprotectRtcp: SRTP not active";
return false;
}
if (recv_rtcp_session_) {
return recv_rtcp_session_->UnprotectRtcp(p, in_len, out_len);
} else {
RTC_CHECK(recv_session_);
return recv_session_->UnprotectRtcp(p, in_len, out_len);
}
}
bool SrtpTransport::GetRtpAuthParams(uint8_t** key,
int* key_len,
int* tag_len) {
if (!IsActive()) {
LOG(LS_WARNING) << "Failed to GetRtpAuthParams: SRTP not active";
return false;
}
RTC_CHECK(send_session_);
return send_session_->GetRtpAuthParams(key, key_len, tag_len);
}
bool SrtpTransport::GetSrtpOverhead(int* srtp_overhead) const {
if (!IsActive()) {
LOG(LS_WARNING) << "Failed to GetSrtpOverhead: SRTP not active";
return false;
}
RTC_CHECK(send_session_);
*srtp_overhead = send_session_->GetSrtpOverhead();
return true;
}
void SrtpTransport::EnableExternalAuth() {
RTC_DCHECK(!IsActive());
external_auth_enabled_ = true;
}
bool SrtpTransport::IsExternalAuthEnabled() const {
return external_auth_enabled_;
}
bool SrtpTransport::IsExternalAuthActive() const {
if (!IsActive()) {
LOG(LS_WARNING) << "Failed to check IsExternalAuthActive: SRTP not active";
return false;
}
RTC_CHECK(send_session_);
return send_session_->IsExternalAuthActive();
}
} // namespace webrtc } // namespace webrtc

View File

@ -17,17 +17,20 @@
#include "pc/rtptransportinternal.h" #include "pc/rtptransportinternal.h"
#include "pc/srtpfilter.h" #include "pc/srtpfilter.h"
#include "pc/srtpsession.h"
#include "rtc_base/checks.h" #include "rtc_base/checks.h"
namespace webrtc { namespace webrtc {
// This class will eventually be a wrapper around RtpTransportInternal // This class will eventually be a wrapper around RtpTransportInternal
// that protects and unprotects sent and received RTP packets. // that protects and unprotects sent and received RTP packets. This
// functionality is currently implemented by SrtpFilter and BaseChannel, but
// will be moved here in the future.
class SrtpTransport : public RtpTransportInternal { class SrtpTransport : public RtpTransportInternal {
public: public:
SrtpTransport(bool rtcp_mux_enabled, const std::string& content_name); SrtpTransport(bool rtcp_mux_enabled, const std::string& content_name);
// TODO(zstein): Consider taking an RtpTransport instead of an
// RtpTransportInternal.
SrtpTransport(std::unique_ptr<RtpTransportInternal> transport, SrtpTransport(std::unique_ptr<RtpTransportInternal> transport,
const std::string& content_name); const std::string& content_name);
@ -58,21 +61,14 @@ class SrtpTransport : public RtpTransportInternal {
return rtp_transport_->GetRtcpPacketTransport(); return rtp_transport_->GetRtcpPacketTransport();
} }
bool SendRtpPacket(rtc::CopyOnWriteBuffer* packet,
const rtc::PacketOptions& options,
int flags) override;
bool SendRtcpPacket(rtc::CopyOnWriteBuffer* packet,
const rtc::PacketOptions& options,
int flags) override;
bool IsWritable(bool rtcp) const override { bool IsWritable(bool rtcp) const override {
return rtp_transport_->IsWritable(rtcp); return rtp_transport_->IsWritable(rtcp);
} }
// The transport becomes active if the send_session_ and recv_session_ are bool SendPacket(bool rtcp,
// created. rtc::CopyOnWriteBuffer* packet,
bool IsActive() const; const rtc::PacketOptions& options,
int flags) override;
bool HandlesPayloadType(int payload_type) const override { bool HandlesPayloadType(int payload_type) const override {
return rtp_transport_->HandlesPayloadType(payload_type); return rtp_transport_->HandlesPayloadType(payload_type);
@ -93,104 +89,18 @@ class SrtpTransport : public RtpTransportInternal {
// TODO(zstein): Remove this when we remove RtpTransportAdapter. // TODO(zstein): Remove this when we remove RtpTransportAdapter.
RtpTransportAdapter* GetInternal() override { return nullptr; } RtpTransportAdapter* GetInternal() override { return nullptr; }
// Create new send/recv sessions and set the negotiated crypto keys for RTP
// packet encryption. The keys can either come from SDES negotiation or DTLS
// handshake.
bool SetRtpParams(int send_cs,
const uint8_t* send_key,
int send_key_len,
int recv_cs,
const uint8_t* recv_key,
int recv_key_len);
// Create new send/recv sessions and set the negotiated crypto keys for RTCP
// packet encryption. The keys can either come from SDES negotiation or DTLS
// handshake.
bool SetRtcpParams(int send_cs,
const uint8_t* send_key,
int send_key_len,
int recv_cs,
const uint8_t* recv_key,
int recv_key_len);
void ResetParams();
// Set the header extension ids that should be encrypted for the given source.
// This method doesn't immediately update the SRTP session with the new IDs,
// and you need to call SetRtpParams for that to happen.
void SetEncryptedHeaderExtensionIds(cricket::ContentSource source,
const std::vector<int>& extension_ids);
// If external auth is enabled, SRTP will write a dummy auth tag that then
// later must get replaced before the packet is sent out. Only supported for
// non-GCM cipher suites and can be checked through "IsExternalAuthActive"
// if it is actually used. This method is only valid before the RTP params
// have been set.
void EnableExternalAuth();
bool IsExternalAuthEnabled() const;
// A SrtpTransport supports external creation of the auth tag if a non-GCM
// cipher is used. This method is only valid after the RTP params have
// been set.
bool IsExternalAuthActive() const;
// Returns srtp overhead for rtp packets.
bool GetSrtpOverhead(int* srtp_overhead) const;
// Returns rtp auth params from srtp context.
bool GetRtpAuthParams(uint8_t** key, int* key_len, int* tag_len);
// Helper method to get RTP Absoulute SendTime extension header id if
// present in remote supported extensions list.
void CacheRtpAbsSendTimeHeaderExtension(int rtp_abs_sendtime_extn_id) {
rtp_abs_sendtime_extn_id_ = rtp_abs_sendtime_extn_id;
}
private: private:
void CreateSrtpSessions();
void ConnectToRtpTransport(); void ConnectToRtpTransport();
bool SendPacket(bool rtcp,
rtc::CopyOnWriteBuffer* packet,
const rtc::PacketOptions& options,
int flags);
void OnPacketReceived(bool rtcp, void OnPacketReceived(bool rtcp,
rtc::CopyOnWriteBuffer* packet, rtc::CopyOnWriteBuffer* packet,
const rtc::PacketTime& packet_time); const rtc::PacketTime& packet_time);
void OnReadyToSend(bool ready) { SignalReadyToSend(ready); } void OnReadyToSend(bool ready) { SignalReadyToSend(ready); }
bool ProtectRtp(void* data, int in_len, int max_len, int* out_len);
// Overloaded version, outputs packet index.
bool ProtectRtp(void* data,
int in_len,
int max_len,
int* out_len,
int64_t* index);
bool ProtectRtcp(void* data, int in_len, int max_len, int* out_len);
// Decrypts/verifies an invidiual RTP/RTCP packet.
// If an HMAC is used, this will decrease the packet size.
bool UnprotectRtp(void* data, int in_len, int* out_len);
bool UnprotectRtcp(void* data, int in_len, int* out_len);
const std::string content_name_; const std::string content_name_;
std::unique_ptr<RtpTransportInternal> rtp_transport_; std::unique_ptr<RtpTransportInternal> rtp_transport_;
std::unique_ptr<cricket::SrtpSession> send_session_;
std::unique_ptr<cricket::SrtpSession> recv_session_;
std::unique_ptr<cricket::SrtpSession> send_rtcp_session_;
std::unique_ptr<cricket::SrtpSession> recv_rtcp_session_;
std::vector<int> send_encrypted_header_extension_ids_;
std::vector<int> recv_encrypted_header_extension_ids_;
bool external_auth_enabled_ = false;
int rtp_abs_sendtime_extn_id_ = -1;
}; };
} // namespace webrtc } // namespace webrtc

View File

@ -10,413 +10,67 @@
#include "pc/srtptransport.h" #include "pc/srtptransport.h"
#include "media/base/fakertp.h"
#include "p2p/base/dtlstransportinternal.h"
#include "p2p/base/fakepackettransport.h"
#include "pc/rtptransport.h" #include "pc/rtptransport.h"
#include "pc/rtptransporttestutil.h" #include "pc/rtptransporttestutil.h"
#include "pc/srtptestutil.h"
#include "rtc_base/asyncpacketsocket.h" #include "rtc_base/asyncpacketsocket.h"
#include "rtc_base/gunit.h" #include "rtc_base/gunit.h"
#include "rtc_base/ptr_util.h" #include "rtc_base/ptr_util.h"
#include "rtc_base/sslstreamadapter.h" #include "test/gmock.h"
using rtc::kTestKey1;
using rtc::kTestKey2;
using rtc::kTestKeyLen;
using rtc::SRTP_AEAD_AES_128_GCM;
namespace webrtc { namespace webrtc {
static const uint8_t kTestKeyGcm128_1[] = "ABCDEFGHIJKLMNOPQRSTUVWXYZ12";
static const uint8_t kTestKeyGcm128_2[] = "21ZYXWVUTSRQPONMLKJIHGFEDCBA";
static const int kTestKeyGcm128Len = 28; // 128 bits key + 96 bits salt.
static const uint8_t kTestKeyGcm256_1[] =
"ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqr";
static const uint8_t kTestKeyGcm256_2[] =
"rqponmlkjihgfedcbaZYXWVUTSRQPONMLKJIHGFEDCBA";
static const int kTestKeyGcm256Len = 44; // 256 bits key + 96 bits salt.
class SrtpTransportTest : public testing::Test, public sigslot::has_slots<> { using testing::_;
protected: using testing::Return;
SrtpTransportTest() {
bool rtcp_mux_enabled = true;
auto rtp_transport1 = rtc::MakeUnique<RtpTransport>(rtcp_mux_enabled);
auto rtp_transport2 = rtc::MakeUnique<RtpTransport>(rtcp_mux_enabled);
rtp_packet_transport1_ = class MockRtpTransport : public RtpTransport {
rtc::MakeUnique<rtc::FakePacketTransport>("fake_packet_transport1"); public:
rtp_packet_transport2_ = MockRtpTransport() : RtpTransport(true) {}
rtc::MakeUnique<rtc::FakePacketTransport>("fake_packet_transport2");
bool asymmetric = false; MOCK_METHOD4(SendPacket,
rtp_packet_transport1_->SetDestination(rtp_packet_transport2_.get(), bool(bool rtcp,
asymmetric);
rtp_transport1->SetRtpPacketTransport(rtp_packet_transport1_.get());
rtp_transport2->SetRtpPacketTransport(rtp_packet_transport2_.get());
// Add payload type for RTP packet and RTCP packet.
rtp_transport1->AddHandledPayloadType(0x00);
rtp_transport2->AddHandledPayloadType(0x00);
rtp_transport1->AddHandledPayloadType(0xc9);
rtp_transport2->AddHandledPayloadType(0xc9);
srtp_transport1_ =
rtc::MakeUnique<SrtpTransport>(std::move(rtp_transport1), "content");
srtp_transport2_ =
rtc::MakeUnique<SrtpTransport>(std::move(rtp_transport2), "content");
srtp_transport1_->SignalPacketReceived.connect(
this, &SrtpTransportTest::OnPacketReceived1);
srtp_transport2_->SignalPacketReceived.connect(
this, &SrtpTransportTest::OnPacketReceived2);
}
void OnPacketReceived1(bool rtcp,
rtc::CopyOnWriteBuffer* packet, rtc::CopyOnWriteBuffer* packet,
const rtc::PacketTime& packet_time) { const rtc::PacketOptions& options,
LOG(LS_INFO) << "SrtpTransport1 Received a packet."; int flags));
last_recv_packet1_ = *packet;
void PretendReceivedPacket() {
bool rtcp = false;
rtc::CopyOnWriteBuffer buffer;
rtc::PacketTime time;
SignalPacketReceived(rtcp, &buffer, time);
} }
void OnPacketReceived2(bool rtcp,
rtc::CopyOnWriteBuffer* packet,
const rtc::PacketTime& packet_time) {
LOG(LS_INFO) << "SrtpTransport2 Received a packet.";
last_recv_packet2_ = *packet;
}
// With external auth enabled, SRTP doesn't write the auth tag and
// unprotect would fail. Check accessing the information about the
// tag instead, similar to what the actual code would do that relies
// on external auth.
void TestRtpAuthParams(SrtpTransport* transport, const std::string& cs) {
int overhead;
EXPECT_TRUE(transport->GetSrtpOverhead(&overhead));
switch (rtc::SrtpCryptoSuiteFromName(cs)) {
case rtc::SRTP_AES128_CM_SHA1_32:
EXPECT_EQ(32 / 8, overhead); // 32-bit tag.
break;
case rtc::SRTP_AES128_CM_SHA1_80:
EXPECT_EQ(80 / 8, overhead); // 80-bit tag.
break;
default:
RTC_NOTREACHED();
break;
}
uint8_t* auth_key = nullptr;
int key_len = 0;
int tag_len = 0;
EXPECT_TRUE(transport->GetRtpAuthParams(&auth_key, &key_len, &tag_len));
EXPECT_NE(nullptr, auth_key);
EXPECT_EQ(160 / 8, key_len); // Length of SHA-1 is 160 bits.
EXPECT_EQ(overhead, tag_len);
}
void TestSendRecvRtpPacket(const std::string& cipher_suite_name) {
size_t rtp_len = sizeof(kPcmuFrame);
size_t packet_size = rtp_len + rtc::rtp_auth_tag_len(cipher_suite_name);
rtc::Buffer rtp_packet_buffer(packet_size);
char* rtp_packet_data = rtp_packet_buffer.data<char>();
memcpy(rtp_packet_data, kPcmuFrame, rtp_len);
// In order to be able to run this test function multiple times we can not
// use the same sequence number twice. Increase the sequence number by one.
rtc::SetBE16(reinterpret_cast<uint8_t*>(rtp_packet_data) + 2,
++sequence_number_);
rtc::CopyOnWriteBuffer rtp_packet1to2(rtp_packet_data, rtp_len,
packet_size);
rtc::CopyOnWriteBuffer rtp_packet2to1(rtp_packet_data, rtp_len,
packet_size);
char original_rtp_data[sizeof(kPcmuFrame)];
memcpy(original_rtp_data, rtp_packet_data, rtp_len);
rtc::PacketOptions options;
// Send a packet from |srtp_transport1_| to |srtp_transport2_| and verify
// that the packet can be successfully received and decrypted.
ASSERT_TRUE(srtp_transport1_->SendRtpPacket(&rtp_packet1to2, options,
cricket::PF_SRTP_BYPASS));
if (srtp_transport1_->IsExternalAuthActive()) {
TestRtpAuthParams(srtp_transport1_.get(), cipher_suite_name);
} else {
ASSERT_TRUE(last_recv_packet2_.data());
EXPECT_TRUE(
memcmp(last_recv_packet2_.data(), original_rtp_data, rtp_len) == 0);
// Get the encrypted packet from underneath packet transport and verify
// the data is actually encrypted.
auto fake_rtp_packet_transport = static_cast<rtc::FakePacketTransport*>(
srtp_transport1_->rtp_packet_transport());
EXPECT_FALSE(memcmp(fake_rtp_packet_transport->last_sent_packet()->data(),
original_rtp_data, rtp_len) == 0);
}
// Do the same thing in the opposite direction;
ASSERT_TRUE(srtp_transport2_->SendRtpPacket(&rtp_packet2to1, options,
cricket::PF_SRTP_BYPASS));
if (srtp_transport2_->IsExternalAuthActive()) {
TestRtpAuthParams(srtp_transport2_.get(), cipher_suite_name);
} else {
ASSERT_TRUE(last_recv_packet1_.data());
EXPECT_TRUE(
memcmp(last_recv_packet1_.data(), original_rtp_data, rtp_len) == 0);
auto fake_rtp_packet_transport = static_cast<rtc::FakePacketTransport*>(
srtp_transport2_->rtp_packet_transport());
EXPECT_FALSE(memcmp(fake_rtp_packet_transport->last_sent_packet()->data(),
original_rtp_data, rtp_len) == 0);
}
}
void TestSendRecvRtcpPacket(const std::string& cipher_suite_name) {
size_t rtcp_len = sizeof(kRtcpReport);
size_t packet_size =
rtcp_len + 4 + rtc::rtcp_auth_tag_len(cipher_suite_name);
rtc::Buffer rtcp_packet_buffer(packet_size);
char* rtcp_packet_data = rtcp_packet_buffer.data<char>();
memcpy(rtcp_packet_data, kRtcpReport, rtcp_len);
rtc::CopyOnWriteBuffer rtcp_packet1to2(rtcp_packet_data, rtcp_len,
packet_size);
rtc::CopyOnWriteBuffer rtcp_packet2to1(rtcp_packet_data, rtcp_len,
packet_size);
rtc::PacketOptions options;
// Send a packet from |srtp_transport1_| to |srtp_transport2_| and verify
// that the packet can be successfully received and decrypted.
ASSERT_TRUE(srtp_transport1_->SendRtcpPacket(&rtcp_packet1to2, options,
cricket::PF_SRTP_BYPASS));
ASSERT_TRUE(last_recv_packet2_.data());
EXPECT_TRUE(memcmp(last_recv_packet2_.data(), rtcp_packet_data, rtcp_len) ==
0);
// Get the encrypted packet from underneath packet transport and verify the
// data is actually encrypted.
auto fake_rtp_packet_transport = static_cast<rtc::FakePacketTransport*>(
srtp_transport1_->rtp_packet_transport());
EXPECT_FALSE(memcmp(fake_rtp_packet_transport->last_sent_packet()->data(),
rtcp_packet_data, rtcp_len) == 0);
// Do the same thing in the opposite direction;
ASSERT_TRUE(srtp_transport2_->SendRtcpPacket(&rtcp_packet2to1, options,
cricket::PF_SRTP_BYPASS));
ASSERT_TRUE(last_recv_packet1_.data());
EXPECT_TRUE(memcmp(last_recv_packet1_.data(), rtcp_packet_data, rtcp_len) ==
0);
fake_rtp_packet_transport = static_cast<rtc::FakePacketTransport*>(
srtp_transport2_->rtp_packet_transport());
EXPECT_FALSE(memcmp(fake_rtp_packet_transport->last_sent_packet()->data(),
rtcp_packet_data, rtcp_len) == 0);
}
void TestSendRecvPacket(bool enable_external_auth,
int cs,
const uint8_t* key1,
int key1_len,
const uint8_t* key2,
int key2_len,
const std::string& cipher_suite_name) {
EXPECT_EQ(key1_len, key2_len);
EXPECT_EQ(cipher_suite_name, rtc::SrtpCryptoSuiteToName(cs));
if (enable_external_auth) {
srtp_transport1_->EnableExternalAuth();
srtp_transport2_->EnableExternalAuth();
}
EXPECT_TRUE(
srtp_transport1_->SetRtpParams(cs, key1, key1_len, cs, key2, key2_len));
EXPECT_TRUE(
srtp_transport2_->SetRtpParams(cs, key2, key2_len, cs, key1, key1_len));
EXPECT_TRUE(srtp_transport1_->SetRtcpParams(cs, key1, key1_len, cs, key2,
key2_len));
EXPECT_TRUE(srtp_transport2_->SetRtcpParams(cs, key2, key2_len, cs, key1,
key1_len));
EXPECT_TRUE(srtp_transport1_->IsActive());
EXPECT_TRUE(srtp_transport2_->IsActive());
if (rtc::IsGcmCryptoSuite(cs)) {
EXPECT_FALSE(srtp_transport1_->IsExternalAuthActive());
EXPECT_FALSE(srtp_transport2_->IsExternalAuthActive());
} else if (enable_external_auth) {
EXPECT_TRUE(srtp_transport1_->IsExternalAuthActive());
EXPECT_TRUE(srtp_transport2_->IsExternalAuthActive());
}
TestSendRecvRtpPacket(cipher_suite_name);
TestSendRecvRtcpPacket(cipher_suite_name);
}
void TestSendRecvPacketWithEncryptedHeaderExtension(
const std::string& cs,
const std::vector<int>& encrypted_header_ids) {
size_t rtp_len = sizeof(kPcmuFrameWithExtensions);
size_t packet_size = rtp_len + rtc::rtp_auth_tag_len(cs);
rtc::Buffer rtp_packet_buffer(packet_size);
char* rtp_packet_data = rtp_packet_buffer.data<char>();
memcpy(rtp_packet_data, kPcmuFrameWithExtensions, rtp_len);
// In order to be able to run this test function multiple times we can not
// use the same sequence number twice. Increase the sequence number by one.
rtc::SetBE16(reinterpret_cast<uint8_t*>(rtp_packet_data) + 2,
++sequence_number_);
rtc::CopyOnWriteBuffer rtp_packet1to2(rtp_packet_data, rtp_len,
packet_size);
rtc::CopyOnWriteBuffer rtp_packet2to1(rtp_packet_data, rtp_len,
packet_size);
char original_rtp_data[sizeof(kPcmuFrameWithExtensions)];
memcpy(original_rtp_data, rtp_packet_data, rtp_len);
rtc::PacketOptions options;
// Send a packet from |srtp_transport1_| to |srtp_transport2_| and verify
// that the packet can be successfully received and decrypted.
ASSERT_TRUE(srtp_transport1_->SendRtpPacket(&rtp_packet1to2, options,
cricket::PF_SRTP_BYPASS));
ASSERT_TRUE(last_recv_packet2_.data());
EXPECT_TRUE(memcmp(last_recv_packet2_.data(), original_rtp_data, rtp_len) ==
0);
// Get the encrypted packet from underneath packet transport and verify the
// data and header extension are actually encrypted.
auto fake_rtp_packet_transport = static_cast<rtc::FakePacketTransport*>(
srtp_transport1_->rtp_packet_transport());
EXPECT_FALSE(memcmp(fake_rtp_packet_transport->last_sent_packet()->data(),
original_rtp_data, rtp_len) == 0);
CompareHeaderExtensions(
reinterpret_cast<const char*>(
fake_rtp_packet_transport->last_sent_packet()->data()),
fake_rtp_packet_transport->last_sent_packet()->size(),
original_rtp_data, rtp_len, encrypted_header_ids, false);
// Do the same thing in the opposite direction;
ASSERT_TRUE(srtp_transport2_->SendRtpPacket(&rtp_packet2to1, options,
cricket::PF_SRTP_BYPASS));
ASSERT_TRUE(last_recv_packet1_.data());
EXPECT_TRUE(memcmp(last_recv_packet1_.data(), original_rtp_data, rtp_len) ==
0);
fake_rtp_packet_transport = static_cast<rtc::FakePacketTransport*>(
srtp_transport2_->rtp_packet_transport());
EXPECT_FALSE(memcmp(fake_rtp_packet_transport->last_sent_packet()->data(),
original_rtp_data, rtp_len) == 0);
CompareHeaderExtensions(
reinterpret_cast<const char*>(
fake_rtp_packet_transport->last_sent_packet()->data()),
fake_rtp_packet_transport->last_sent_packet()->size(),
original_rtp_data, rtp_len, encrypted_header_ids, false);
}
void TestSendRecvEncryptedHeaderExtension(int cs,
const uint8_t* key1,
int key1_len,
const uint8_t* key2,
int key2_len,
const std::string& cs_name) {
std::vector<int> encrypted_headers;
encrypted_headers.push_back(1);
// Don't encrypt header ids 2 and 3.
encrypted_headers.push_back(4);
EXPECT_EQ(key1_len, key2_len);
EXPECT_EQ(cs_name, rtc::SrtpCryptoSuiteToName(cs));
srtp_transport1_->SetEncryptedHeaderExtensionIds(cricket::CS_LOCAL,
encrypted_headers);
srtp_transport1_->SetEncryptedHeaderExtensionIds(cricket::CS_REMOTE,
encrypted_headers);
srtp_transport2_->SetEncryptedHeaderExtensionIds(cricket::CS_LOCAL,
encrypted_headers);
srtp_transport2_->SetEncryptedHeaderExtensionIds(cricket::CS_REMOTE,
encrypted_headers);
EXPECT_TRUE(
srtp_transport1_->SetRtpParams(cs, key1, key1_len, cs, key2, key2_len));
EXPECT_TRUE(
srtp_transport2_->SetRtpParams(cs, key2, key2_len, cs, key1, key1_len));
EXPECT_TRUE(srtp_transport1_->IsActive());
EXPECT_TRUE(srtp_transport2_->IsActive());
EXPECT_FALSE(srtp_transport1_->IsExternalAuthActive());
EXPECT_FALSE(srtp_transport2_->IsExternalAuthActive());
TestSendRecvPacketWithEncryptedHeaderExtension(cs_name, encrypted_headers);
}
std::unique_ptr<SrtpTransport> srtp_transport1_;
std::unique_ptr<SrtpTransport> srtp_transport2_;
std::unique_ptr<rtc::FakePacketTransport> rtp_packet_transport1_;
std::unique_ptr<rtc::FakePacketTransport> rtp_packet_transport2_;
rtc::CopyOnWriteBuffer last_recv_packet1_;
rtc::CopyOnWriteBuffer last_recv_packet2_;
int sequence_number_ = 0;
}; };
class SrtpTransportTestWithExternalAuth TEST(SrtpTransportTest, SendPacket) {
: public SrtpTransportTest, auto rtp_transport = rtc::MakeUnique<MockRtpTransport>();
public testing::WithParamInterface<bool> {}; EXPECT_CALL(*rtp_transport, SendPacket(_, _, _, _)).WillOnce(Return(true));
TEST_P(SrtpTransportTestWithExternalAuth, SrtpTransport srtp_transport(std::move(rtp_transport), "a");
SendAndRecvPacket_AES_CM_128_HMAC_SHA1_80) {
bool enable_external_auth = GetParam(); const bool rtcp = false;
TestSendRecvPacket(enable_external_auth, rtc::SRTP_AES128_CM_SHA1_80, rtc::CopyOnWriteBuffer packet;
kTestKey1, kTestKeyLen, kTestKey2, kTestKeyLen, rtc::PacketOptions options;
rtc::CS_AES_CM_128_HMAC_SHA1_80); int flags = 0;
EXPECT_TRUE(srtp_transport.SendPacket(rtcp, &packet, options, flags));
// TODO(zstein): Also verify that the packet received by RtpTransport has been
// protected once SrtpTransport handles that.
} }
TEST_F(SrtpTransportTest, // Test that SrtpTransport fires SignalPacketReceived when the underlying
SendAndRecvPacketWithHeaderExtension_AES_CM_128_HMAC_SHA1_80) { // RtpTransport fires SignalPacketReceived.
TestSendRecvEncryptedHeaderExtension(rtc::SRTP_AES128_CM_SHA1_80, kTestKey1, TEST(SrtpTransportTest, SignalPacketReceived) {
kTestKeyLen, kTestKey2, kTestKeyLen, auto rtp_transport = rtc::MakeUnique<MockRtpTransport>();
rtc::CS_AES_CM_128_HMAC_SHA1_80); MockRtpTransport* rtp_transport_raw = rtp_transport.get();
} SrtpTransport srtp_transport(std::move(rtp_transport), "a");
TEST_P(SrtpTransportTestWithExternalAuth, SignalPacketReceivedCounter counter(&srtp_transport);
SendAndRecvPacket_AES_CM_128_HMAC_SHA1_32) {
bool enable_external_auth = GetParam();
TestSendRecvPacket(enable_external_auth, rtc::SRTP_AES128_CM_SHA1_32,
kTestKey1, kTestKeyLen, kTestKey2, kTestKeyLen,
rtc::CS_AES_CM_128_HMAC_SHA1_32);
}
TEST_F(SrtpTransportTest, rtp_transport_raw->PretendReceivedPacket();
SendAndRecvPacketWithHeaderExtension_AES_CM_128_HMAC_SHA1_32) {
TestSendRecvEncryptedHeaderExtension(rtc::SRTP_AES128_CM_SHA1_32, kTestKey1,
kTestKeyLen, kTestKey2, kTestKeyLen,
rtc::CS_AES_CM_128_HMAC_SHA1_32);
}
TEST_P(SrtpTransportTestWithExternalAuth, EXPECT_EQ(1, counter.rtp_count());
SendAndRecvPacket_SRTP_AEAD_AES_128_GCM) {
bool enable_external_auth = GetParam();
TestSendRecvPacket(enable_external_auth, rtc::SRTP_AEAD_AES_128_GCM,
kTestKeyGcm128_1, kTestKeyGcm128Len, kTestKeyGcm128_2,
kTestKeyGcm128Len, rtc::CS_AEAD_AES_128_GCM);
}
TEST_F(SrtpTransportTest, // TODO(zstein): Also verify that the packet is unprotected once SrtpTransport
SendAndRecvPacketWithHeaderExtension_SRTP_AEAD_AES_128_GCM) { // handles that.
TestSendRecvEncryptedHeaderExtension(
rtc::SRTP_AEAD_AES_128_GCM, kTestKeyGcm128_1, kTestKeyGcm128Len,
kTestKeyGcm128_2, kTestKeyGcm128Len, rtc::CS_AEAD_AES_128_GCM);
}
TEST_P(SrtpTransportTestWithExternalAuth,
SendAndRecvPacket_SRTP_AEAD_AES_256_GCM) {
bool enable_external_auth = GetParam();
TestSendRecvPacket(enable_external_auth, rtc::SRTP_AEAD_AES_256_GCM,
kTestKeyGcm256_1, kTestKeyGcm256Len, kTestKeyGcm256_2,
kTestKeyGcm256Len, rtc::CS_AEAD_AES_256_GCM);
}
TEST_F(SrtpTransportTest,
SendAndRecvPacketWithHeaderExtension_SRTP_AEAD_AES_256_GCM) {
TestSendRecvEncryptedHeaderExtension(
rtc::SRTP_AEAD_AES_256_GCM, kTestKeyGcm256_1, kTestKeyGcm256Len,
kTestKeyGcm256_2, kTestKeyGcm256Len, rtc::CS_AEAD_AES_256_GCM);
}
// Run all tests both with and without external auth enabled.
INSTANTIATE_TEST_CASE_P(ExternalAuth,
SrtpTransportTestWithExternalAuth,
::testing::Values(true, false));
// Test directly setting the params with bogus keys.
TEST_F(SrtpTransportTest, TestSetParamsKeyTooShort) {
EXPECT_FALSE(srtp_transport1_->SetRtpParams(
rtc::SRTP_AES128_CM_SHA1_80, kTestKey1, kTestKeyLen - 1,
rtc::SRTP_AES128_CM_SHA1_80, kTestKey1, kTestKeyLen - 1));
EXPECT_FALSE(srtp_transport1_->SetRtcpParams(
rtc::SRTP_AES128_CM_SHA1_80, kTestKey1, kTestKeyLen - 1,
rtc::SRTP_AES128_CM_SHA1_80, kTestKey1, kTestKeyLen - 1));
} }
} // namespace webrtc } // namespace webrtc