Stay writable after partial socket writes.

This CL fixes an issue where the "writable" flag didn't stay set after
::send or ::sendto only sent a partial buffer.

Also SocketTest::TcpInternal has been updated to use rtc::Buffer instead
of manually allocating data.

BUG=webrtc:4898

Review URL: https://codereview.webrtc.org/1616153007

Cr-Commit-Position: refs/heads/master@{#11480}
This commit is contained in:
jbauch
2016-02-03 16:45:32 -08:00
committed by Commit bot
parent 14d024d882
commit f2a2bf4ae4
5 changed files with 194 additions and 85 deletions

View File

@ -271,7 +271,8 @@ int PhysicalSocket::SetOption(Option opt, int value) {
}
int PhysicalSocket::Send(const void* pv, size_t cb) {
int sent = ::send(s_, reinterpret_cast<const char *>(pv), (int)cb,
int sent = DoSend(s_, reinterpret_cast<const char *>(pv),
static_cast<int>(cb),
#if defined(WEBRTC_LINUX) && !defined(WEBRTC_ANDROID)
// Suppress SIGPIPE. Without this, attempting to send on a socket whose
// other end is closed will result in a SIGPIPE signal being raised to
@ -287,7 +288,8 @@ int PhysicalSocket::Send(const void* pv, size_t cb) {
MaybeRemapSendError();
// We have seen minidumps where this may be false.
ASSERT(sent <= static_cast<int>(cb));
if ((sent < 0) && IsBlockingError(GetError())) {
if ((sent > 0 && sent < static_cast<int>(cb)) ||
(sent < 0 && IsBlockingError(GetError()))) {
enabled_events_ |= DE_WRITE;
}
return sent;
@ -298,7 +300,7 @@ int PhysicalSocket::SendTo(const void* buffer,
const SocketAddress& addr) {
sockaddr_storage saddr;
size_t len = addr.ToSockAddrStorage(&saddr);
int sent = ::sendto(
int sent = DoSendTo(
s_, static_cast<const char *>(buffer), static_cast<int>(length),
#if defined(WEBRTC_LINUX) && !defined(WEBRTC_ANDROID)
// Suppress SIGPIPE. See above for explanation.
@ -311,7 +313,8 @@ int PhysicalSocket::SendTo(const void* buffer,
MaybeRemapSendError();
// We have seen minidumps where this may be false.
ASSERT(sent <= static_cast<int>(length));
if ((sent < 0) && IsBlockingError(GetError())) {
if ((sent > 0 && sent < static_cast<int>(length)) ||
(sent < 0 && IsBlockingError(GetError()))) {
enabled_events_ |= DE_WRITE;
}
return sent;
@ -474,13 +477,25 @@ int PhysicalSocket::EstimateMTU(uint16_t* mtu) {
#endif
}
SOCKET PhysicalSocket::DoAccept(SOCKET socket,
sockaddr* addr,
socklen_t* addrlen) {
return ::accept(socket, addr, addrlen);
}
int PhysicalSocket::DoSend(SOCKET socket, const char* buf, int len, int flags) {
return ::send(socket, buf, len, flags);
}
int PhysicalSocket::DoSendTo(SOCKET socket,
const char* buf,
int len,
int flags,
const struct sockaddr* dest_addr,
socklen_t addrlen) {
return ::sendto(socket, buf, len, flags, dest_addr, addrlen);
}
void PhysicalSocket::OnResolveResult(AsyncResolverInterface* resolver) {
if (resolver != resolver_) {
return;

View File

@ -68,8 +68,8 @@ class PhysicalSocketServer : public SocketServer {
AsyncSocket* CreateAsyncSocket(int type) override;
AsyncSocket* CreateAsyncSocket(int family, int type) override;
// Internal Factory for Accept
AsyncSocket* WrapSocket(SOCKET s);
// Internal Factory for Accept (virtual so it can be overwritten in tests).
virtual AsyncSocket* WrapSocket(SOCKET s);
// SocketServer:
bool Wait(int cms, bool process_io) override;
@ -161,6 +161,13 @@ class PhysicalSocket : public AsyncSocket, public sigslot::has_slots<> {
// Make virtual so ::accept can be overwritten in tests.
virtual SOCKET DoAccept(SOCKET socket, sockaddr* addr, socklen_t* addrlen);
// Make virtual so ::send can be overwritten in tests.
virtual int DoSend(SOCKET socket, const char* buf, int len, int flags);
// Make virtual so ::sendto can be overwritten in tests.
virtual int DoSendTo(SOCKET socket, const char* buf, int len, int flags,
const struct sockaddr* dest_addr, socklen_t addrlen);
void OnResolveResult(AsyncResolverInterface* resolver);
void UpdateLastError();

View File

@ -29,8 +29,15 @@ class FakeSocketDispatcher : public SocketDispatcher {
: SocketDispatcher(ss) {
}
FakeSocketDispatcher(SOCKET s, PhysicalSocketServer* ss)
: SocketDispatcher(s, ss) {
}
protected:
SOCKET DoAccept(SOCKET socket, sockaddr* addr, socklen_t* addrlen) override;
int DoSend(SOCKET socket, const char* buf, int len, int flags) override;
int DoSendTo(SOCKET socket, const char* buf, int len, int flags,
const struct sockaddr* dest_addr, socklen_t addrlen) override;
};
class FakePhysicalSocketServer : public PhysicalSocketServer {
@ -41,22 +48,29 @@ class FakePhysicalSocketServer : public PhysicalSocketServer {
AsyncSocket* CreateAsyncSocket(int type) override {
SocketDispatcher* dispatcher = new FakeSocketDispatcher(this);
if (dispatcher->Create(type)) {
return dispatcher;
} else {
if (!dispatcher->Create(type)) {
delete dispatcher;
return nullptr;
}
return dispatcher;
}
AsyncSocket* CreateAsyncSocket(int family, int type) override {
SocketDispatcher* dispatcher = new FakeSocketDispatcher(this);
if (dispatcher->Create(family, type)) {
return dispatcher;
} else {
if (!dispatcher->Create(family, type)) {
delete dispatcher;
return nullptr;
}
return dispatcher;
}
AsyncSocket* WrapSocket(SOCKET s) override {
SocketDispatcher* dispatcher = new FakeSocketDispatcher(s, this);
if (!dispatcher->Initialize()) {
delete dispatcher;
return nullptr;
}
return dispatcher;
}
PhysicalSocketTest* GetTest() const { return test_; }
@ -71,18 +85,25 @@ class PhysicalSocketTest : public SocketTest {
void SetFailAccept(bool fail) { fail_accept_ = fail; }
bool FailAccept() const { return fail_accept_; }
// Maximum size to ::send to a socket. Set to < 0 to disable limiting.
void SetMaxSendSize(int max_size) { max_send_size_ = max_size; }
int MaxSendSize() const { return max_send_size_; }
protected:
PhysicalSocketTest()
: server_(new FakePhysicalSocketServer(this)),
scope_(server_.get()),
fail_accept_(false) {
fail_accept_(false),
max_send_size_(-1) {
}
void ConnectInternalAcceptError(const IPAddress& loopback);
void WritableAfterPartialWrite(const IPAddress& loopback);
rtc::scoped_ptr<FakePhysicalSocketServer> server_;
SocketServerScope scope_;
bool fail_accept_;
int max_send_size_;
};
SOCKET FakeSocketDispatcher::DoAccept(SOCKET socket,
@ -97,6 +118,29 @@ SOCKET FakeSocketDispatcher::DoAccept(SOCKET socket,
return SocketDispatcher::DoAccept(socket, addr, addrlen);
}
int FakeSocketDispatcher::DoSend(SOCKET socket, const char* buf, int len,
int flags) {
FakePhysicalSocketServer* ss =
static_cast<FakePhysicalSocketServer*>(socketserver());
if (ss->GetTest()->MaxSendSize() >= 0) {
len = std::min(len, ss->GetTest()->MaxSendSize());
}
return SocketDispatcher::DoSend(socket, buf, len, flags);
}
int FakeSocketDispatcher::DoSendTo(SOCKET socket, const char* buf, int len,
int flags, const struct sockaddr* dest_addr, socklen_t addrlen) {
FakePhysicalSocketServer* ss =
static_cast<FakePhysicalSocketServer*>(socketserver());
if (ss->GetTest()->MaxSendSize() >= 0) {
len = std::min(len, ss->GetTest()->MaxSendSize());
}
return SocketDispatcher::DoSendTo(socket, buf, len, flags, dest_addr,
addrlen);
}
TEST_F(PhysicalSocketTest, TestConnectIPv4) {
SocketTest::TestConnectIPv4();
}
@ -209,6 +253,33 @@ TEST_F(PhysicalSocketTest, MAYBE_TestConnectAcceptErrorIPv6) {
ConnectInternalAcceptError(kIPv6Loopback);
}
void PhysicalSocketTest::WritableAfterPartialWrite(const IPAddress& loopback) {
// Simulate a really small maximum send size.
const int kMaxSendSize = 128;
SetMaxSendSize(kMaxSendSize);
// Run the default send/receive socket tests with a smaller amount of data
// to avoid long running times due to the small maximum send size.
const size_t kDataSize = 128 * 1024;
TcpInternal(loopback, kDataSize, kMaxSendSize);
}
TEST_F(PhysicalSocketTest, TestWritableAfterPartialWriteIPv4) {
WritableAfterPartialWrite(kIPv4Loopback);
}
// Crashes on Linux. See webrtc:4923.
#if defined(WEBRTC_LINUX)
#define MAYBE_TestWritableAfterPartialWriteIPv6 \
DISABLED_TestWritableAfterPartialWriteIPv6
#else
#define MAYBE_TestWritableAfterPartialWriteIPv6 \
TestWritableAfterPartialWriteIPv6
#endif
TEST_F(PhysicalSocketTest, MAYBE_TestWritableAfterPartialWriteIPv6) {
WritableAfterPartialWrite(kIPv6Loopback);
}
// Crashes on Linux. See webrtc:4923.
#if defined(WEBRTC_LINUX)
#define MAYBE_TestConnectFailIPv6 DISABLED_TestConnectFailIPv6

View File

@ -11,6 +11,7 @@
#include "webrtc/base/socket_unittest.h"
#include "webrtc/base/arraysize.h"
#include "webrtc/base/buffer.h"
#include "webrtc/base/asyncudpsocket.h"
#include "webrtc/base/gunit.h"
#include "webrtc/base/nethelpers.h"
@ -21,6 +22,9 @@
namespace rtc {
// Data size to be used in TcpInternal tests.
static const size_t kTcpInternalDataSize = 1024 * 1024; // bytes
#define MAYBE_SKIP_IPV6 \
if (!HasIPv6Enabled()) { \
LOG(LS_INFO) << "No IPv6... skipping"; \
@ -129,12 +133,12 @@ void SocketTest::TestSocketServerWaitIPv6() {
}
void SocketTest::TestTcpIPv4() {
TcpInternal(kIPv4Loopback);
TcpInternal(kIPv4Loopback, kTcpInternalDataSize, -1);
}
void SocketTest::TestTcpIPv6() {
MAYBE_SKIP_IPV6;
TcpInternal(kIPv6Loopback);
TcpInternal(kIPv6Loopback, kTcpInternalDataSize, -1);
}
void SocketTest::TestSingleFlowControlCallbackIPv4() {
@ -671,24 +675,15 @@ void SocketTest::SocketServerWaitInternal(const IPAddress& loopback) {
EXPECT_LT(0, accepted->Recv(buf, 1024));
}
void SocketTest::TcpInternal(const IPAddress& loopback) {
void SocketTest::TcpInternal(const IPAddress& loopback, size_t data_size,
ssize_t max_send_size) {
testing::StreamSink sink;
SocketAddress accept_addr;
// Create test data.
const size_t kDataSize = 1024 * 1024;
scoped_ptr<char[]> send_buffer(new char[kDataSize]);
scoped_ptr<char[]> recv_buffer(new char[kDataSize]);
size_t send_pos = 0, recv_pos = 0;
for (size_t i = 0; i < kDataSize; ++i) {
send_buffer[i] = static_cast<char>(i % 256);
recv_buffer[i] = 0;
}
// Create client.
scoped_ptr<AsyncSocket> client(
// Create receiving client.
scoped_ptr<AsyncSocket> receiver(
ss_->CreateAsyncSocket(loopback.family(), SOCK_STREAM));
sink.Monitor(client.get());
sink.Monitor(receiver.get());
// Create server and listen.
scoped_ptr<AsyncSocket> server(
@ -698,97 +693,115 @@ void SocketTest::TcpInternal(const IPAddress& loopback) {
EXPECT_EQ(0, server->Listen(5));
// Attempt connection.
EXPECT_EQ(0, client->Connect(server->GetLocalAddress()));
EXPECT_EQ(0, receiver->Connect(server->GetLocalAddress()));
// Accept connection.
// Accept connection which will be used for sending.
EXPECT_TRUE_WAIT((sink.Check(server.get(), testing::SSE_READ)), kTimeout);
scoped_ptr<AsyncSocket> accepted(server->Accept(&accept_addr));
ASSERT_TRUE(accepted);
sink.Monitor(accepted.get());
scoped_ptr<AsyncSocket> sender(server->Accept(&accept_addr));
ASSERT_TRUE(sender);
sink.Monitor(sender.get());
// Both sides are now connected.
EXPECT_EQ_WAIT(AsyncSocket::CS_CONNECTED, client->GetState(), kTimeout);
EXPECT_TRUE(sink.Check(client.get(), testing::SSE_OPEN));
EXPECT_EQ(client->GetRemoteAddress(), accepted->GetLocalAddress());
EXPECT_EQ(accepted->GetRemoteAddress(), client->GetLocalAddress());
EXPECT_EQ_WAIT(AsyncSocket::CS_CONNECTED, receiver->GetState(), kTimeout);
EXPECT_TRUE(sink.Check(receiver.get(), testing::SSE_OPEN));
EXPECT_EQ(receiver->GetRemoteAddress(), sender->GetLocalAddress());
EXPECT_EQ(sender->GetRemoteAddress(), receiver->GetLocalAddress());
// Create test data.
rtc::Buffer send_buffer(0, data_size);
rtc::Buffer recv_buffer(0, data_size);
for (size_t i = 0; i < data_size; ++i) {
char ch = static_cast<char>(i % 256);
send_buffer.AppendData(&ch, sizeof(ch));
}
// Send and receive a bunch of data.
bool send_waiting_for_writability = false;
bool send_expect_success = true;
bool recv_waiting_for_readability = true;
bool recv_expect_success = false;
int data_in_flight = 0;
while (recv_pos < kDataSize) {
// Send as much as we can if we've been cleared to send.
while (!send_waiting_for_writability && send_pos < kDataSize) {
int tosend = static_cast<int>(kDataSize - send_pos);
int sent = accepted->Send(send_buffer.get() + send_pos, tosend);
if (send_expect_success) {
size_t sent_size = 0;
bool writable = true;
bool send_called = false;
bool readable = false;
bool recv_called = false;
while (recv_buffer.size() < send_buffer.size()) {
// Send as much as we can while we're cleared to send.
while (writable && sent_size < send_buffer.size()) {
int unsent_size = static_cast<int>(send_buffer.size() - sent_size);
int sent = sender->Send(send_buffer.data() + sent_size, unsent_size);
if (!send_called) {
// The first Send() after connecting or getting writability should
// succeed and send some data.
EXPECT_GT(sent, 0);
send_expect_success = false;
send_called = true;
}
if (sent >= 0) {
EXPECT_LE(sent, tosend);
send_pos += sent;
data_in_flight += sent;
EXPECT_LE(sent, unsent_size);
sent_size += sent;
if (max_send_size >= 0) {
EXPECT_LE(static_cast<ssize_t>(sent), max_send_size);
if (sent < unsent_size) {
// If max_send_size is limiting the amount to send per call such
// that the sent amount is less than the unsent amount, we simulate
// that the socket is no longer writable.
writable = false;
}
}
} else {
ASSERT_TRUE(accepted->IsBlocking());
send_waiting_for_writability = true;
ASSERT_TRUE(sender->IsBlocking());
writable = false;
}
}
// Read all the sent data.
while (data_in_flight > 0) {
if (recv_waiting_for_readability) {
while (recv_buffer.size() < sent_size) {
if (!readable) {
// Wait until data is available.
EXPECT_TRUE_WAIT(sink.Check(client.get(), testing::SSE_READ), kTimeout);
recv_waiting_for_readability = false;
recv_expect_success = true;
EXPECT_TRUE_WAIT(sink.Check(receiver.get(), testing::SSE_READ),
kTimeout);
readable = true;
recv_called = false;
}
// Receive as much as we can get in a single recv call.
int rcvd = client->Recv(recv_buffer.get() + recv_pos,
kDataSize - recv_pos);
char recved_data[data_size];
int recved_size = receiver->Recv(recved_data, data_size);
if (recv_expect_success) {
if (!recv_called) {
// The first Recv() after getting readability should succeed and receive
// some data.
// TODO: The following line is disabled due to flakey pulse
// builds. Re-enable if/when possible.
// EXPECT_GT(rcvd, 0);
recv_expect_success = false;
// EXPECT_GT(recved_size, 0);
recv_called = true;
}
if (rcvd >= 0) {
EXPECT_LE(rcvd, data_in_flight);
recv_pos += rcvd;
data_in_flight -= rcvd;
if (recved_size >= 0) {
EXPECT_LE(static_cast<size_t>(recved_size),
sent_size - recv_buffer.size());
recv_buffer.AppendData(recved_data, recved_size);
} else {
ASSERT_TRUE(client->IsBlocking());
recv_waiting_for_readability = true;
ASSERT_TRUE(receiver->IsBlocking());
readable = false;
}
}
// Once all that we've sent has been rcvd, expect to be able to send again.
if (send_waiting_for_writability) {
EXPECT_TRUE_WAIT(sink.Check(accepted.get(), testing::SSE_WRITE),
// Once all that we've sent has been received, expect to be able to send
// again.
if (!writable) {
EXPECT_TRUE_WAIT(sink.Check(sender.get(), testing::SSE_WRITE),
kTimeout);
send_waiting_for_writability = false;
send_expect_success = true;
writable = true;
send_called = false;
}
}
// The received data matches the sent data.
EXPECT_EQ(kDataSize, send_pos);
EXPECT_EQ(kDataSize, recv_pos);
EXPECT_EQ(0, memcmp(recv_buffer.get(), send_buffer.get(), kDataSize));
EXPECT_EQ(data_size, sent_size);
EXPECT_EQ(data_size, recv_buffer.size());
EXPECT_EQ(recv_buffer, send_buffer);
// Close down.
accepted->Close();
EXPECT_EQ_WAIT(AsyncSocket::CS_CLOSED, client->GetState(), kTimeout);
EXPECT_TRUE(sink.Check(client.get(), testing::SSE_CLOSE));
client->Close();
sender->Close();
EXPECT_EQ_WAIT(AsyncSocket::CS_CLOSED, receiver->GetState(), kTimeout);
EXPECT_TRUE(sink.Check(receiver.get(), testing::SSE_CLOSE));
receiver->Close();
}
void SocketTest::SingleFlowControlCallbackInternal(const IPAddress& loopback) {

View File

@ -62,6 +62,10 @@ class SocketTest : public testing::Test {
const IPAddress kIPv4Loopback;
const IPAddress kIPv6Loopback;
protected:
void TcpInternal(const IPAddress& loopback, size_t data_size,
ssize_t max_send_size);
private:
void ConnectInternal(const IPAddress& loopback);
void ConnectWithDnsLookupInternal(const IPAddress& loopback,
@ -76,7 +80,6 @@ class SocketTest : public testing::Test {
void ServerCloseInternal(const IPAddress& loopback);
void CloseInClosedCallbackInternal(const IPAddress& loopback);
void SocketServerWaitInternal(const IPAddress& loopback);
void TcpInternal(const IPAddress& loopback);
void SingleFlowControlCallbackInternal(const IPAddress& loopback);
void UdpInternal(const IPAddress& loopback);
void UdpReadyToSend(const IPAddress& loopback);