diff --git a/api/crypto/frame_decryptor_interface.h b/api/crypto/frame_decryptor_interface.h index ec900ab80a..69f137504c 100644 --- a/api/crypto/frame_decryptor_interface.h +++ b/api/crypto/frame_decryptor_interface.h @@ -34,9 +34,8 @@ class FrameDecryptorInterface : public rtc::RefCountInterface { // returned when attempting to decrypt a frame. kRecoverable indicates that // there was an error with the given frame and so it should not be passed to // the decoder, however it hints that the receive stream is still decryptable - // which is important for determining when to send key frame requests - // kUnknown should never be returned by the implementor. - enum class Status { kOk, kRecoverable, kFailedToDecrypt, kUnknown }; + // which is important for determining when to send key frame requests. + enum class Status { kOk, kRecoverable, kFailedToDecrypt }; struct Result { Result(Status status, size_t bytes_written) @@ -55,15 +54,33 @@ class FrameDecryptorInterface : public rtc::RefCountInterface { // that the frames are in order if SRTP is enabled. The stream is not provided // here and it is up to the implementor to transport this information to the // receiver if they care about it. You must set bytes_written to how many - // bytes you wrote to in the frame buffer. kOk must be returned if successful, - // kRecoverable should be returned if the failure was due to something other - // than a decryption failure. kFailedToDecrypt should be returned in all other - // cases. + // bytes you wrote to in the frame buffer. 0 must be returned if successful + // all other numbers can be selected by the implementer to represent error + // codes. + // TODO(bugs.webrtc.org/10512) - Remove this after implementation rewrite. + virtual int Decrypt(cricket::MediaType media_type, + const std::vector& csrcs, + rtc::ArrayView additional_data, + rtc::ArrayView encrypted_frame, + rtc::ArrayView frame, + size_t* bytes_written) { + bytes_written = 0; + return 1; + } + + // TODO(bugs.webrtc.org/10512) - Remove the other decrypt function and turn + // this to a pure virtual. virtual Result Decrypt(cricket::MediaType media_type, const std::vector& csrcs, rtc::ArrayView additional_data, rtc::ArrayView encrypted_frame, - rtc::ArrayView frame) = 0; + rtc::ArrayView frame) { + size_t bytes_written = 0; + const int status = Decrypt(media_type, csrcs, additional_data, + encrypted_frame, frame, &bytes_written); + return Result(status == 0 ? Status::kOk : Status::kFailedToDecrypt, + bytes_written); + } // Returns the total required length in bytes for the output of the // decryption. This can be larger than the actual number of bytes you need but diff --git a/api/test/fake_frame_decryptor.cc b/api/test/fake_frame_decryptor.cc index 4af42a6b82..b77017fdb4 100644 --- a/api/test/fake_frame_decryptor.cc +++ b/api/test/fake_frame_decryptor.cc @@ -18,14 +18,14 @@ FakeFrameDecryptor::FakeFrameDecryptor(uint8_t fake_key, uint8_t expected_postfix_byte) : fake_key_(fake_key), expected_postfix_byte_(expected_postfix_byte) {} -FakeFrameDecryptor::Result FakeFrameDecryptor::Decrypt( - cricket::MediaType media_type, - const std::vector& csrcs, - rtc::ArrayView additional_data, - rtc::ArrayView encrypted_frame, - rtc::ArrayView frame) { +int FakeFrameDecryptor::Decrypt(cricket::MediaType media_type, + const std::vector& csrcs, + rtc::ArrayView additional_data, + rtc::ArrayView encrypted_frame, + rtc::ArrayView frame, + size_t* bytes_written) { if (fail_decryption_) { - return Result(Status::kFailedToDecrypt, 0); + return static_cast(FakeDecryptStatus::FORCED_FAILURE); } RTC_CHECK_EQ(frame.size() + 1, encrypted_frame.size()); @@ -34,10 +34,11 @@ FakeFrameDecryptor::Result FakeFrameDecryptor::Decrypt( } if (encrypted_frame[frame.size()] != expected_postfix_byte_) { - return Result(Status::kFailedToDecrypt, 0); + return static_cast(FakeDecryptStatus::INVALID_POSTFIX); } - return Result(Status::kOk, frame.size()); + *bytes_written = frame.size(); + return static_cast(FakeDecryptStatus::OK); } size_t FakeFrameDecryptor::GetMaxPlaintextByteSize( diff --git a/api/test/fake_frame_decryptor.h b/api/test/fake_frame_decryptor.h index 05813dbbd0..cb370b905a 100644 --- a/api/test/fake_frame_decryptor.h +++ b/api/test/fake_frame_decryptor.h @@ -35,11 +35,12 @@ class FakeFrameDecryptor final uint8_t expected_postfix_byte = 255); // Fake decryption that just xors the payload with the 1 byte key and checks // the postfix byte. This will always fail if fail_decryption_ is set to true. - Result Decrypt(cricket::MediaType media_type, - const std::vector& csrcs, - rtc::ArrayView additional_data, - rtc::ArrayView encrypted_frame, - rtc::ArrayView frame) override; + int Decrypt(cricket::MediaType media_type, + const std::vector& csrcs, + rtc::ArrayView additional_data, + rtc::ArrayView encrypted_frame, + rtc::ArrayView frame, + size_t* bytes_written) override; // Always returns 1 less than the size of the encrypted frame. size_t GetMaxPlaintextByteSize(cricket::MediaType media_type, size_t encrypted_frame_size) override; diff --git a/api/test/mock_frame_decryptor.h b/api/test/mock_frame_decryptor.h index 77aa4f9147..feac9b3809 100644 --- a/api/test/mock_frame_decryptor.h +++ b/api/test/mock_frame_decryptor.h @@ -23,12 +23,13 @@ class MockFrameDecryptor : public FrameDecryptorInterface { MockFrameDecryptor(); ~MockFrameDecryptor() override; - MOCK_METHOD5(Decrypt, - Result(cricket::MediaType, - const std::vector&, - rtc::ArrayView, - rtc::ArrayView, - rtc::ArrayView)); + MOCK_METHOD6(Decrypt, + int(cricket::MediaType, + const std::vector&, + rtc::ArrayView, + rtc::ArrayView, + rtc::ArrayView, + size_t*)); MOCK_METHOD2(GetMaxPlaintextByteSize, size_t(cricket::MediaType, size_t encrypted_frame_size)); diff --git a/audio/channel_receive.cc b/audio/channel_receive.cc index 4f00d9fdf2..088aa57a07 100644 --- a/audio/channel_receive.cc +++ b/audio/channel_receive.cc @@ -103,8 +103,8 @@ class ChannelReceive : public ChannelReceiveInterface, void StopPlayout() override; // Codecs - absl::optional> GetReceiveCodec() - const override; + absl::optional> + GetReceiveCodec() const override; void ReceivedRTCPPacket(const uint8_t* data, size_t length) override; @@ -627,26 +627,28 @@ bool ChannelReceive::ReceivePacket(const uint8_t* packet, // Keep this buffer around for the lifetime of the OnReceivedPayloadData call. rtc::Buffer decrypted_audio_payload; if (frame_decryptor_ != nullptr) { - const size_t max_plaintext_size = frame_decryptor_->GetMaxPlaintextByteSize( + size_t max_plaintext_size = frame_decryptor_->GetMaxPlaintextByteSize( cricket::MEDIA_TYPE_AUDIO, payload_length); decrypted_audio_payload.SetSize(max_plaintext_size); - const std::vector csrcs(header.arrOfCSRCs, - header.arrOfCSRCs + header.numCSRCs); - const FrameDecryptorInterface::Result decrypt_result = - frame_decryptor_->Decrypt( - cricket::MEDIA_TYPE_AUDIO, csrcs, - /*additional_data=*/nullptr, - rtc::ArrayView(payload, payload_data_length), - decrypted_audio_payload); + size_t bytes_written = 0; + std::vector csrcs(header.arrOfCSRCs, + header.arrOfCSRCs + header.numCSRCs); + int decrypt_status = frame_decryptor_->Decrypt( + cricket::MEDIA_TYPE_AUDIO, csrcs, + /*additional_data=*/nullptr, + rtc::ArrayView(payload, payload_data_length), + decrypted_audio_payload, &bytes_written); - if (decrypt_result.IsOk()) { - decrypted_audio_payload.SetSize(decrypt_result.bytes_written); - } else { - // Interpret failures as a silent frame. - decrypted_audio_payload.SetSize(0); + // In this case just interpret the failure as a silent frame. + if (decrypt_status != 0) { + bytes_written = 0; } + // Resize the decrypted audio payload to the number of bytes actually + // written. + decrypted_audio_payload.SetSize(bytes_written); + // Update the final payload. payload = decrypted_audio_payload.data(); payload_data_length = decrypted_audio_payload.size(); } else if (crypto_options_.sframe.require_frame_encryption) { diff --git a/video/buffered_frame_decryptor.cc b/video/buffered_frame_decryptor.cc index 41eddea17e..2d7ec25098 100644 --- a/video/buffered_frame_decryptor.cc +++ b/video/buffered_frame_decryptor.cc @@ -83,25 +83,25 @@ BufferedFrameDecryptor::FrameDecision BufferedFrameDecryptor::DecryptFrame( } // Attempt to decrypt the video frame. - const FrameDecryptorInterface::Result decrypt_result = - frame_decryptor_->Decrypt(cricket::MEDIA_TYPE_VIDEO, /*csrcs=*/{}, - additional_data, *frame, - inline_decrypted_bitstream); + size_t bytes_written = 0; + const int status = frame_decryptor_->Decrypt( + cricket::MEDIA_TYPE_VIDEO, /*csrcs=*/{}, additional_data, *frame, + inline_decrypted_bitstream, &bytes_written); + // Optionally call the callback if there was a change in status - if (decrypt_result.status != last_status_) { - last_status_ = decrypt_result.status; - decryption_status_change_callback_->OnDecryptionStatusChange( - decrypt_result.status); + if (status != last_status_) { + last_status_ = status; + decryption_status_change_callback_->OnDecryptionStatusChange(status); } - if (!decrypt_result.IsOk()) { + if (status != 0) { // Only stash frames if we have never decrypted a frame before. return first_frame_decrypted_ ? FrameDecision::kDrop : FrameDecision::kStash; } - RTC_CHECK_LE(decrypt_result.bytes_written, max_plaintext_byte_size); + RTC_CHECK_LE(bytes_written, max_plaintext_byte_size); // Update the frame to contain just the written bytes. - frame->set_size(decrypt_result.bytes_written); + frame->set_size(bytes_written); // Indicate that all future fail to decrypt frames should be dropped. if (!first_frame_decrypted_) { diff --git a/video/buffered_frame_decryptor.h b/video/buffered_frame_decryptor.h index 49ab9a7bd9..06992bbfb5 100644 --- a/video/buffered_frame_decryptor.h +++ b/video/buffered_frame_decryptor.h @@ -42,8 +42,7 @@ class OnDecryptionStatusChangeCallback { // blocking so the caller must relinquish the callback quickly. This status // must match what is specified in the FrameDecryptorInterface file. Notably // 0 must indicate success and any positive integer is a failure. - virtual void OnDecryptionStatusChange( - FrameDecryptorInterface::Status status) = 0; + virtual void OnDecryptionStatusChange(int status) = 0; }; // The BufferedFrameDecryptor is responsible for deciding when to pass @@ -92,8 +91,7 @@ class BufferedFrameDecryptor final { const bool generic_descriptor_auth_experiment_; bool first_frame_decrypted_ = false; - FrameDecryptorInterface::Status last_status_ = - FrameDecryptorInterface::Status::kUnknown; + int last_status_ = -1; rtc::scoped_refptr frame_decryptor_; OnDecryptedFrameCallback* const decrypted_frame_callback_; OnDecryptionStatusChangeCallback* const decryption_status_change_callback_; diff --git a/video/buffered_frame_decryptor_unittest.cc b/video/buffered_frame_decryptor_unittest.cc index a056b4a720..7926f2421e 100644 --- a/video/buffered_frame_decryptor_unittest.cc +++ b/video/buffered_frame_decryptor_unittest.cc @@ -55,16 +55,6 @@ class FakePacketBuffer : public video_coding::PacketBuffer { std::map packets_; }; -FrameDecryptorInterface::Result DecryptSuccess() { - return FrameDecryptorInterface::Result(FrameDecryptorInterface::Status::kOk, - 0); -} - -FrameDecryptorInterface::Result DecryptFail() { - return FrameDecryptorInterface::Result( - FrameDecryptorInterface::Status::kFailedToDecrypt, 0); -} - } // namespace class BufferedFrameDecryptorTest @@ -79,7 +69,7 @@ class BufferedFrameDecryptorTest decrypted_frame_call_count_++; } - void OnDecryptionStatusChange(FrameDecryptorInterface::Status status) { + void OnDecryptionStatusChange(int status) { ++decryption_status_change_count_; } @@ -137,9 +127,7 @@ const size_t BufferedFrameDecryptorTest::kMaxStashedFrames = 24; // Callback should always be triggered on a successful decryption. TEST_F(BufferedFrameDecryptorTest, CallbackCalledOnSuccessfulDecryption) { - EXPECT_CALL(*mock_frame_decryptor_, Decrypt) - .Times(1) - .WillOnce(Return(DecryptSuccess())); + EXPECT_CALL(*mock_frame_decryptor_, Decrypt).Times(1).WillOnce(Return(0)); EXPECT_CALL(*mock_frame_decryptor_, GetMaxPlaintextByteSize) .Times(1) .WillOnce(Return(0)); @@ -150,9 +138,7 @@ TEST_F(BufferedFrameDecryptorTest, CallbackCalledOnSuccessfulDecryption) { // An initial fail to decrypt should not trigger the callback. TEST_F(BufferedFrameDecryptorTest, CallbackNotCalledOnFailedDecryption) { - EXPECT_CALL(*mock_frame_decryptor_, Decrypt) - .Times(1) - .WillOnce(Return(DecryptFail())); + EXPECT_CALL(*mock_frame_decryptor_, Decrypt).Times(1).WillOnce(Return(1)); EXPECT_CALL(*mock_frame_decryptor_, GetMaxPlaintextByteSize) .Times(1) .WillOnce(Return(0)); @@ -166,9 +152,9 @@ TEST_F(BufferedFrameDecryptorTest, CallbackNotCalledOnFailedDecryption) { TEST_F(BufferedFrameDecryptorTest, DelayedCallbackOnBufferedFrames) { EXPECT_CALL(*mock_frame_decryptor_, Decrypt) .Times(3) - .WillOnce(Return(DecryptFail())) - .WillOnce(Return(DecryptSuccess())) - .WillOnce(Return(DecryptSuccess())); + .WillOnce(Return(1)) + .WillOnce(Return(0)) + .WillOnce(Return(0)); EXPECT_CALL(*mock_frame_decryptor_, GetMaxPlaintextByteSize) .Times(3) .WillRepeatedly(Return(0)); @@ -188,10 +174,10 @@ TEST_F(BufferedFrameDecryptorTest, DelayedCallbackOnBufferedFrames) { TEST_F(BufferedFrameDecryptorTest, FTDDiscardedAfterFirstSuccess) { EXPECT_CALL(*mock_frame_decryptor_, Decrypt) .Times(4) - .WillOnce(Return(DecryptFail())) - .WillOnce(Return(DecryptSuccess())) - .WillOnce(Return(DecryptSuccess())) - .WillOnce(Return(DecryptFail())); + .WillOnce(Return(1)) + .WillOnce(Return(0)) + .WillOnce(Return(0)) + .WillOnce(Return(1)); EXPECT_CALL(*mock_frame_decryptor_, GetMaxPlaintextByteSize) .Times(4) .WillRepeatedly(Return(0)); @@ -217,7 +203,7 @@ TEST_F(BufferedFrameDecryptorTest, MaximumNumberOfFramesStored) { const size_t failed_to_decrypt_count = kMaxStashedFrames * 2; EXPECT_CALL(*mock_frame_decryptor_, Decrypt) .Times(failed_to_decrypt_count) - .WillRepeatedly(Return(DecryptFail())); + .WillRepeatedly(Return(1)); EXPECT_CALL(*mock_frame_decryptor_, GetMaxPlaintextByteSize) .WillRepeatedly(Return(0)); @@ -229,7 +215,7 @@ TEST_F(BufferedFrameDecryptorTest, MaximumNumberOfFramesStored) { EXPECT_CALL(*mock_frame_decryptor_, Decrypt) .Times(kMaxStashedFrames + 1) - .WillRepeatedly(Return(DecryptSuccess())); + .WillRepeatedly(Return(0)); buffered_frame_decryptor_->ManageEncryptedFrame(CreateRtpFrameObject(true)); EXPECT_EQ(decrypted_frame_call_count_, kMaxStashedFrames + 1); EXPECT_EQ(decryption_status_change_count_, static_cast(2)); @@ -245,7 +231,7 @@ TEST_F(BufferedFrameDecryptorTest, FramesStoredIfDecryptorNull) { EXPECT_CALL(*mock_frame_decryptor_, Decrypt) .Times(kMaxStashedFrames + 1) - .WillRepeatedly(Return(DecryptSuccess())); + .WillRepeatedly(Return(0)); EXPECT_CALL(*mock_frame_decryptor_, GetMaxPlaintextByteSize) .WillRepeatedly(Return(0)); diff --git a/video/rtp_video_stream_receiver.cc b/video/rtp_video_stream_receiver.cc index 0aaefe59d2..0a63c8761e 100644 --- a/video/rtp_video_stream_receiver.cc +++ b/video/rtp_video_stream_receiver.cc @@ -451,11 +451,8 @@ void RtpVideoStreamReceiver::OnDecryptedFrame( reference_finder_->ManageFrame(std::move(frame)); } -void RtpVideoStreamReceiver::OnDecryptionStatusChange( - FrameDecryptorInterface::Status status) { - frames_decryptable_.store( - (status == FrameDecryptorInterface::Status::kOk) || - (status == FrameDecryptorInterface::Status::kRecoverable)); +void RtpVideoStreamReceiver::OnDecryptionStatusChange(int status) { + frames_decryptable_.store(status == 0); } void RtpVideoStreamReceiver::SetFrameDecryptor( diff --git a/video/rtp_video_stream_receiver.h b/video/rtp_video_stream_receiver.h index f2ac9de076..1bc5d8a8b3 100644 --- a/video/rtp_video_stream_receiver.h +++ b/video/rtp_video_stream_receiver.h @@ -153,8 +153,7 @@ class RtpVideoStreamReceiver : public LossNotificationSender, std::unique_ptr frame) override; // Implements OnDecryptionStatusChangeCallback. - void OnDecryptionStatusChange( - FrameDecryptorInterface::Status status) override; + void OnDecryptionStatusChange(int status) override; // Optionally set a frame decryptor after a stream has started. This will not // reset the decoder state.