diff --git a/p2p/base/port_unittest.cc b/p2p/base/port_unittest.cc index 8d6d99e5c1..e2dde8358c 100644 --- a/p2p/base/port_unittest.cc +++ b/p2p/base/port_unittest.cc @@ -955,9 +955,7 @@ void PortTest::TestConnectivity(const char* name1, class FakePacketSocketFactory : public rtc::PacketSocketFactory { public: FakePacketSocketFactory() - : next_udp_socket_(NULL), - next_server_tcp_socket_(NULL), - next_client_tcp_socket_(NULL) {} + : next_udp_socket_(NULL), next_server_tcp_socket_(NULL) {} ~FakePacketSocketFactory() override {} AsyncPacketSocket* CreateUdpSocket(const SocketAddress& address, @@ -985,9 +983,9 @@ class FakePacketSocketFactory : public rtc::PacketSocketFactory { const rtc::ProxyInfo& proxy_info, const std::string& user_agent, const rtc::PacketSocketTcpOptions& opts) override { - EXPECT_TRUE(next_client_tcp_socket_ != NULL); - AsyncPacketSocket* result = next_client_tcp_socket_; - next_client_tcp_socket_ = NULL; + EXPECT_TRUE(next_client_tcp_socket_.has_value()); + AsyncPacketSocket* result = *next_client_tcp_socket_; + next_client_tcp_socket_ = nullptr; return result; } @@ -1005,29 +1003,37 @@ class FakePacketSocketFactory : public rtc::PacketSocketFactory { private: AsyncPacketSocket* next_udp_socket_; AsyncPacketSocket* next_server_tcp_socket_; - AsyncPacketSocket* next_client_tcp_socket_; + absl::optional next_client_tcp_socket_; }; class FakeAsyncPacketSocket : public AsyncPacketSocket { public: // Returns current local address. Address may be set to NULL if the // socket is not bound yet (GetState() returns STATE_BINDING). - virtual SocketAddress GetLocalAddress() const { return SocketAddress(); } + virtual SocketAddress GetLocalAddress() const { return local_address_; } // Returns remote address. Returns zeroes if this is not a client TCP socket. - virtual SocketAddress GetRemoteAddress() const { return SocketAddress(); } + virtual SocketAddress GetRemoteAddress() const { return remote_address_; } // Send a packet. virtual int Send(const void* pv, size_t cb, const rtc::PacketOptions& options) { - return static_cast(cb); + if (error_ == 0) { + return static_cast(cb); + } else { + return -1; + } } virtual int SendTo(const void* pv, size_t cb, const SocketAddress& addr, const rtc::PacketOptions& options) { - return static_cast(cb); + if (error_ == 0) { + return static_cast(cb); + } else { + return -1; + } } virtual int Close() { return 0; } @@ -1035,11 +1041,15 @@ class FakeAsyncPacketSocket : public AsyncPacketSocket { virtual int GetOption(Socket::Option opt, int* value) { return 0; } virtual int SetOption(Socket::Option opt, int value) { return 0; } virtual int GetError() const { return 0; } - virtual void SetError(int error) {} + virtual void SetError(int error) { error_ = error; } void set_state(State state) { state_ = state; } + SocketAddress local_address_; + SocketAddress remote_address_; + private: + int error_ = 0; State state_; }; @@ -1435,6 +1445,52 @@ TEST_F(PortTest, TestDelayedBindingTcp) { EXPECT_EQ(1U, port->Candidates().size()); } +TEST_F(PortTest, TestDisableInterfaceOfTcpPort) { + FakeAsyncPacketSocket* lsocket = new FakeAsyncPacketSocket(); + FakeAsyncPacketSocket* rsocket = new FakeAsyncPacketSocket(); + FakePacketSocketFactory socket_factory; + + socket_factory.set_next_server_tcp_socket(lsocket); + auto lport = CreateTcpPort(kLocalAddr1, &socket_factory); + + socket_factory.set_next_server_tcp_socket(rsocket); + auto rport = CreateTcpPort(kLocalAddr2, &socket_factory); + + lsocket->set_state(AsyncPacketSocket::STATE_BINDING); + lsocket->SignalAddressReady(lsocket, kLocalAddr1); + rsocket->set_state(AsyncPacketSocket::STATE_BINDING); + rsocket->SignalAddressReady(rsocket, kLocalAddr2); + + lport->SetIceRole(cricket::ICEROLE_CONTROLLING); + lport->SetIceTiebreaker(kTiebreaker1); + rport->SetIceRole(cricket::ICEROLE_CONTROLLED); + rport->SetIceTiebreaker(kTiebreaker2); + + lport->PrepareAddress(); + rport->PrepareAddress(); + ASSERT_FALSE(rport->Candidates().empty()); + + // A client socket. + FakeAsyncPacketSocket* socket = new FakeAsyncPacketSocket(); + socket->local_address_ = kLocalAddr1; + socket->remote_address_ = kLocalAddr2; + socket_factory.set_next_client_tcp_socket(socket); + Connection* lconn = + lport->CreateConnection(rport->Candidates()[0], Port::ORIGIN_MESSAGE); + ASSERT_NE(lconn, nullptr); + socket->SignalConnect(socket); + lconn->Ping(0); + + // Now disconnect the client socket... + socket->SignalClose(socket, 1); + + // And prevent new sockets from being created. + socket_factory.set_next_client_tcp_socket(nullptr); + + // Test that Ping() does not cause SEGV. + lconn->Ping(0); +} + void PortTest::TestCrossFamilyPorts(int type) { FakePacketSocketFactory factory; std::unique_ptr ports[4]; diff --git a/p2p/base/tcp_port.cc b/p2p/base/tcp_port.cc index d1fb9b29e9..e07361acf7 100644 --- a/p2p/base/tcp_port.cc +++ b/p2p/base/tcp_port.cc @@ -520,6 +520,9 @@ void TCPConnection::OnMessage(rtc::Message* pmsg) { Destroy(); } break; + case MSG_TCPCONNECTION_FAILED_CREATE_SOCKET: + FailAndPrune(); + break; default: Connection::OnMessage(pmsg); } @@ -576,7 +579,13 @@ void TCPConnection::CreateOutgoingTcpSocket() { } else { RTC_LOG(LS_WARNING) << ToString() << ": Failed to create connection to " << remote_candidate().address().ToSensitiveString(); - FailAndPrune(); + // We can't FailAndPrune directly here. FailAndPrune and deletes all + // the StunRequests from the request_map_. And if this is in the stack + // of Connection::Ping(), we are still using the request. + // Unwind the stack and defer the FailAndPrune. + set_state(IceCandidatePairState::FAILED); + port()->thread()->Post(RTC_FROM_HERE, this, + MSG_TCPCONNECTION_FAILED_CREATE_SOCKET); } } diff --git a/p2p/base/tcp_port.h b/p2p/base/tcp_port.h index f6953c06b3..36257b07ed 100644 --- a/p2p/base/tcp_port.h +++ b/p2p/base/tcp_port.h @@ -139,6 +139,7 @@ class TCPConnection : public Connection { protected: enum { MSG_TCPCONNECTION_DELAYED_ONCLOSE = Connection::MSG_FIRST_AVAILABLE, + MSG_TCPCONNECTION_FAILED_CREATE_SOCKET, }; // Set waiting_for_stun_binding_complete_ to false to allow data packets in