UdpSocket: Allow IPADDR_ANY and port 0 when Create()'ing.
Removes the restriction that all UdpSockets must be created with a known
local endpoint (addr+port), and provides a GetLocalEndpoint() method to
query what port is being auto-assigned by the operating system. This
allows for two socket use cases:
1. Allowing the operating system to auto-assign a free port, which is
useful for things like Cast Streaming senders, unit testing, etc.
2. Direct point-to-point Send(), where the client code does not care
what the local endpoint is (and is fine with the operating system auto-
assigning whatever it wants as the "return address").
Change-Id: I623d3294f1dfdf7c9cc11bbf90acd2c3503fbfec
Reviewed-on: https://chromium-review.googlesource.com/c/openscreen/+/1740680
Reviewed-by: Yuri Wiitala <miu@chromium.org>
Reviewed-by: Max Yakimakha <yakimakha@chromium.org>
Reviewed-by: Ryan Keane <rwkeane@google.com>
Commit-Queue: Yuri Wiitala <miu@chromium.org>
diff --git a/platform/api/udp_socket.h b/platform/api/udp_socket.h
index 3639891..aeee300 100644
--- a/platform/api/udp_socket.h
+++ b/platform/api/udp_socket.h
@@ -56,15 +56,25 @@
using Version = IPAddress::Version;
- // Creates a new, scoped UdpSocket within the IPv4 or IPv6 family. This method
- // must be defined in the platform-level implementation.
- static ErrorOr<UdpSocketUniquePtr> Create(const IPEndpoint& endpoint);
+ // Creates a new, scoped UdpSocket within the IPv4 or IPv6 family.
+ // |local_endpoint| may be zero (see comments for Bind()). This method must be
+ // defined in the platform-level implementation.
+ static ErrorOr<UdpSocketUniquePtr> Create(const IPEndpoint& local_endpoint);
// Returns true if |socket| belongs to the IPv4/IPv6 address family.
virtual bool IsIPv4() const = 0;
virtual bool IsIPv6() const = 0;
- // Binds to the address specified in the constructor.
+ // Returns the current local endpoint's address and port. Initially, this will
+ // be the same as the value that was passed into Create(). However, it can
+ // later change after certain operations, such as Bind(), are executed.
+ virtual IPEndpoint GetLocalEndpoint() const = 0;
+
+ // Binds to the address specified in the constructor. If the local endpoint's
+ // address is zero, the operating system will bind to all interfaces. If the
+ // local endpoint's port is zero, the operating system will automatically find
+ // a free local port and bind to it. Future calls to local_endpoint() will
+ // reflect the resolved port.
virtual Error Bind() = 0;
// Sets the device to use for outgoing multicast packets on the socket.
diff --git a/platform/api/udp_socket_unittest.cc b/platform/api/udp_socket_unittest.cc
index 7814032..f50ab72 100644
--- a/platform/api/udp_socket_unittest.cc
+++ b/platform/api/udp_socket_unittest.cc
@@ -33,5 +33,35 @@
EXPECT_EQ(call_count, 1);
}
+// Tests that a UdpSocket that does not specify any address or port will
+// successfully Bind(), and that the operating system will return the
+// auto-assigned socket name (i.e., the local endpoint's port will not be zero).
+TEST(UdpSocketTest, ResolvesLocalEndpoint_IPv4) {
+ const uint8_t kIpV4AddrAny[4] = {};
+ ErrorOr<UdpSocketUniquePtr> create_result =
+ UdpSocket::Create(IPEndpoint{IPAddress(kIpV4AddrAny), 0});
+ ASSERT_TRUE(create_result) << create_result.error();
+ const auto socket = create_result.MoveValue();
+ const Error bind_result = socket->Bind();
+ ASSERT_TRUE(bind_result.ok()) << bind_result;
+ const IPEndpoint local_endpoint = socket->GetLocalEndpoint();
+ EXPECT_NE(local_endpoint.port, 0) << local_endpoint;
+}
+
+// Tests that a UdpSocket that does not specify any address or port will
+// successfully Bind(), and that the operating system will return the
+// auto-assigned socket name (i.e., the local endpoint's port will not be zero).
+TEST(UdpSocketTest, ResolvesLocalEndpoint_IPv6) {
+ const uint8_t kIpV6AddrAny[16] = {};
+ ErrorOr<UdpSocketUniquePtr> create_result =
+ UdpSocket::Create(IPEndpoint{IPAddress(kIpV6AddrAny), 0});
+ ASSERT_TRUE(create_result) << create_result.error();
+ const auto socket = create_result.MoveValue();
+ const Error bind_result = socket->Bind();
+ ASSERT_TRUE(bind_result.ok()) << bind_result;
+ const IPEndpoint local_endpoint = socket->GetLocalEndpoint();
+ EXPECT_NE(local_endpoint.port, 0) << local_endpoint;
+}
+
} // namespace platform
} // namespace openscreen