api/validate: Add require_each validator.

The require_each validator allows specifying fields of a repeated
message that must be specified for each instance.

BUG=chromium:1130818
TEST=run_pytest

Cq-Depend: chromium:2426691
Change-Id: I42da9fa19fd13b700715734cf9c92d7c03b49556
Reviewed-on: https://chromium-review.googlesource.com/c/chromiumos/chromite/+/2424785
Commit-Queue: Alex Klein <saklein@chromium.org>
Tested-by: Alex Klein <saklein@chromium.org>
Reviewed-by: Michael Mortensen <mmortensen@google.com>
diff --git a/api/validate_unittest.py b/api/validate_unittest.py
index cce8674..e42b740 100644
--- a/api/validate_unittest.py
+++ b/api/validate_unittest.py
@@ -12,6 +12,7 @@
 
 from chromite.api import api_config
 from chromite.api import validate
+from chromite.api.gen.chromite.api import build_api_test_pb2
 from chromite.api.gen.chromiumos import common_pb2
 from chromite.lib import cros_build_lib
 from chromite.lib import cros_test_lib
@@ -153,6 +154,127 @@
     impl(common_pb2.Chroot(), None, self.no_validate_config)
 
 
+class RequireEachTest(cros_test_lib.TestCase, api_config.ApiConfigMixin):
+  """Tests for the require_each validator."""
+
+  def _multi_field_message(self, msg_id=None, name=None, flag=None):
+    msg = build_api_test_pb2.MultiFieldMessage()
+    if msg_id is not None:
+      msg.id = int(msg_id)
+    if name is not None:
+      msg.name = str(name)
+    if flag is not None:
+      msg.flag = bool(flag)
+    return msg
+
+  def _request(self, messages=None, count=0):
+    """Build the request."""
+    if messages is None:
+      messages = [self._multi_field_message() for _ in range(count)]
+
+    request = build_api_test_pb2.TestRequestMessage()
+    for message in messages:
+      msg = request.messages.add()
+      msg.CopyFrom(message)
+
+    return request
+
+  def test_invalid_field(self):
+    """Test validator fails when given an invalid field."""
+
+    @validate.require_each('does.not', ['exist'])
+    def impl(_input_proto, _output_proto, _config):
+      self.fail('Incorrectly allowed method to execute.')
+
+    with self.assertRaises(cros_build_lib.DieSystemExit):
+      impl(self._request(), None, self.api_config)
+
+  def test_invalid_call_no_subfields(self):
+    """Test validator fails when given no subfields."""
+
+    with self.assertRaises(AssertionError):
+      @validate.require_each('does.not', [])
+      def _(_input_proto, _output_proto, _config):
+        pass
+
+  def test_invalid_call_invalid_subfields(self):
+    """Test validator fails when given subfields incorrectly."""
+
+    with self.assertRaises(AssertionError):
+      @validate.require_each('does.not', 'exist')
+      def _(_input_proto, _output_proto, _config):
+        pass
+
+  def test_not_set(self):
+    """Test validator fails when given an unset value."""
+
+    @validate.require_each('messages', ['id'])
+    def impl(_input_proto, _output_proto, _config):
+      self.fail('Incorrectly allowed method to execute.')
+
+    with self.assertRaises(cros_build_lib.DieSystemExit):
+      impl(self._request(count=2), None, self.api_config)
+
+  def test_no_elements_success(self):
+    """Test validator fails when given no messages in the repeated field."""
+
+    @validate.require_each('messages', ['id'])
+    def impl(_input_proto, _output_proto, _config):
+      pass
+
+    impl(self._request(), None, self.api_config)
+
+  def test_no_elements_failure(self):
+    """Test validator fails when given no messages in the repeated field."""
+
+    @validate.require_each('messages', ['id'], allow_empty=False)
+    def impl(_input_proto, _output_proto, _config):
+      self.fail('Incorrectly allowed method to execute.')
+
+    with self.assertRaises(cros_build_lib.DieSystemExit):
+      impl(self._request(), None, self.api_config)
+
+  def test_set(self):
+    """Test validator passes when given set values."""
+
+    @validate.require_each('messages', ['id'])
+    def impl(_input_proto, _output_proto, _config):
+      pass
+
+    messages = [self._multi_field_message(msg_id=i) for i in range(1, 5)]
+    impl(self._request(messages=messages), None, self.api_config)
+
+  def test_one_set_fails(self):
+    """Test validator passes when given set values."""
+
+    @validate.require_each('messages', ['id', 'name'])
+    def impl(_input_proto, _output_proto, _config):
+      pass
+
+    messages = [self._multi_field_message(msg_id=i) for i in range(1, 5)]
+    with self.assertRaises(cros_build_lib.DieSystemExit):
+      impl(self._request(messages=messages), None, self.api_config)
+
+  def test_multi_set(self):
+    """Test validator passes when all values set."""
+
+    @validate.require_each('messages', ['id', 'name'])
+    def impl(_input_proto, _output_proto, _config):
+      pass
+
+    messages = [self._multi_field_message(msg_id=i, name=i)
+                for i in range(1, 5)]
+    impl(self._request(messages=messages), None, self.api_config)
+
+  def test_skip_validation(self):
+    """Test skipping validation case."""
+    @validate.require_each('messages', ['id'], allow_empty=False)
+    def impl(_input_proto, _output_proto, _config):
+      pass
+
+    impl(self._request(), None, self.no_validate_config)
+
+
 class ValidateOnlyTest(cros_test_lib.TestCase, api_config.ApiConfigMixin):
   """validate_only decorator tests."""