Andreea Costinas | 456ee5b | 2020-09-08 15:11:43 +0200 | [diff] [blame] | 1 | // Copyright 2020 The Chromium OS Authors. All rights reserved. |
| 2 | // Use of this source code is governed by a BSD-style license that can be |
| 3 | // found in the LICENSE file. |
| 4 | |
| 5 | #include "patchpanel/socket_forwarder.h" |
| 6 | |
| 7 | #include <netinet/in.h> |
| 8 | #include <sys/socket.h> |
| 9 | #include <sys/types.h> |
| 10 | |
| 11 | #include <memory> |
| 12 | #include <utility> |
| 13 | #include <vector> |
| 14 | |
| 15 | #include <base/callback.h> |
| 16 | #include <base/run_loop.h> |
| 17 | #include <base/task/single_thread_task_executor.h> |
| 18 | #include <brillo/message_loops/base_message_loop.h> |
| 19 | #include <gmock/gmock.h> |
| 20 | #include <gtest/gtest.h> |
| 21 | |
| 22 | using testing::Each; |
| 23 | |
| 24 | namespace patchpanel { |
| 25 | namespace { |
| 26 | // SocketForwarder reads blocks of 4096 bytes. |
| 27 | constexpr int kDataSize = 5000; |
| 28 | |
| 29 | // Does a blocking read on |socket| until it receives |expected_byte_count| |
| 30 | // bytes which will be written into |buf|. |
| 31 | bool Read(Socket* socket, char* buf, int expected_byte_count) { |
| 32 | int read_byte_count = 0; |
| 33 | int bytes = 0; |
| 34 | while (read_byte_count < expected_byte_count) { |
| 35 | bytes = socket->RecvFrom(buf + read_byte_count, kDataSize); |
| 36 | if (bytes <= 0) |
| 37 | return false; |
| 38 | read_byte_count += bytes; |
| 39 | } |
| 40 | if (read_byte_count != expected_byte_count) |
| 41 | return false; |
| 42 | return true; |
| 43 | } |
| 44 | } // namespace |
| 45 | |
| 46 | class SocketForwarderTest : public ::testing::Test { |
| 47 | void SetUp() override { |
| 48 | int fds0[2], fds1[2]; |
| 49 | ASSERT_NE(-1, socketpair(AF_UNIX, SOCK_STREAM, 0 /* protocol */, fds0)); |
| 50 | ASSERT_NE(-1, socketpair(AF_UNIX, SOCK_STREAM, 0 /* protocol */, fds1)); |
| 51 | peer0_ = std::make_unique<Socket>(base::ScopedFD(fds0[0])); |
| 52 | peer1_ = std::make_unique<Socket>(base::ScopedFD(fds1[0])); |
| 53 | forwarder_ = std::make_unique<SocketForwarder>( |
| 54 | "test", std::make_unique<Socket>(base::ScopedFD(fds0[1])), |
| 55 | std::make_unique<Socket>(base::ScopedFD(fds1[1]))); |
| 56 | } |
| 57 | |
| 58 | protected: |
| 59 | std::unique_ptr<Socket> peer0_; |
| 60 | std::unique_ptr<Socket> peer1_; |
| 61 | // Forwards data betweeok |peer0_| and |peer1_|. |
| 62 | std::unique_ptr<SocketForwarder> forwarder_; |
| 63 | |
| 64 | base::SingleThreadTaskExecutor task_executor_{base::MessagePumpType::IO}; |
| 65 | brillo::BaseMessageLoop brillo_loop_{task_executor_.task_runner()}; |
| 66 | }; |
| 67 | |
| 68 | TEST_F(SocketForwarderTest, ForwardDataAndClose) { |
| 69 | base::RunLoop loop; |
| 70 | forwarder_->SetStopQuitClosureForTesting(loop.QuitClosure()); |
| 71 | forwarder_->Start(); |
| 72 | |
| 73 | std::vector<char> msg(kDataSize, 1); |
| 74 | |
| 75 | EXPECT_EQ(peer0_->SendTo(msg.data(), msg.size()), kDataSize); |
| 76 | EXPECT_EQ(peer1_->SendTo(msg.data(), msg.size()), kDataSize); |
| 77 | // Close both sockets for writing. |
| 78 | EXPECT_NE(shutdown(peer0_->fd(), SHUT_WR), -1); |
| 79 | EXPECT_NE(shutdown(peer1_->fd(), SHUT_WR), -1); |
| 80 | |
| 81 | loop.Run(); |
| 82 | |
| 83 | EXPECT_FALSE(forwarder_->IsRunning()); |
| 84 | |
| 85 | // Verify that all the data has been forwarded to the peers. |
| 86 | std::vector<char> expected_data_peer0(kDataSize); |
| 87 | std::vector<char> expected_data_peer1(kDataSize); |
| 88 | EXPECT_TRUE(Read(peer1_.get(), expected_data_peer1.data(), kDataSize)); |
| 89 | EXPECT_TRUE(Read(peer0_.get(), expected_data_peer0.data(), kDataSize)); |
| 90 | |
| 91 | EXPECT_THAT(expected_data_peer0, Each(1)); |
| 92 | EXPECT_THAT(expected_data_peer1, Each(1)); |
| 93 | } |
| 94 | |
| 95 | TEST_F(SocketForwarderTest, PeerSignalEPOLLHUP) { |
| 96 | base::RunLoop loop; |
| 97 | forwarder_->SetStopQuitClosureForTesting(loop.QuitClosure()); |
| 98 | forwarder_->Start(); |
| 99 | |
| 100 | // Close the destination peer. |
| 101 | peer1_.reset(); |
| 102 | |
| 103 | loop.Run(); |
| 104 | |
| 105 | EXPECT_FALSE(forwarder_->IsRunning()); |
| 106 | } |
| 107 | |
| 108 | } // namespace patchpanel |