From f2a2bf4ae448dff45974a7f38f8db005f643f3a3 Mon Sep 17 00:00:00 2001 From: jbauch Date: Wed, 3 Feb 2016 16:45:32 -0800 Subject: [PATCH] Stay writable after partial socket writes. This CL fixes an issue where the "writable" flag didn't stay set after ::send or ::sendto only sent a partial buffer. Also SocketTest::TcpInternal has been updated to use rtc::Buffer instead of manually allocating data. BUG=webrtc:4898 Review URL: https://codereview.webrtc.org/1616153007 Cr-Commit-Position: refs/heads/master@{#11480} --- webrtc/base/physicalsocketserver.cc | 25 ++- webrtc/base/physicalsocketserver.h | 11 +- webrtc/base/physicalsocketserver_unittest.cc | 85 ++++++++++- webrtc/base/socket_unittest.cc | 153 ++++++++++--------- webrtc/base/socket_unittest.h | 5 +- 5 files changed, 194 insertions(+), 85 deletions(-) diff --git a/webrtc/base/physicalsocketserver.cc b/webrtc/base/physicalsocketserver.cc index 67fbea0a2f..4cea0403df 100644 --- a/webrtc/base/physicalsocketserver.cc +++ b/webrtc/base/physicalsocketserver.cc @@ -271,7 +271,8 @@ int PhysicalSocket::SetOption(Option opt, int value) { } int PhysicalSocket::Send(const void* pv, size_t cb) { - int sent = ::send(s_, reinterpret_cast(pv), (int)cb, + int sent = DoSend(s_, reinterpret_cast(pv), + static_cast(cb), #if defined(WEBRTC_LINUX) && !defined(WEBRTC_ANDROID) // Suppress SIGPIPE. Without this, attempting to send on a socket whose // other end is closed will result in a SIGPIPE signal being raised to @@ -287,7 +288,8 @@ int PhysicalSocket::Send(const void* pv, size_t cb) { MaybeRemapSendError(); // We have seen minidumps where this may be false. ASSERT(sent <= static_cast(cb)); - if ((sent < 0) && IsBlockingError(GetError())) { + if ((sent > 0 && sent < static_cast(cb)) || + (sent < 0 && IsBlockingError(GetError()))) { enabled_events_ |= DE_WRITE; } return sent; @@ -298,7 +300,7 @@ int PhysicalSocket::SendTo(const void* buffer, const SocketAddress& addr) { sockaddr_storage saddr; size_t len = addr.ToSockAddrStorage(&saddr); - int sent = ::sendto( + int sent = DoSendTo( s_, static_cast(buffer), static_cast(length), #if defined(WEBRTC_LINUX) && !defined(WEBRTC_ANDROID) // Suppress SIGPIPE. See above for explanation. @@ -311,7 +313,8 @@ int PhysicalSocket::SendTo(const void* buffer, MaybeRemapSendError(); // We have seen minidumps where this may be false. ASSERT(sent <= static_cast(length)); - if ((sent < 0) && IsBlockingError(GetError())) { + if ((sent > 0 && sent < static_cast(length)) || + (sent < 0 && IsBlockingError(GetError()))) { enabled_events_ |= DE_WRITE; } return sent; @@ -474,13 +477,25 @@ int PhysicalSocket::EstimateMTU(uint16_t* mtu) { #endif } - SOCKET PhysicalSocket::DoAccept(SOCKET socket, sockaddr* addr, socklen_t* addrlen) { return ::accept(socket, addr, addrlen); } +int PhysicalSocket::DoSend(SOCKET socket, const char* buf, int len, int flags) { + return ::send(socket, buf, len, flags); +} + +int PhysicalSocket::DoSendTo(SOCKET socket, + const char* buf, + int len, + int flags, + const struct sockaddr* dest_addr, + socklen_t addrlen) { + return ::sendto(socket, buf, len, flags, dest_addr, addrlen); +} + void PhysicalSocket::OnResolveResult(AsyncResolverInterface* resolver) { if (resolver != resolver_) { return; diff --git a/webrtc/base/physicalsocketserver.h b/webrtc/base/physicalsocketserver.h index cbe6580b62..583306c31d 100644 --- a/webrtc/base/physicalsocketserver.h +++ b/webrtc/base/physicalsocketserver.h @@ -68,8 +68,8 @@ class PhysicalSocketServer : public SocketServer { AsyncSocket* CreateAsyncSocket(int type) override; AsyncSocket* CreateAsyncSocket(int family, int type) override; - // Internal Factory for Accept - AsyncSocket* WrapSocket(SOCKET s); + // Internal Factory for Accept (virtual so it can be overwritten in tests). + virtual AsyncSocket* WrapSocket(SOCKET s); // SocketServer: bool Wait(int cms, bool process_io) override; @@ -161,6 +161,13 @@ class PhysicalSocket : public AsyncSocket, public sigslot::has_slots<> { // Make virtual so ::accept can be overwritten in tests. virtual SOCKET DoAccept(SOCKET socket, sockaddr* addr, socklen_t* addrlen); + // Make virtual so ::send can be overwritten in tests. + virtual int DoSend(SOCKET socket, const char* buf, int len, int flags); + + // Make virtual so ::sendto can be overwritten in tests. + virtual int DoSendTo(SOCKET socket, const char* buf, int len, int flags, + const struct sockaddr* dest_addr, socklen_t addrlen); + void OnResolveResult(AsyncResolverInterface* resolver); void UpdateLastError(); diff --git a/webrtc/base/physicalsocketserver_unittest.cc b/webrtc/base/physicalsocketserver_unittest.cc index a2fde80b42..c53441d1a0 100644 --- a/webrtc/base/physicalsocketserver_unittest.cc +++ b/webrtc/base/physicalsocketserver_unittest.cc @@ -29,8 +29,15 @@ class FakeSocketDispatcher : public SocketDispatcher { : SocketDispatcher(ss) { } + FakeSocketDispatcher(SOCKET s, PhysicalSocketServer* ss) + : SocketDispatcher(s, ss) { + } + protected: SOCKET DoAccept(SOCKET socket, sockaddr* addr, socklen_t* addrlen) override; + int DoSend(SOCKET socket, const char* buf, int len, int flags) override; + int DoSendTo(SOCKET socket, const char* buf, int len, int flags, + const struct sockaddr* dest_addr, socklen_t addrlen) override; }; class FakePhysicalSocketServer : public PhysicalSocketServer { @@ -41,22 +48,29 @@ class FakePhysicalSocketServer : public PhysicalSocketServer { AsyncSocket* CreateAsyncSocket(int type) override { SocketDispatcher* dispatcher = new FakeSocketDispatcher(this); - if (dispatcher->Create(type)) { - return dispatcher; - } else { + if (!dispatcher->Create(type)) { delete dispatcher; return nullptr; } + return dispatcher; } AsyncSocket* CreateAsyncSocket(int family, int type) override { SocketDispatcher* dispatcher = new FakeSocketDispatcher(this); - if (dispatcher->Create(family, type)) { - return dispatcher; - } else { + if (!dispatcher->Create(family, type)) { delete dispatcher; return nullptr; } + return dispatcher; + } + + AsyncSocket* WrapSocket(SOCKET s) override { + SocketDispatcher* dispatcher = new FakeSocketDispatcher(s, this); + if (!dispatcher->Initialize()) { + delete dispatcher; + return nullptr; + } + return dispatcher; } PhysicalSocketTest* GetTest() const { return test_; } @@ -71,18 +85,25 @@ class PhysicalSocketTest : public SocketTest { void SetFailAccept(bool fail) { fail_accept_ = fail; } bool FailAccept() const { return fail_accept_; } + // Maximum size to ::send to a socket. Set to < 0 to disable limiting. + void SetMaxSendSize(int max_size) { max_send_size_ = max_size; } + int MaxSendSize() const { return max_send_size_; } + protected: PhysicalSocketTest() : server_(new FakePhysicalSocketServer(this)), scope_(server_.get()), - fail_accept_(false) { + fail_accept_(false), + max_send_size_(-1) { } void ConnectInternalAcceptError(const IPAddress& loopback); + void WritableAfterPartialWrite(const IPAddress& loopback); rtc::scoped_ptr server_; SocketServerScope scope_; bool fail_accept_; + int max_send_size_; }; SOCKET FakeSocketDispatcher::DoAccept(SOCKET socket, @@ -97,6 +118,29 @@ SOCKET FakeSocketDispatcher::DoAccept(SOCKET socket, return SocketDispatcher::DoAccept(socket, addr, addrlen); } +int FakeSocketDispatcher::DoSend(SOCKET socket, const char* buf, int len, + int flags) { + FakePhysicalSocketServer* ss = + static_cast(socketserver()); + if (ss->GetTest()->MaxSendSize() >= 0) { + len = std::min(len, ss->GetTest()->MaxSendSize()); + } + + return SocketDispatcher::DoSend(socket, buf, len, flags); +} + +int FakeSocketDispatcher::DoSendTo(SOCKET socket, const char* buf, int len, + int flags, const struct sockaddr* dest_addr, socklen_t addrlen) { + FakePhysicalSocketServer* ss = + static_cast(socketserver()); + if (ss->GetTest()->MaxSendSize() >= 0) { + len = std::min(len, ss->GetTest()->MaxSendSize()); + } + + return SocketDispatcher::DoSendTo(socket, buf, len, flags, dest_addr, + addrlen); +} + TEST_F(PhysicalSocketTest, TestConnectIPv4) { SocketTest::TestConnectIPv4(); } @@ -209,6 +253,33 @@ TEST_F(PhysicalSocketTest, MAYBE_TestConnectAcceptErrorIPv6) { ConnectInternalAcceptError(kIPv6Loopback); } +void PhysicalSocketTest::WritableAfterPartialWrite(const IPAddress& loopback) { + // Simulate a really small maximum send size. + const int kMaxSendSize = 128; + SetMaxSendSize(kMaxSendSize); + + // Run the default send/receive socket tests with a smaller amount of data + // to avoid long running times due to the small maximum send size. + const size_t kDataSize = 128 * 1024; + TcpInternal(loopback, kDataSize, kMaxSendSize); +} + +TEST_F(PhysicalSocketTest, TestWritableAfterPartialWriteIPv4) { + WritableAfterPartialWrite(kIPv4Loopback); +} + +// Crashes on Linux. See webrtc:4923. +#if defined(WEBRTC_LINUX) +#define MAYBE_TestWritableAfterPartialWriteIPv6 \ + DISABLED_TestWritableAfterPartialWriteIPv6 +#else +#define MAYBE_TestWritableAfterPartialWriteIPv6 \ + TestWritableAfterPartialWriteIPv6 +#endif +TEST_F(PhysicalSocketTest, MAYBE_TestWritableAfterPartialWriteIPv6) { + WritableAfterPartialWrite(kIPv6Loopback); +} + // Crashes on Linux. See webrtc:4923. #if defined(WEBRTC_LINUX) #define MAYBE_TestConnectFailIPv6 DISABLED_TestConnectFailIPv6 diff --git a/webrtc/base/socket_unittest.cc b/webrtc/base/socket_unittest.cc index 8143823b86..d1369e2f78 100644 --- a/webrtc/base/socket_unittest.cc +++ b/webrtc/base/socket_unittest.cc @@ -11,6 +11,7 @@ #include "webrtc/base/socket_unittest.h" #include "webrtc/base/arraysize.h" +#include "webrtc/base/buffer.h" #include "webrtc/base/asyncudpsocket.h" #include "webrtc/base/gunit.h" #include "webrtc/base/nethelpers.h" @@ -21,6 +22,9 @@ namespace rtc { +// Data size to be used in TcpInternal tests. +static const size_t kTcpInternalDataSize = 1024 * 1024; // bytes + #define MAYBE_SKIP_IPV6 \ if (!HasIPv6Enabled()) { \ LOG(LS_INFO) << "No IPv6... skipping"; \ @@ -129,12 +133,12 @@ void SocketTest::TestSocketServerWaitIPv6() { } void SocketTest::TestTcpIPv4() { - TcpInternal(kIPv4Loopback); + TcpInternal(kIPv4Loopback, kTcpInternalDataSize, -1); } void SocketTest::TestTcpIPv6() { MAYBE_SKIP_IPV6; - TcpInternal(kIPv6Loopback); + TcpInternal(kIPv6Loopback, kTcpInternalDataSize, -1); } void SocketTest::TestSingleFlowControlCallbackIPv4() { @@ -671,24 +675,15 @@ void SocketTest::SocketServerWaitInternal(const IPAddress& loopback) { EXPECT_LT(0, accepted->Recv(buf, 1024)); } -void SocketTest::TcpInternal(const IPAddress& loopback) { +void SocketTest::TcpInternal(const IPAddress& loopback, size_t data_size, + ssize_t max_send_size) { testing::StreamSink sink; SocketAddress accept_addr; - // Create test data. - const size_t kDataSize = 1024 * 1024; - scoped_ptr send_buffer(new char[kDataSize]); - scoped_ptr recv_buffer(new char[kDataSize]); - size_t send_pos = 0, recv_pos = 0; - for (size_t i = 0; i < kDataSize; ++i) { - send_buffer[i] = static_cast(i % 256); - recv_buffer[i] = 0; - } - - // Create client. - scoped_ptr client( + // Create receiving client. + scoped_ptr receiver( ss_->CreateAsyncSocket(loopback.family(), SOCK_STREAM)); - sink.Monitor(client.get()); + sink.Monitor(receiver.get()); // Create server and listen. scoped_ptr server( @@ -698,97 +693,115 @@ void SocketTest::TcpInternal(const IPAddress& loopback) { EXPECT_EQ(0, server->Listen(5)); // Attempt connection. - EXPECT_EQ(0, client->Connect(server->GetLocalAddress())); + EXPECT_EQ(0, receiver->Connect(server->GetLocalAddress())); - // Accept connection. + // Accept connection which will be used for sending. EXPECT_TRUE_WAIT((sink.Check(server.get(), testing::SSE_READ)), kTimeout); - scoped_ptr accepted(server->Accept(&accept_addr)); - ASSERT_TRUE(accepted); - sink.Monitor(accepted.get()); + scoped_ptr sender(server->Accept(&accept_addr)); + ASSERT_TRUE(sender); + sink.Monitor(sender.get()); // Both sides are now connected. - EXPECT_EQ_WAIT(AsyncSocket::CS_CONNECTED, client->GetState(), kTimeout); - EXPECT_TRUE(sink.Check(client.get(), testing::SSE_OPEN)); - EXPECT_EQ(client->GetRemoteAddress(), accepted->GetLocalAddress()); - EXPECT_EQ(accepted->GetRemoteAddress(), client->GetLocalAddress()); + EXPECT_EQ_WAIT(AsyncSocket::CS_CONNECTED, receiver->GetState(), kTimeout); + EXPECT_TRUE(sink.Check(receiver.get(), testing::SSE_OPEN)); + EXPECT_EQ(receiver->GetRemoteAddress(), sender->GetLocalAddress()); + EXPECT_EQ(sender->GetRemoteAddress(), receiver->GetLocalAddress()); + + // Create test data. + rtc::Buffer send_buffer(0, data_size); + rtc::Buffer recv_buffer(0, data_size); + for (size_t i = 0; i < data_size; ++i) { + char ch = static_cast(i % 256); + send_buffer.AppendData(&ch, sizeof(ch)); + } // Send and receive a bunch of data. - bool send_waiting_for_writability = false; - bool send_expect_success = true; - bool recv_waiting_for_readability = true; - bool recv_expect_success = false; - int data_in_flight = 0; - while (recv_pos < kDataSize) { - // Send as much as we can if we've been cleared to send. - while (!send_waiting_for_writability && send_pos < kDataSize) { - int tosend = static_cast(kDataSize - send_pos); - int sent = accepted->Send(send_buffer.get() + send_pos, tosend); - if (send_expect_success) { + size_t sent_size = 0; + bool writable = true; + bool send_called = false; + bool readable = false; + bool recv_called = false; + while (recv_buffer.size() < send_buffer.size()) { + // Send as much as we can while we're cleared to send. + while (writable && sent_size < send_buffer.size()) { + int unsent_size = static_cast(send_buffer.size() - sent_size); + int sent = sender->Send(send_buffer.data() + sent_size, unsent_size); + if (!send_called) { // The first Send() after connecting or getting writability should // succeed and send some data. EXPECT_GT(sent, 0); - send_expect_success = false; + send_called = true; } if (sent >= 0) { - EXPECT_LE(sent, tosend); - send_pos += sent; - data_in_flight += sent; + EXPECT_LE(sent, unsent_size); + sent_size += sent; + if (max_send_size >= 0) { + EXPECT_LE(static_cast(sent), max_send_size); + if (sent < unsent_size) { + // If max_send_size is limiting the amount to send per call such + // that the sent amount is less than the unsent amount, we simulate + // that the socket is no longer writable. + writable = false; + } + } } else { - ASSERT_TRUE(accepted->IsBlocking()); - send_waiting_for_writability = true; + ASSERT_TRUE(sender->IsBlocking()); + writable = false; } } // Read all the sent data. - while (data_in_flight > 0) { - if (recv_waiting_for_readability) { + while (recv_buffer.size() < sent_size) { + if (!readable) { // Wait until data is available. - EXPECT_TRUE_WAIT(sink.Check(client.get(), testing::SSE_READ), kTimeout); - recv_waiting_for_readability = false; - recv_expect_success = true; + EXPECT_TRUE_WAIT(sink.Check(receiver.get(), testing::SSE_READ), + kTimeout); + readable = true; + recv_called = false; } // Receive as much as we can get in a single recv call. - int rcvd = client->Recv(recv_buffer.get() + recv_pos, - kDataSize - recv_pos); + char recved_data[data_size]; + int recved_size = receiver->Recv(recved_data, data_size); - if (recv_expect_success) { + if (!recv_called) { // The first Recv() after getting readability should succeed and receive // some data. // TODO: The following line is disabled due to flakey pulse // builds. Re-enable if/when possible. - // EXPECT_GT(rcvd, 0); - recv_expect_success = false; + // EXPECT_GT(recved_size, 0); + recv_called = true; } - if (rcvd >= 0) { - EXPECT_LE(rcvd, data_in_flight); - recv_pos += rcvd; - data_in_flight -= rcvd; + if (recved_size >= 0) { + EXPECT_LE(static_cast(recved_size), + sent_size - recv_buffer.size()); + recv_buffer.AppendData(recved_data, recved_size); } else { - ASSERT_TRUE(client->IsBlocking()); - recv_waiting_for_readability = true; + ASSERT_TRUE(receiver->IsBlocking()); + readable = false; } } - // Once all that we've sent has been rcvd, expect to be able to send again. - if (send_waiting_for_writability) { - EXPECT_TRUE_WAIT(sink.Check(accepted.get(), testing::SSE_WRITE), + // Once all that we've sent has been received, expect to be able to send + // again. + if (!writable) { + EXPECT_TRUE_WAIT(sink.Check(sender.get(), testing::SSE_WRITE), kTimeout); - send_waiting_for_writability = false; - send_expect_success = true; + writable = true; + send_called = false; } } // The received data matches the sent data. - EXPECT_EQ(kDataSize, send_pos); - EXPECT_EQ(kDataSize, recv_pos); - EXPECT_EQ(0, memcmp(recv_buffer.get(), send_buffer.get(), kDataSize)); + EXPECT_EQ(data_size, sent_size); + EXPECT_EQ(data_size, recv_buffer.size()); + EXPECT_EQ(recv_buffer, send_buffer); // Close down. - accepted->Close(); - EXPECT_EQ_WAIT(AsyncSocket::CS_CLOSED, client->GetState(), kTimeout); - EXPECT_TRUE(sink.Check(client.get(), testing::SSE_CLOSE)); - client->Close(); + sender->Close(); + EXPECT_EQ_WAIT(AsyncSocket::CS_CLOSED, receiver->GetState(), kTimeout); + EXPECT_TRUE(sink.Check(receiver.get(), testing::SSE_CLOSE)); + receiver->Close(); } void SocketTest::SingleFlowControlCallbackInternal(const IPAddress& loopback) { diff --git a/webrtc/base/socket_unittest.h b/webrtc/base/socket_unittest.h index e4a6b32705..adc69f1465 100644 --- a/webrtc/base/socket_unittest.h +++ b/webrtc/base/socket_unittest.h @@ -62,6 +62,10 @@ class SocketTest : public testing::Test { const IPAddress kIPv4Loopback; const IPAddress kIPv6Loopback; + protected: + void TcpInternal(const IPAddress& loopback, size_t data_size, + ssize_t max_send_size); + private: void ConnectInternal(const IPAddress& loopback); void ConnectWithDnsLookupInternal(const IPAddress& loopback, @@ -76,7 +80,6 @@ class SocketTest : public testing::Test { void ServerCloseInternal(const IPAddress& loopback); void CloseInClosedCallbackInternal(const IPAddress& loopback); void SocketServerWaitInternal(const IPAddress& loopback); - void TcpInternal(const IPAddress& loopback); void SingleFlowControlCallbackInternal(const IPAddress& loopback); void UdpInternal(const IPAddress& loopback); void UdpReadyToSend(const IPAddress& loopback);