TcpSocket: Modify existing code for TcpConnectionWaiter
In order for TcpConnections to be initialized on posix systems, we need
to use accept, as documented here:
http://man7.org/linux/man-pages/man2/accept.2.html
Accept uses a polling mechanism, similar to UdpSocket Reads.
Additionally, according to the above documentation, we can use Select to
check if the socket is ready to connect, as documented here:
http://man7.org/linux/man-pages/man2/select.2.html
Because this same mechanism is already used for UdpSocket Reads in the
NetworkReader (via the NetworkWaiter class), we can repurpose the
existing NetworkWaiter to work for both its current NetworkReader use
and a new TcpConnectionWaiter class (in the next CL)
To do so, we need to make the NetworkWaiter work on fds instead of
UdpSocket objects. This CL accomplishes that change.
Change-Id: I69cab5ea9e94ece7016da52b2723ec7bf47a6196
Reviewed-on: https://chromium-review.googlesource.com/c/openscreen/+/1790368
Commit-Queue: Ryan Keane <rwkeane@google.com>
Reviewed-by: Max Yakimakha <yakimakha@chromium.org>
diff --git a/platform/impl/network_reader.cc b/platform/impl/network_reader.cc
index ffd4ee9..d0b7f3f 100644
--- a/platform/impl/network_reader.cc
+++ b/platform/impl/network_reader.cc
@@ -8,6 +8,7 @@
#include <condition_variable>
#include "platform/api/logging.h"
+#include "platform/impl/socket_handle_posix.h"
#include "platform/impl/udp_socket_posix.h"
namespace openscreen {
@@ -20,36 +21,45 @@
NetworkReader::~NetworkReader() = default;
+// TODO(rwkeane): Remove unsafe casts to UdpSocketPosix.
Error NetworkReader::WaitAndRead(Clock::duration timeout) {
// Get the set of all sockets we care about. A different list than the
// existing unordered_set is used to avoid race conditions with the method
// using this new list.
socket_deletion_block_.notify_all();
- std::vector<UdpSocket*> sockets;
+ std::vector<SocketHandle> socket_handles;
+ socket_handles.reserve(sockets_.size());
{
std::lock_guard<std::mutex> lock(mutex_);
- sockets = sockets_;
+ for (const auto& socket : sockets_) {
+ UdpSocketPosix* read_socket = static_cast<UdpSocketPosix*>(socket);
+ socket_handles.emplace_back(read_socket->GetFd());
+ }
}
// Wait for the sockets to find something interesting or for the timeout.
- auto changed_or_error = waiter_->AwaitSocketsReadable(sockets, timeout);
+ auto changed_or_error =
+ waiter_->AwaitSocketsReadable(socket_handles, timeout);
if (changed_or_error.is_error()) {
return changed_or_error.error();
}
// Process the results.
socket_deletion_block_.notify_all();
+ const std::vector<SocketHandle>& changed_handles = changed_or_error.value();
+ if (changed_handles.empty()) {
+ return Error::None();
+ }
+
{
std::lock_guard<std::mutex> lock(mutex_);
- for (UdpSocket* socket : changed_or_error.value()) {
- if (std::find(sockets_.begin(), sockets_.end(), socket) ==
- sockets_.end()) {
- continue;
- }
-
- // TODO(rwkeane): Remove this unsafe cast.
+ for (UdpSocket* socket : sockets_) {
UdpSocketPosix* read_socket = static_cast<UdpSocketPosix*>(socket);
- read_socket->ReceiveMessage();
+ if (std::find(changed_handles.begin(), changed_handles.end(),
+ SocketHandle(read_socket->GetFd())) !=
+ changed_handles.end()) {
+ read_socket->ReceiveMessage();
+ }
}
}
diff --git a/platform/impl/network_reader.h b/platform/impl/network_reader.h
index d6ad87e..0e9f799 100644
--- a/platform/impl/network_reader.h
+++ b/platform/impl/network_reader.h
@@ -21,6 +21,7 @@
// calling the function associated with these sockets once that data is read.
// NOTE: This class will only function as intended while its RunUntilStopped
// method is running.
+// TODO(rwkeane): Rename this class NetworkReaderPosix.
class NetworkReader : public UdpSocket::LifetimeObserver {
public:
// Create a type for readability
diff --git a/platform/impl/network_reader_unittest.cc b/platform/impl/network_reader_unittest.cc
index 82a6bf6..e6a02dc 100644
--- a/platform/impl/network_reader_unittest.cc
+++ b/platform/impl/network_reader_unittest.cc
@@ -7,6 +7,7 @@
#include "gmock/gmock.h"
#include "gtest/gtest.h"
#include "platform/api/time.h"
+#include "platform/impl/socket_handle_posix.h"
#include "platform/impl/udp_socket_posix.h"
#include "platform/test/fake_clock.h"
#include "platform/test/fake_task_runner.h"
@@ -23,8 +24,9 @@
public:
explicit MockUdpSocketPosix(TaskRunner* task_runner,
Client* client,
+ int fd,
Version version = Version::kV4)
- : UdpSocketPosix(task_runner, client, 0, IPEndpoint()),
+ : UdpSocketPosix(task_runner, client, fd, IPEndpoint()),
version_(version) {}
~MockUdpSocketPosix() override = default;
@@ -47,9 +49,10 @@
// Mock event waiter
class MockNetworkWaiter final : public NetworkWaiter {
public:
- MOCK_METHOD2(AwaitSocketsReadable,
- ErrorOr<std::vector<UdpSocket*>>(const std::vector<UdpSocket*>&,
- const Clock::duration&));
+ MOCK_METHOD2(
+ AwaitSocketsReadable,
+ ErrorOr<std::vector<SocketHandle>>(const std::vector<SocketHandle>&,
+ const Clock::duration&));
};
// Mock Task Runner
@@ -99,7 +102,7 @@
std::unique_ptr<TaskRunner>(new MockTaskRunner());
FakeUdpSocket::MockClient client;
std::unique_ptr<MockUdpSocketPosix> socket =
- std::make_unique<MockUdpSocketPosix>(task_runner.get(), &client,
+ std::make_unique<MockUdpSocketPosix>(task_runner.get(), &client, 42,
UdpSocket::Version::kV4);
TestingNetworkWaiter network_waiter(std::move(mock_waiter));
@@ -119,7 +122,7 @@
std::unique_ptr<TaskRunner>(new MockTaskRunner());
FakeUdpSocket::MockClient client;
std::unique_ptr<MockUdpSocketPosix> socket =
- std::make_unique<MockUdpSocketPosix>(task_runner.get(), &client,
+ std::make_unique<MockUdpSocketPosix>(task_runner.get(), &client, 17,
UdpSocket::Version::kV4);
TestingNetworkWaiter network_waiter(std::move(mock_waiter));
@@ -169,7 +172,7 @@
auto timeout = Clock::duration(0);
EXPECT_CALL(*mock_waiter_ptr, AwaitSocketsReadable(_, timeout))
- .WillOnce(Return(ByMove(std::vector<UdpSocket*>{})));
+ .WillOnce(Return(ByMove(std::vector<SocketHandle>{})));
EXPECT_EQ(network_waiter.WaitTesting(timeout), Error::Code::kNone);
}
@@ -181,17 +184,17 @@
std::unique_ptr<TaskRunner>(new MockTaskRunner());
FakeUdpSocket::MockClient client;
std::unique_ptr<MockUdpSocketPosix> socket =
- std::make_unique<MockUdpSocketPosix>(task_runner.get(), &client,
+ std::make_unique<MockUdpSocketPosix>(task_runner.get(), &client, 10,
UdpSocket::Version::kV4);
TestingNetworkWaiter network_waiter(std::move(mock_waiter));
auto timeout = Clock::duration(0);
UdpPacket packet;
network_waiter.OnCreate(socket.get());
- EXPECT_CALL(
- *mock_waiter_ptr,
- AwaitSocketsReadable(ContainerEq<std::vector<UdpSocket*>>({socket.get()}),
- timeout))
+ EXPECT_CALL(*mock_waiter_ptr,
+ AwaitSocketsReadable(ContainerEq<std::vector<SocketHandle>>(
+ {SocketHandle{socket->GetFd()}}),
+ timeout))
.WillOnce(Return(ByMove(std::move(Error::Code::kAgain))));
EXPECT_EQ(network_waiter.WaitTesting(timeout), Error::Code::kAgain);
}
@@ -204,7 +207,7 @@
std::unique_ptr<TaskRunner> task_runner =
std::unique_ptr<TaskRunner>(task_runner_ptr);
FakeUdpSocket::MockClient client;
- MockUdpSocketPosix socket(task_runner.get(), &client,
+ MockUdpSocketPosix socket(task_runner.get(), &client, 42,
UdpSocket::Version::kV4);
TestingNetworkWaiter network_waiter(std::move(mock_waiter));
auto timeout = Clock::duration(0);
@@ -213,7 +216,8 @@
network_waiter.OnCreate(&socket);
EXPECT_CALL(*mock_waiter_ptr, AwaitSocketsReadable(_, timeout))
- .WillOnce(Return(ByMove(std::vector<UdpSocket*>{&socket})));
+ .WillOnce(Return(
+ ByMove(std::vector<SocketHandle>{SocketHandle{socket.GetFd()}})));
EXPECT_CALL(socket, ReceiveMessage()).Times(1);
EXPECT_EQ(network_waiter.WaitTesting(timeout), Error::Code::kNone);
}
diff --git a/platform/impl/network_waiter_posix.cc b/platform/impl/network_waiter_posix.cc
index 84185f9..bbd8106 100644
--- a/platform/impl/network_waiter_posix.cc
+++ b/platform/impl/network_waiter_posix.cc
@@ -11,7 +11,7 @@
#include "platform/api/logging.h"
#include "platform/base/error.h"
-#include "platform/impl/network_reader.h"
+#include "platform/impl/socket_handle_posix.h"
#include "platform/impl/udp_socket_posix.h"
namespace openscreen {
@@ -21,15 +21,14 @@
NetworkWaiterPosix::~NetworkWaiterPosix() = default;
-ErrorOr<std::vector<UdpSocket*>> NetworkWaiterPosix::AwaitSocketsReadable(
- const std::vector<UdpSocket*>& sockets,
+ErrorOr<std::vector<SocketHandle>> NetworkWaiterPosix::AwaitSocketsReadable(
+ const std::vector<SocketHandle>& socket_handles,
const Clock::duration& timeout) {
int max_fd = -1;
FD_ZERO(&read_handles_);
- for (UdpSocket* socket : sockets) {
- UdpSocketPosix* posix_socket = static_cast<UdpSocketPosix*>(socket);
- FD_SET(posix_socket->GetFd(), &read_handles_);
- max_fd = std::max(max_fd, posix_socket->GetFd());
+ for (const SocketHandle& handle : socket_handles) {
+ FD_SET(handle.fd, &read_handles_);
+ max_fd = std::max(max_fd, handle.fd);
}
if (max_fd < 0) {
return Error::Code::kIOFailure;
@@ -49,15 +48,14 @@
return Error::Code::kAgain;
}
- std::vector<UdpSocket*> changed_sockets;
- for (UdpSocket* socket : sockets) {
- UdpSocketPosix* posix_socket = static_cast<UdpSocketPosix*>(socket);
- if (FD_ISSET(posix_socket->GetFd(), &read_handles_)) {
- changed_sockets.push_back(socket);
+ std::vector<SocketHandle> changed_handles;
+ for (const SocketHandle& handle : socket_handles) {
+ if (FD_ISSET(handle.fd, &read_handles_)) {
+ changed_handles.push_back(handle);
}
}
- return changed_sockets;
+ return changed_handles;
}
// static
diff --git a/platform/impl/network_waiter_posix.h b/platform/impl/network_waiter_posix.h
index 747c846..3f10b1a 100644
--- a/platform/impl/network_waiter_posix.h
+++ b/platform/impl/network_waiter_posix.h
@@ -12,7 +12,6 @@
#include <mutex> // NOLINT
#include "platform/api/network_waiter.h"
-#include "platform/impl/udp_socket_posix.h"
namespace openscreen {
namespace platform {
@@ -21,9 +20,9 @@
public:
NetworkWaiterPosix();
~NetworkWaiterPosix();
- ErrorOr<std::vector<UdpSocket*>> AwaitSocketsReadable(
- const std::vector<UdpSocket*>& sockets,
- const Clock::duration& timeout) override;
+ ErrorOr<std::vector<SocketHandle>> AwaitSocketsReadable(
+ const std::vector<SocketHandle>& socket_handles,
+ const Clock::duration& timeout);
// TODO(rwkeane): Move this to a platform-specific util library.
static struct timeval ToTimeval(const Clock::duration& timeout);
diff --git a/platform/impl/socket_handle_posix.cc b/platform/impl/socket_handle_posix.cc
new file mode 100644
index 0000000..965c9eb
--- /dev/null
+++ b/platform/impl/socket_handle_posix.cc
@@ -0,0 +1,20 @@
+// Copyright 2019 The Chromium Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file
+
+#include "platform/impl/socket_handle_posix.h"
+
+namespace openscreen {
+namespace platform {
+
+SocketHandle::SocketHandle(int descriptor) : fd(descriptor) {}
+
+bool SocketHandle::operator==(const SocketHandle& other) const {
+ return fd == other.fd;
+}
+
+bool SocketHandle::operator!=(const SocketHandle& other) const {
+ return !(*this == other);
+}
+} // namespace platform
+} // namespace openscreen
diff --git a/platform/impl/socket_handle_posix.h b/platform/impl/socket_handle_posix.h
new file mode 100644
index 0000000..baec6e5
--- /dev/null
+++ b/platform/impl/socket_handle_posix.h
@@ -0,0 +1,24 @@
+// Copyright 2019 The Chromium Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file
+
+#ifndef PLATFORM_IMPL_SOCKET_HANDLE_POSIX_H_
+#define PLATFORM_IMPL_SOCKET_HANDLE_POSIX_H_
+
+#include "platform/api/socket_handle.h"
+
+namespace openscreen {
+namespace platform {
+
+struct SocketHandle {
+ explicit SocketHandle(int descriptor);
+ int fd;
+
+ bool operator==(const SocketHandle& other) const;
+ bool operator!=(const SocketHandle& other) const;
+};
+
+} // namespace platform
+} // namespace openscreen
+
+#endif // PLATFORM_IMPL_SOCKET_HANDLE_POSIX_H_
diff --git a/platform/impl/stream_socket.h b/platform/impl/stream_socket.h
index 21cd727..83e3586 100644
--- a/platform/impl/stream_socket.h
+++ b/platform/impl/stream_socket.h
@@ -10,6 +10,7 @@
#include <string>
#include "platform/api/network_interface.h"
+#include "platform/api/socket_handle.h"
#include "platform/base/error.h"
#include "platform/base/ip_address.h"
#include "platform/base/macros.h"
@@ -18,8 +19,6 @@
namespace openscreen {
namespace platform {
-struct FileDescriptor;
-
// StreamSocket is an incomplete abstraction of synchronous platform methods for
// creating, initializing, and closing stream sockets. Callers can use this
// class to define complete TCP and TLS socket classes, both synchronous and
@@ -48,7 +47,7 @@
virtual Error Listen(int max_backlog_size) = 0;
// Returns the file descriptor (e.g. fd or HANDLE pointer) for this socket.
- virtual FileDescriptor file_descriptor() const = 0;
+ virtual SocketHandle socket_handle() const = 0;
// Returns the connected remote address, if socket is connected.
virtual absl::optional<IPEndpoint> remote_address() const = 0;
diff --git a/platform/impl/stream_socket_posix.cc b/platform/impl/stream_socket_posix.cc
index 69ab141..9082245 100644
--- a/platform/impl/stream_socket_posix.cc
+++ b/platform/impl/stream_socket_posix.cc
@@ -11,8 +11,6 @@
#include <sys/types.h>
#include <unistd.h>
-#include "platform/impl/socket_address_posix.h"
-
namespace openscreen {
namespace platform {
@@ -156,8 +154,8 @@
return Error::None();
}
-FileDescriptor StreamSocketPosix::file_descriptor() const {
- return FileDescriptor{.fd = file_descriptor_.load()};
+SocketHandle StreamSocketPosix::socket_handle() const {
+ return SocketHandle{file_descriptor_.load()};
}
absl::optional<IPEndpoint> StreamSocketPosix::remote_address() const {
diff --git a/platform/impl/stream_socket_posix.h b/platform/impl/stream_socket_posix.h
index bebf601..dba1af4 100644
--- a/platform/impl/stream_socket_posix.h
+++ b/platform/impl/stream_socket_posix.h
@@ -13,13 +13,11 @@
#include "platform/base/error.h"
#include "platform/base/ip_address.h"
#include "platform/impl/socket_address_posix.h"
+#include "platform/impl/socket_handle_posix.h"
#include "platform/impl/stream_socket.h"
namespace openscreen {
namespace platform {
-struct FileDescriptor {
- int fd;
-};
class StreamSocketPosix : public StreamSocket {
public:
@@ -42,7 +40,7 @@
Error Listen(int max_backlog_size) override;
// StreamSocket getter overrides.
- FileDescriptor file_descriptor() const override;
+ SocketHandle socket_handle() const override;
absl::optional<IPEndpoint> remote_address() const override;
SocketState state() const override;
IPAddress::Version version() const override;
diff --git a/platform/impl/udp_socket_posix.h b/platform/impl/udp_socket_posix.h
index f00078d..612d389 100644
--- a/platform/impl/udp_socket_posix.h
+++ b/platform/impl/udp_socket_posix.h
@@ -39,6 +39,7 @@
const IPEndpoint& dest) override;
void SetDscp(DscpMode state) override;
+ // TODO(rwkeane): Update to return a SocketHandle object.
int GetFd() const { return fd_; }
private: