api/validate: add each_in validator

Add new validator to allow validating repeated fields, and their
subfieds, have specific values.

BUG=None
TEST=./run_pytest

Change-Id: I2c1645c64c2e69253d9a5b11c3dd50809bbae3b9
Reviewed-on: https://chromium-review.googlesource.com/c/chromiumos/chromite/+/2595815
Tested-by: Alex Klein <saklein@chromium.org>
Commit-Queue: Alex Klein <saklein@chromium.org>
Reviewed-by: Michael Mortensen <mmortensen@google.com>
diff --git a/api/validate_unittest.py b/api/validate_unittest.py
index 00e57ab..fdc6561 100644
--- a/api/validate_unittest.py
+++ b/api/validate_unittest.py
@@ -18,10 +18,16 @@
 from chromite.lib import cros_test_lib
 from chromite.lib import osutils
 
-
 assert sys.version_info >= (3, 6), 'This module requires Python 3.6+'
 
 
+# These tests test the validators by defining a local `impl` function that
+# has the same parameters as a controller function and the validator being
+# tested. The validators don't care that they aren't actually controller
+# functions, they just need the function to look like one, so it works
+# to pass an arbitrary message; i.e. passing one of the Request messages
+# we'd usually expect in a controller is not required. The validator
+# just needs to be checking one of the fields on the message being used.
 class ExistsTest(cros_test_lib.TempDirTestCase, api_config.ApiConfigMixin):
   """Tests for the exists validator."""
 
@@ -100,6 +106,186 @@
     impl(common_pb2.Chroot(), None, self.no_validate_config)
 
 
+class EachInTest(cros_test_lib.TestCase, api_config.ApiConfigMixin):
+  """Tests for the each_in validator."""
+
+  # Easier access to the enum values.
+  ENUM_FOO = build_api_test_pb2.TEST_ENUM_FOO
+  ENUM_BAR = build_api_test_pb2.TEST_ENUM_BAR
+  ENUM_BAZ = build_api_test_pb2.TEST_ENUM_BAZ
+
+  # pylint: disable=docstring-misnamed-args
+  def _message_request(self, *messages):
+    """Build a request instance, filling out the messages field.
+
+    Args:
+      messages: Each messages data (id, name, flag, enum) as lists. Only
+        requires as many as are set. e.g. _request([1], [2]) will create two
+        messages with only ids set. _request([1, 'name']) will create one with
+        id and name set, but not flag or enum.
+    """
+    request = build_api_test_pb2.TestRequestMessage()
+    for message in messages or []:
+      msg = request.messages.add()
+      try:
+        msg.id = message[0]
+        msg.name = message[1]
+        msg.flag = message[2]
+      except IndexError:
+        pass
+
+    return request
+
+  def _enums_request(self, *enum_values):
+    """Build a request instance, setting the test_enums field."""
+    request = build_api_test_pb2.TestRequestMessage()
+    for value in enum_values:
+      request.test_enums.append(value)
+
+    return request
+
+  def _numbers_request(self, *numbers):
+    """Build a request instance, setting the numbers field."""
+    request = build_api_test_pb2.TestRequestMessage()
+    request.numbers.extend(numbers)
+
+    return request
+
+  def test_message_in(self):
+    """Test valid values."""
+
+    @validate.each_in('messages', 'name', ['foo', 'bar'])
+    def impl(_input_proto, _output_proto, _config):
+      pass
+
+    impl(self._message_request([1, 'foo']), None, self.api_config)
+    impl(self._message_request([1, 'foo'], [2, 'bar']), None, self.api_config)
+
+  def test_enum_in(self):
+    """Test valid enum values."""
+
+    @validate.each_in('test_enums', None, [self.ENUM_FOO, self.ENUM_BAR])
+    def impl(_input_proto, _output_proto, _config):
+      pass
+
+    impl(self._enums_request(self.ENUM_FOO), None, self.api_config)
+    impl(self._enums_request(self.ENUM_FOO, self.ENUM_BAR), None,
+         self.api_config)
+
+  def test_scalar_in(self):
+    """Test valid scalar values."""
+
+    @validate.each_in('numbers', None, [1, 2])
+    def impl(_input_proto, _output_proto, _config):
+      pass
+
+    impl(self._numbers_request(1), None, self.api_config)
+    impl(self._numbers_request(1, 2), None, self.api_config)
+
+  def test_message_not_in(self):
+    """Test an invalid value."""
+
+    @validate.each_in('messages', 'name', ['foo', 'bar'])
+    def impl(_input_proto, _output_proto, _config):
+      pass
+
+    # Should be failing on the invalid value.
+    with self.assertRaises(cros_build_lib.DieSystemExit):
+      impl(self._message_request([1, 'invalid']), None, self.api_config)
+    with self.assertRaises(cros_build_lib.DieSystemExit):
+      impl(self._message_request([1, 'invalid'], [2, 'invalid']), None,
+           self.api_config)
+    with self.assertRaises(cros_build_lib.DieSystemExit):
+      impl(self._message_request([1, 'foo'], [2, 'invalid']), None,
+           self.api_config)
+
+  def test_enum_not_in(self):
+    """Test an invalid enum value."""
+
+    @validate.each_in('test_enums', None, [self.ENUM_FOO, self.ENUM_BAR])
+    def impl(_input_proto, _output_proto, _config):
+      pass
+
+    # Only invalid values.
+    with self.assertRaises(cros_build_lib.DieSystemExit):
+      impl(self._enums_request(self.ENUM_BAZ), None, self.api_config)
+    # Mixed valid/invalid values.
+    with self.assertRaises(cros_build_lib.DieSystemExit):
+      impl(self._enums_request(self.ENUM_FOO, self.ENUM_BAZ), None,
+           self.api_config)
+
+  def test_scalar_not_in(self):
+    """Test invalid scalar value."""
+
+    @validate.each_in('numbers', None, [1, 2])
+    def impl(_input_proto, _output_proto, _config):
+      pass
+
+    # Only invalid values.
+    with self.assertRaises(cros_build_lib.DieSystemExit):
+      impl(self._numbers_request(3), None, self.api_config)
+    # Mixed valid/invalid values.
+    with self.assertRaises(cros_build_lib.DieSystemExit):
+      impl(self._numbers_request(1, 2, 3), None, self.api_config)
+
+  def test_not_set(self):
+    """Test an unset value."""
+
+    @validate.each_in('messages', 'name', ['foo', 'bar'])
+    def impl(_input_proto, _output_proto, _config):
+      pass
+
+    # Should be failing without a value set.
+    # No entries in the field.
+    with self.assertRaises(cros_build_lib.DieSystemExit):
+      impl(self._message_request(), None, self.api_config)
+    # No value set on lone entry.
+    with self.assertRaises(cros_build_lib.DieSystemExit):
+      impl(self._message_request([1]), None, self.api_config)
+    # No value set on multiple entries.
+    with self.assertRaises(cros_build_lib.DieSystemExit):
+      impl(self._message_request([1], [2]), None, self.api_config)
+    # Some valid and some invalid entries.
+    with self.assertRaises(cros_build_lib.DieSystemExit):
+      impl(self._message_request([1, 'foo'], [2]), None, self.api_config)
+
+  def test_optional(self):
+    """Test optional argument."""
+
+    @validate.each_in('messages', 'name', ['foo', 'bar'], optional=True)
+    @validate.each_in('test_enums', None, [self.ENUM_FOO, self.ENUM_BAR],
+                      optional=True)
+    @validate.each_in('numbers', None, [1, 2], optional=True)
+    def impl(_input_proto, _output_proto, _config):
+      pass
+
+    # No entries in the field succeeds.
+    impl(self._message_request(), None, self.api_config)
+
+    # Still fails when entries exist but value unset cases.
+    # No value set on lone entry.
+    with self.assertRaises(cros_build_lib.DieSystemExit):
+      impl(self._message_request([1]), None, self.api_config)
+    # No value set on multiple entries.
+    with self.assertRaises(cros_build_lib.DieSystemExit):
+      impl(self._message_request([1], [2]), None, self.api_config)
+    # Some valid and some invalid entries.
+    with self.assertRaises(cros_build_lib.DieSystemExit):
+      impl(self._message_request([1, 'foo'], [2]), None, self.api_config)
+
+  def test_skip_validation(self):
+    """Test skipping validation case."""
+
+    @validate.each_in('messages', 'name', ['foo', 'bar'])
+    @validate.each_in('test_enums', None, [self.ENUM_FOO, self.ENUM_BAR])
+    @validate.each_in('numbers', None, [1, 2])
+    def impl(_input_proto, _output_proto, _config):
+      pass
+
+    # This would otherwise raise an error for multiple invalid fields.
+    impl(self._message_request([1, 'invalid']), None, self.no_validate_config)
+
+
 class RequireTest(cros_test_lib.TestCase, api_config.ApiConfigMixin):
   """Tests for the require validator."""