api/validate: Add require_each validator.
The require_each validator allows specifying fields of a repeated
message that must be specified for each instance.
BUG=chromium:1130818
TEST=run_pytest
Cq-Depend: chromium:2426691
Change-Id: I42da9fa19fd13b700715734cf9c92d7c03b49556
Reviewed-on: https://chromium-review.googlesource.com/c/chromiumos/chromite/+/2424785
Commit-Queue: Alex Klein <saklein@chromium.org>
Tested-by: Alex Klein <saklein@chromium.org>
Reviewed-by: Michael Mortensen <mmortensen@google.com>
diff --git a/api/validate_unittest.py b/api/validate_unittest.py
index cce8674..e42b740 100644
--- a/api/validate_unittest.py
+++ b/api/validate_unittest.py
@@ -12,6 +12,7 @@
from chromite.api import api_config
from chromite.api import validate
+from chromite.api.gen.chromite.api import build_api_test_pb2
from chromite.api.gen.chromiumos import common_pb2
from chromite.lib import cros_build_lib
from chromite.lib import cros_test_lib
@@ -153,6 +154,127 @@
impl(common_pb2.Chroot(), None, self.no_validate_config)
+class RequireEachTest(cros_test_lib.TestCase, api_config.ApiConfigMixin):
+ """Tests for the require_each validator."""
+
+ def _multi_field_message(self, msg_id=None, name=None, flag=None):
+ msg = build_api_test_pb2.MultiFieldMessage()
+ if msg_id is not None:
+ msg.id = int(msg_id)
+ if name is not None:
+ msg.name = str(name)
+ if flag is not None:
+ msg.flag = bool(flag)
+ return msg
+
+ def _request(self, messages=None, count=0):
+ """Build the request."""
+ if messages is None:
+ messages = [self._multi_field_message() for _ in range(count)]
+
+ request = build_api_test_pb2.TestRequestMessage()
+ for message in messages:
+ msg = request.messages.add()
+ msg.CopyFrom(message)
+
+ return request
+
+ def test_invalid_field(self):
+ """Test validator fails when given an invalid field."""
+
+ @validate.require_each('does.not', ['exist'])
+ def impl(_input_proto, _output_proto, _config):
+ self.fail('Incorrectly allowed method to execute.')
+
+ with self.assertRaises(cros_build_lib.DieSystemExit):
+ impl(self._request(), None, self.api_config)
+
+ def test_invalid_call_no_subfields(self):
+ """Test validator fails when given no subfields."""
+
+ with self.assertRaises(AssertionError):
+ @validate.require_each('does.not', [])
+ def _(_input_proto, _output_proto, _config):
+ pass
+
+ def test_invalid_call_invalid_subfields(self):
+ """Test validator fails when given subfields incorrectly."""
+
+ with self.assertRaises(AssertionError):
+ @validate.require_each('does.not', 'exist')
+ def _(_input_proto, _output_proto, _config):
+ pass
+
+ def test_not_set(self):
+ """Test validator fails when given an unset value."""
+
+ @validate.require_each('messages', ['id'])
+ def impl(_input_proto, _output_proto, _config):
+ self.fail('Incorrectly allowed method to execute.')
+
+ with self.assertRaises(cros_build_lib.DieSystemExit):
+ impl(self._request(count=2), None, self.api_config)
+
+ def test_no_elements_success(self):
+ """Test validator fails when given no messages in the repeated field."""
+
+ @validate.require_each('messages', ['id'])
+ def impl(_input_proto, _output_proto, _config):
+ pass
+
+ impl(self._request(), None, self.api_config)
+
+ def test_no_elements_failure(self):
+ """Test validator fails when given no messages in the repeated field."""
+
+ @validate.require_each('messages', ['id'], allow_empty=False)
+ def impl(_input_proto, _output_proto, _config):
+ self.fail('Incorrectly allowed method to execute.')
+
+ with self.assertRaises(cros_build_lib.DieSystemExit):
+ impl(self._request(), None, self.api_config)
+
+ def test_set(self):
+ """Test validator passes when given set values."""
+
+ @validate.require_each('messages', ['id'])
+ def impl(_input_proto, _output_proto, _config):
+ pass
+
+ messages = [self._multi_field_message(msg_id=i) for i in range(1, 5)]
+ impl(self._request(messages=messages), None, self.api_config)
+
+ def test_one_set_fails(self):
+ """Test validator passes when given set values."""
+
+ @validate.require_each('messages', ['id', 'name'])
+ def impl(_input_proto, _output_proto, _config):
+ pass
+
+ messages = [self._multi_field_message(msg_id=i) for i in range(1, 5)]
+ with self.assertRaises(cros_build_lib.DieSystemExit):
+ impl(self._request(messages=messages), None, self.api_config)
+
+ def test_multi_set(self):
+ """Test validator passes when all values set."""
+
+ @validate.require_each('messages', ['id', 'name'])
+ def impl(_input_proto, _output_proto, _config):
+ pass
+
+ messages = [self._multi_field_message(msg_id=i, name=i)
+ for i in range(1, 5)]
+ impl(self._request(messages=messages), None, self.api_config)
+
+ def test_skip_validation(self):
+ """Test skipping validation case."""
+ @validate.require_each('messages', ['id'], allow_empty=False)
+ def impl(_input_proto, _output_proto, _config):
+ pass
+
+ impl(self._request(), None, self.no_validate_config)
+
+
class ValidateOnlyTest(cros_test_lib.TestCase, api_config.ApiConfigMixin):
"""validate_only decorator tests."""