patchpanel: Change SocketForwarder exit condition
This CL replaces the exit condition of the SocketForwarder from quitting
as soon as one of the sockets receives a EPOLLRDHUP signal to waiting
for both sockets to receive EOF.
EPOLLRDHUP means that the peer will no longer write to the socket, but
they can still receive data on that socket. In practice, this is
encountered when downloading PlayStore apps through a SocketForwarder.
BUG=chromium:1125177
TEST=downloading PlayStore apps on DUT works
Change-Id: I387cdf413e4a0ececd4453d0783424388de145d1
Reviewed-on: https://chromium-review.googlesource.com/c/chromiumos/platform2/+/2398698
Reviewed-by: Garrick Evans <garrick@chromium.org>
Reviewed-by: Hugo Benichi <hugobenichi@google.com>
Tested-by: Andreea-Elena Costinas <acostinas@google.com>
Commit-Queue: Hugo Benichi <hugobenichi@google.com>
diff --git a/patchpanel/socket_forwarder_test.cc b/patchpanel/socket_forwarder_test.cc
new file mode 100644
index 0000000..547c2ae
--- /dev/null
+++ b/patchpanel/socket_forwarder_test.cc
@@ -0,0 +1,108 @@
+// Copyright 2020 The Chromium OS Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file.
+
+#include "patchpanel/socket_forwarder.h"
+
+#include <netinet/in.h>
+#include <sys/socket.h>
+#include <sys/types.h>
+
+#include <memory>
+#include <utility>
+#include <vector>
+
+#include <base/callback.h>
+#include <base/run_loop.h>
+#include <base/task/single_thread_task_executor.h>
+#include <brillo/message_loops/base_message_loop.h>
+#include <gmock/gmock.h>
+#include <gtest/gtest.h>
+
+using testing::Each;
+
+namespace patchpanel {
+namespace {
+// SocketForwarder reads blocks of 4096 bytes.
+constexpr int kDataSize = 5000;
+
+// Does a blocking read on |socket| until it receives |expected_byte_count|
+// bytes which will be written into |buf|.
+bool Read(Socket* socket, char* buf, int expected_byte_count) {
+ int read_byte_count = 0;
+ int bytes = 0;
+ while (read_byte_count < expected_byte_count) {
+ bytes = socket->RecvFrom(buf + read_byte_count, kDataSize);
+ if (bytes <= 0)
+ return false;
+ read_byte_count += bytes;
+ }
+ if (read_byte_count != expected_byte_count)
+ return false;
+ return true;
+}
+} // namespace
+
+class SocketForwarderTest : public ::testing::Test {
+ void SetUp() override {
+ int fds0[2], fds1[2];
+ ASSERT_NE(-1, socketpair(AF_UNIX, SOCK_STREAM, 0 /* protocol */, fds0));
+ ASSERT_NE(-1, socketpair(AF_UNIX, SOCK_STREAM, 0 /* protocol */, fds1));
+ peer0_ = std::make_unique<Socket>(base::ScopedFD(fds0[0]));
+ peer1_ = std::make_unique<Socket>(base::ScopedFD(fds1[0]));
+ forwarder_ = std::make_unique<SocketForwarder>(
+ "test", std::make_unique<Socket>(base::ScopedFD(fds0[1])),
+ std::make_unique<Socket>(base::ScopedFD(fds1[1])));
+ }
+
+ protected:
+ std::unique_ptr<Socket> peer0_;
+ std::unique_ptr<Socket> peer1_;
+ // Forwards data betweeok |peer0_| and |peer1_|.
+ std::unique_ptr<SocketForwarder> forwarder_;
+
+ base::SingleThreadTaskExecutor task_executor_{base::MessagePumpType::IO};
+ brillo::BaseMessageLoop brillo_loop_{task_executor_.task_runner()};
+};
+
+TEST_F(SocketForwarderTest, ForwardDataAndClose) {
+ base::RunLoop loop;
+ forwarder_->SetStopQuitClosureForTesting(loop.QuitClosure());
+ forwarder_->Start();
+
+ std::vector<char> msg(kDataSize, 1);
+
+ EXPECT_EQ(peer0_->SendTo(msg.data(), msg.size()), kDataSize);
+ EXPECT_EQ(peer1_->SendTo(msg.data(), msg.size()), kDataSize);
+ // Close both sockets for writing.
+ EXPECT_NE(shutdown(peer0_->fd(), SHUT_WR), -1);
+ EXPECT_NE(shutdown(peer1_->fd(), SHUT_WR), -1);
+
+ loop.Run();
+
+ EXPECT_FALSE(forwarder_->IsRunning());
+
+ // Verify that all the data has been forwarded to the peers.
+ std::vector<char> expected_data_peer0(kDataSize);
+ std::vector<char> expected_data_peer1(kDataSize);
+ EXPECT_TRUE(Read(peer1_.get(), expected_data_peer1.data(), kDataSize));
+ EXPECT_TRUE(Read(peer0_.get(), expected_data_peer0.data(), kDataSize));
+
+ EXPECT_THAT(expected_data_peer0, Each(1));
+ EXPECT_THAT(expected_data_peer1, Each(1));
+}
+
+TEST_F(SocketForwarderTest, PeerSignalEPOLLHUP) {
+ base::RunLoop loop;
+ forwarder_->SetStopQuitClosureForTesting(loop.QuitClosure());
+ forwarder_->Start();
+
+ // Close the destination peer.
+ peer1_.reset();
+
+ loop.Run();
+
+ EXPECT_FALSE(forwarder_->IsRunning());
+}
+
+} // namespace patchpanel