patchpanel: datapath: Add generic IP forwarding functions

This patch adds a unique ModifyIpForwarding function and base all
FORWARD ACCEPT rule commands of ot it. There is no functional change in
this patch.

BUG=b:161507671
BUG=b:161508179
TEST=Unit tests.

Change-Id: I9dcadb601524cf8e582a937eb83bc84774453476
Reviewed-on: https://chromium-review.googlesource.com/c/chromiumos/platform2/+/2359936
Tested-by: Hugo Benichi <hugobenichi@google.com>
Commit-Queue: Hugo Benichi <hugobenichi@google.com>
Reviewed-by: Taoyu Li <taoyl@chromium.org>
diff --git a/patchpanel/datapath.cc b/patchpanel/datapath.cc
index b60e18c..8e3b7a0 100644
--- a/patchpanel/datapath.cc
+++ b/patchpanel/datapath.cc
@@ -15,6 +15,8 @@
 #include <sys/ioctl.h>
 #include <sys/socket.h>
 
+#include <vector>
+
 #include <base/files/scoped_file.h>
 #include <base/logging.h>
 #include <base/posix/eintr_wrapper.h>
@@ -406,13 +408,11 @@
 // TODO(hugobenichi) The name incorrectly refers to egress traffic, but this
 // FORWARD rule actually enables forwarding for ingress traffic. Fix the name.
 bool Datapath::AddOutboundIPv4(const std::string& ifname) {
-  return process_runner_->iptables("filter", {"-A", "FORWARD", "-o", ifname,
-                                              "-j", "ACCEPT", "-w"}) == 0;
+  return StartIpForwarding(IpFamily::IPv4, "", ifname);
 }
 
 void Datapath::RemoveOutboundIPv4(const std::string& ifname) {
-  process_runner_->iptables(
-      "filter", {"-D", "FORWARD", "-o", ifname, "-j", "ACCEPT", "-w"});
+  StopIpForwarding(IpFamily::IPv4, "", ifname);
 }
 
 bool Datapath::AddSNATMarkRules() {
@@ -534,25 +534,75 @@
   process_runner_->ip6("neigh", "del", {"proxy", ipv6_addr, "dev", ifname});
 }
 
-bool Datapath::AddIPv6Forwarding(const std::string& ifname1,
-                                 const std::string& ifname2) {
-  if (process_runner_->ip6tables(
-          "filter",
-          {"-C", "FORWARD", "-i", ifname1, "-o", ifname2, "-j", "ACCEPT", "-w"},
-          false /*log_failures*/) != 0 &&
-      process_runner_->ip6tables(
-          "filter", {"-A", "FORWARD", "-i", ifname1, "-o", ifname2, "-j",
-                     "ACCEPT", "-w"}) != 0) {
+bool Datapath::ModifyIpForwarding(IpFamily family,
+                                  const std::string& op,
+                                  const std::string& iif,
+                                  const std::string& oif,
+                                  bool log_failures) {
+  if (iif.empty() && oif.empty()) {
+    LOG(ERROR) << "Cannot change IP forwarding with no input or output "
+                  "interface specified";
     return false;
   }
 
-  if (process_runner_->ip6tables(
-          "filter",
-          {"-C", "FORWARD", "-i", ifname2, "-o", ifname1, "-j", "ACCEPT", "-w"},
-          false /*log_failures*/) != 0 &&
-      process_runner_->ip6tables(
-          "filter", {"-A", "FORWARD", "-i", ifname2, "-o", ifname1, "-j",
-                     "ACCEPT", "-w"}) != 0) {
+  switch (family) {
+    case IPv4:
+    case IPv6:
+    case Dual:
+      break;
+    default:
+
+      LOG(ERROR) << "Cannot change IP forwarding from \"" << iif << "\" to \""
+                 << oif << "\": incorrect IP family " << family;
+      return false;
+  }
+
+  std::vector<std::string> args = {op, "FORWARD"};
+  if (!iif.empty()) {
+    args.push_back("-i");
+    args.push_back(iif);
+  }
+  if (!oif.empty()) {
+    args.push_back("-o");
+    args.push_back(oif);
+  }
+  args.push_back("-j");
+  args.push_back("ACCEPT");
+  args.push_back("-w");
+
+  bool success = true;
+  if (family & IpFamily::IPv4)
+    success &= process_runner_->iptables("filter", args, log_failures) == 0;
+  if (family & IpFamily::IPv6)
+    success &= process_runner_->ip6tables("filter", args, log_failures) == 0;
+  return success;
+}
+
+bool Datapath::StartIpForwarding(IpFamily family,
+                                 const std::string& iif,
+                                 const std::string& oif) {
+  return ModifyIpForwarding(family, "-A", iif, oif);
+}
+
+bool Datapath::StopIpForwarding(IpFamily family,
+                                const std::string& iif,
+                                const std::string& oif) {
+  return ModifyIpForwarding(family, "-D", iif, oif);
+}
+
+bool Datapath::AddIPv6Forwarding(const std::string& ifname1,
+                                 const std::string& ifname2) {
+  // Only start Ipv6 forwarding if -C returns false and it had not been
+  // started yet.
+  if (!ModifyIpForwarding(IpFamily::IPv6, "-C", ifname1, ifname2,
+                          false /*log_failures*/) &&
+      !StartIpForwarding(IpFamily::IPv6, ifname1, ifname2)) {
+    return false;
+  }
+
+  if (!ModifyIpForwarding(IpFamily::IPv6, "-C", ifname2, ifname1,
+                          false /*log_failures*/) &&
+      !StartIpForwarding(IpFamily::IPv6, ifname2, ifname1)) {
     RemoveIPv6Forwarding(ifname1, ifname2);
     return false;
   }
@@ -562,11 +612,8 @@
 
 void Datapath::RemoveIPv6Forwarding(const std::string& ifname1,
                                     const std::string& ifname2) {
-  process_runner_->ip6tables("filter", {"-D", "FORWARD", "-i", ifname1, "-o",
-                                        ifname2, "-j", "ACCEPT", "-w"});
-
-  process_runner_->ip6tables("filter", {"-D", "FORWARD", "-i", ifname2, "-o",
-                                        ifname1, "-j", "ACCEPT", "-w"});
+  StopIpForwarding(IpFamily::IPv6, ifname1, ifname2);
+  StopIpForwarding(IpFamily::IPv6, ifname2, ifname1);
 }
 
 bool Datapath::AddIPv4Route(uint32_t gateway_addr,
diff --git a/patchpanel/datapath.h b/patchpanel/datapath.h
index 72afb25..03b13b6 100644
--- a/patchpanel/datapath.h
+++ b/patchpanel/datapath.h
@@ -20,6 +20,14 @@
 
 namespace patchpanel {
 
+// Simple enum of bitmasks used for specifying a set of IP family values.
+enum IpFamily {
+  NONE = 0,
+  IPv4 = 1 << 0,
+  IPv6 = 1 << 1,
+  Dual = IPv4 | IPv6,  //(1 << 0) | (1 << 1);
+};
+
 // cros lint will yell to force using int16/int64 instead of long here, however
 // note that unsigned long IS the correct signature for ioctl in Linux kernel -
 // it's 32 bits on 32-bit platform and 64 bits on 64-bit one.
@@ -162,6 +170,20 @@
                                   uint16_t on,
                                   uint16_t off = 0);
 
+  // Starts or stops accepting IP traffic forwarded between |iif| and |oif|
+  // by adding or removing ACCEPT rules in the filter FORWARD chain of iptables
+  // and/or ip6tables. If |iif| is empty, only specifies |oif| as the output
+  // interface.  If |iif| is empty, only specifies |iif| as the input interface.
+  // |oif| and |iif| cannot be both empty.
+  virtual bool StartIpForwarding(IpFamily family,
+                                 const std::string& iif,
+                                 const std::string& oif);
+  virtual bool StopIpForwarding(IpFamily family,
+                                const std::string& iif,
+                                const std::string& oif);
+
+  // Convenience functions for enabling or disabling IPv6 forwarding in both
+  // directions between a pair of interfaces
   virtual bool AddIPv6Forwarding(const std::string& ifname1,
                                  const std::string& ifname2);
   virtual void RemoveIPv6Forwarding(const std::string& ifname1,
@@ -207,6 +229,12 @@
   MinijailedProcessRunner& runner() const;
 
  private:
+  bool ModifyIpForwarding(IpFamily family,
+                          const std::string& op,
+                          const std::string& iif,
+                          const std::string& oif,
+                          bool log_failures = true);
+
   MinijailedProcessRunner* process_runner_;
   Firewall* firewall_;
   ioctl_t ioctl_;
diff --git a/patchpanel/datapath_test.cc b/patchpanel/datapath_test.cc
index 14a9f05..f850b78 100644
--- a/patchpanel/datapath_test.cc
+++ b/patchpanel/datapath_test.cc
@@ -105,6 +105,15 @@
                int(const std::string& netns_name, bool log_failures));
 };
 
+TEST(DatapathTest, IpFamily) {
+  EXPECT_EQ(IpFamily::Dual, IpFamily::IPv4 | IpFamily::IPv6);
+  EXPECT_EQ(IpFamily::Dual & IpFamily::IPv4, IpFamily::IPv4);
+  EXPECT_EQ(IpFamily::Dual & IpFamily::IPv6, IpFamily::IPv6);
+  EXPECT_NE(IpFamily::Dual, IpFamily::IPv4);
+  EXPECT_NE(IpFamily::Dual, IpFamily::IPv6);
+  EXPECT_NE(IpFamily::IPv4, IpFamily::IPv6);
+}
+
 TEST(DatapathTest, AddTAP) {
   MockProcessRunner runner;
   MockFirewall firewall;
@@ -367,6 +376,101 @@
                              TrafficSource::CROSVM);
 }
 
+TEST(DatapathTest, StartStopIpForwarding) {
+  struct {
+    IpFamily family;
+    std::string iif;
+    std::string oif;
+    std::vector<std::string> start_args;
+    std::vector<std::string> stop_args;
+    bool result;
+  } testcases[] = {
+      {IpFamily::IPv4, "", "", {}, {}, false},
+      {IpFamily::NONE, "foo", "bar", {}, {}, false},
+      {IpFamily::IPv4,
+       "foo",
+       "bar",
+       {"-A", "FORWARD", "-i", "foo", "-o", "bar", "-j", "ACCEPT", "-w"},
+       {"-D", "FORWARD", "-i", "foo", "-o", "bar", "-j", "ACCEPT", "-w"},
+       true},
+      {IpFamily::IPv4,
+       "",
+       "bar",
+       {"-A", "FORWARD", "-o", "bar", "-j", "ACCEPT", "-w"},
+       {"-D", "FORWARD", "-o", "bar", "-j", "ACCEPT", "-w"},
+       true},
+      {IpFamily::IPv4,
+       "foo",
+       "",
+       {"-A", "FORWARD", "-i", "foo", "-j", "ACCEPT", "-w"},
+       {"-D", "FORWARD", "-i", "foo", "-j", "ACCEPT", "-w"},
+       true},
+      {IpFamily::IPv6,
+       "foo",
+       "bar",
+       {"-A", "FORWARD", "-i", "foo", "-o", "bar", "-j", "ACCEPT", "-w"},
+       {"-D", "FORWARD", "-i", "foo", "-o", "bar", "-j", "ACCEPT", "-w"},
+       true},
+      {IpFamily::IPv6,
+       "",
+       "bar",
+       {"-A", "FORWARD", "-o", "bar", "-j", "ACCEPT", "-w"},
+       {"-D", "FORWARD", "-o", "bar", "-j", "ACCEPT", "-w"},
+       true},
+      {IpFamily::IPv6,
+       "foo",
+       "",
+       {"-A", "FORWARD", "-i", "foo", "-j", "ACCEPT", "-w"},
+       {"-D", "FORWARD", "-i", "foo", "-j", "ACCEPT", "-w"},
+       true},
+      {IpFamily::Dual,
+       "foo",
+       "bar",
+       {"-A", "FORWARD", "-i", "foo", "-o", "bar", "-j", "ACCEPT", "-w"},
+       {"-D", "FORWARD", "-i", "foo", "-o", "bar", "-j", "ACCEPT", "-w"},
+       true},
+      {IpFamily::Dual,
+       "",
+       "bar",
+       {"-A", "FORWARD", "-o", "bar", "-j", "ACCEPT", "-w"},
+       {"-D", "FORWARD", "-o", "bar", "-j", "ACCEPT", "-w"},
+       true},
+      {IpFamily::Dual,
+       "foo",
+       "",
+       {"-A", "FORWARD", "-i", "foo", "-j", "ACCEPT", "-w"},
+       {"-D", "FORWARD", "-i", "foo", "-j", "ACCEPT", "-w"},
+       true},
+  };
+
+  for (const auto& tt : testcases) {
+    MockProcessRunner runner;
+    MockFirewall firewall;
+    if (tt.result) {
+      if (tt.family & IpFamily::IPv4) {
+        EXPECT_CALL(runner,
+                    iptables(StrEq("filter"), tt.start_args, true, nullptr))
+            .WillOnce(Return(0));
+        EXPECT_CALL(runner,
+                    iptables(StrEq("filter"), tt.stop_args, true, nullptr))
+            .WillOnce(Return(0));
+      }
+      if (tt.family & IpFamily::IPv6) {
+        EXPECT_CALL(runner,
+                    ip6tables(StrEq("filter"), tt.start_args, true, nullptr))
+            .WillOnce(Return(0));
+        EXPECT_CALL(runner,
+                    ip6tables(StrEq("filter"), tt.stop_args, true, nullptr))
+            .WillOnce(Return(0));
+      }
+    }
+    Datapath datapath(&runner, &firewall);
+
+    EXPECT_EQ(tt.result, datapath.StartIpForwarding(tt.family, tt.iif, tt.oif));
+    EXPECT_EQ(tt.result, datapath.StopIpForwarding(tt.family, tt.iif, tt.oif));
+  }
+}
+
 TEST(DatapathTest, AddInboundIPv4DNAT) {
   MockProcessRunner runner;
   MockFirewall firewall;