diff --git a/api/transport/stun.cc b/api/transport/stun.cc index 80b7b82d9a..4ae834c0c4 100644 --- a/api/transport/stun.cc +++ b/api/transport/stun.cc @@ -370,6 +370,28 @@ bool StunMessage::ValidateFingerprint(const char* data, size_t size) { rtc::ComputeCrc32(data, size - fingerprint_attr_size)); } +bool StunMessage::IsStunMethod(rtc::ArrayView methods, + const char* data, + size_t size) { + // Check the message length. + if (size % 4 != 0 || size < kStunHeaderSize) + return false; + + // Skip the rest if the magic cookie isn't present. + const char* magic_cookie = + data + kStunTransactionIdOffset - kStunMagicCookieLength; + if (rtc::GetBE32(magic_cookie) != kStunMagicCookie) + return false; + + int method = rtc::GetBE16(data); + for (int m : methods) { + if (m == method) { + return true; + } + } + return false; +} + bool StunMessage::AddFingerprint() { // Add the attribute with a dummy value. Since this is a known attribute, // it can't fail. @@ -557,6 +579,44 @@ bool StunMessage::IsValidTransactionId(const std::string& transaction_id) { transaction_id.size() == kStunLegacyTransactionIdLength; } +bool StunMessage::EqualAttributes( + const StunMessage* other, + std::function attribute_type_mask) const { + RTC_DCHECK(other != nullptr); + rtc::ByteBufferWriter tmp_buffer_ptr1; + rtc::ByteBufferWriter tmp_buffer_ptr2; + for (const auto& attr : attrs_) { + if (attribute_type_mask(attr->type())) { + const StunAttribute* other_attr = other->GetAttribute(attr->type()); + if (other_attr == nullptr) { + return false; + } + tmp_buffer_ptr1.Clear(); + tmp_buffer_ptr2.Clear(); + attr->Write(&tmp_buffer_ptr1); + other_attr->Write(&tmp_buffer_ptr2); + if (tmp_buffer_ptr1.Length() != tmp_buffer_ptr2.Length()) { + return false; + } + if (memcmp(tmp_buffer_ptr1.Data(), tmp_buffer_ptr2.Data(), + tmp_buffer_ptr1.Length()) != 0) { + return false; + } + } + } + + for (const auto& attr : other->attrs_) { + if (attribute_type_mask(attr->type())) { + const StunAttribute* own_attr = GetAttribute(attr->type()); + if (own_attr == nullptr) { + return false; + } + // we have already compared all values... + } + } + return true; +} + // StunAttribute StunAttribute::StunAttribute(uint16_t type, uint16_t length) @@ -1205,4 +1265,20 @@ StunMessage* IceMessage::CreateNew() const { return new IceMessage(); } +std::unique_ptr StunMessage::Clone() const { + std::unique_ptr copy(CreateNew()); + if (!copy) { + return nullptr; + } + rtc::ByteBufferWriter buf; + if (!Write(&buf)) { + return nullptr; + } + rtc::ByteBufferReader reader(buf); + if (!copy->Read(&reader)) { + return nullptr; + } + return copy; +} + } // namespace cricket diff --git a/api/transport/stun.h b/api/transport/stun.h index 857a381078..1c2cb804d0 100644 --- a/api/transport/stun.h +++ b/api/transport/stun.h @@ -202,6 +202,12 @@ class StunMessage { // current message. bool AddMessageIntegrity32(absl::string_view password); + // Verify that a buffer has stun magic cookie and one of the specified + // methods. Note that it does not check for the existance of FINGERPRINT. + static bool IsStunMethod(rtc::ArrayView methods, + const char* data, + size_t size); + // Verifies that a given buffer is STUN by checking for a correct FINGERPRINT. static bool ValidateFingerprint(const char* data, size_t size); @@ -223,10 +229,20 @@ class StunMessage { // This is used for testing. void SetStunMagicCookie(uint32_t val); + // Contruct a copy of |this|. + std::unique_ptr Clone() const; + + // Check if the attributes of this StunMessage equals those of |other| + // for all attributes that |attribute_type_mask| return true + bool EqualAttributes(const StunMessage* other, + std::function attribute_type_mask) const; + protected: // Verifies that the given attribute is allowed for this message. virtual StunAttributeValueType GetAttributeValueType(int type) const; + std::vector> attrs_; + private: StunAttribute* CreateAttribute(int type, size_t length) /* const*/; const StunAttribute* GetAttribute(int type) const; @@ -245,7 +261,6 @@ class StunMessage { uint16_t length_; std::string transaction_id_; uint32_t reduced_transaction_id_; - std::vector> attrs_; uint32_t stun_magic_cookie_; }; diff --git a/api/transport/stun_unittest.cc b/api/transport/stun_unittest.cc index c75fb90500..84a61f5774 100644 --- a/api/transport/stun_unittest.cc +++ b/api/transport/stun_unittest.cc @@ -1751,6 +1751,109 @@ TEST_F(StunTest, CopyAttribute) { } } +// Test Clone +TEST_F(StunTest, Clone) { + IceMessage msg; + { + auto errorcode = StunAttribute::CreateErrorCode(); + errorcode->SetCode(kTestErrorCode); + errorcode->SetReason(kTestErrorReason); + msg.AddAttribute(std::move(errorcode)); + } + { + auto bytes2 = StunAttribute::CreateByteString(STUN_ATTR_USERNAME); + bytes2->CopyBytes("abcdefghijkl"); + msg.AddAttribute(std::move(bytes2)); + } + { + auto uval2 = StunAttribute::CreateUInt32(STUN_ATTR_RETRANSMIT_COUNT); + uval2->SetValue(11); + msg.AddAttribute(std::move(uval2)); + } + { + auto addr = StunAttribute::CreateAddress(STUN_ATTR_MAPPED_ADDRESS); + addr->SetIP(rtc::IPAddress(kIPv6TestAddress1)); + addr->SetPort(kTestMessagePort1); + msg.AddAttribute(std::move(addr)); + } + auto copy = msg.Clone(); + ASSERT_NE(nullptr, copy.get()); + + msg.SetTransactionID("0123456789ab"); + copy->SetTransactionID("0123456789ab"); + + rtc::ByteBufferWriter out1; + EXPECT_TRUE(msg.Write(&out1)); + rtc::ByteBufferWriter out2; + EXPECT_TRUE(copy->Write(&out2)); + + ASSERT_EQ(out1.Length(), out2.Length()); + EXPECT_EQ(0, memcmp(out1.Data(), out2.Data(), out1.Length())); +} + +// Test EqualAttributes +TEST_F(StunTest, EqualAttributes) { + IceMessage msg; + { + auto errorcode = StunAttribute::CreateErrorCode(); + errorcode->SetCode(kTestErrorCode); + errorcode->SetReason(kTestErrorReason); + msg.AddAttribute(std::move(errorcode)); + } + { + auto bytes2 = StunAttribute::CreateByteString(STUN_ATTR_USERNAME); + bytes2->CopyBytes("abcdefghijkl"); + msg.AddAttribute(std::move(bytes2)); + } + { + auto uval2 = StunAttribute::CreateUInt32(STUN_ATTR_RETRANSMIT_COUNT); + uval2->SetValue(11); + msg.AddAttribute(std::move(uval2)); + } + { + auto addr = StunAttribute::CreateAddress(STUN_ATTR_MAPPED_ADDRESS); + addr->SetIP(rtc::IPAddress(kIPv6TestAddress1)); + addr->SetPort(kTestMessagePort1); + msg.AddAttribute(std::move(addr)); + } + auto copy = msg.Clone(); + ASSERT_NE(nullptr, copy.get()); + + EXPECT_TRUE(copy->EqualAttributes(&msg, [](int type) { return true; })); + + { + auto attr = StunAttribute::CreateByteString(STUN_ATTR_NONCE); + attr->CopyBytes("keso"); + msg.AddAttribute(std::move(attr)); + EXPECT_FALSE(copy->EqualAttributes(&msg, [](int type) { return true; })); + EXPECT_TRUE(copy->EqualAttributes( + &msg, [](int type) { return type != STUN_ATTR_NONCE; })); + } + + { + auto attr = StunAttribute::CreateByteString(STUN_ATTR_NONCE); + attr->CopyBytes("keso"); + copy->AddAttribute(std::move(attr)); + EXPECT_TRUE(copy->EqualAttributes(&msg, [](int type) { return true; })); + } + { + copy->RemoveAttribute(STUN_ATTR_NONCE); + auto attr = StunAttribute::CreateByteString(STUN_ATTR_NONCE); + attr->CopyBytes("kent"); + copy->AddAttribute(std::move(attr)); + EXPECT_FALSE(copy->EqualAttributes(&msg, [](int type) { return true; })); + EXPECT_TRUE(copy->EqualAttributes( + &msg, [](int type) { return type != STUN_ATTR_NONCE; })); + } + + { + msg.RemoveAttribute(STUN_ATTR_NONCE); + EXPECT_FALSE(copy->EqualAttributes(&msg, [](int type) { return true; })); + EXPECT_TRUE(copy->EqualAttributes( + &msg, [](int type) { return type != STUN_ATTR_NONCE; })); + } +} + TEST_F(StunTest, ReduceTransactionIdIsHostOrderIndependent) { std::string transaction_id = "abcdefghijkl"; StunMessage message; @@ -1793,4 +1896,11 @@ TEST_F(StunTest, GoogMiscInfo) { EXPECT_EQ(0xAB0CU, types->GetType(2)); } +TEST_F(StunTest, IsStunMethod) { + int methods[] = {STUN_BINDING_REQUEST}; + EXPECT_TRUE(StunMessage::IsStunMethod( + methods, reinterpret_cast(kRfc5769SampleRequest), + sizeof(kRfc5769SampleRequest))); +} + } // namespace cricket