Support for rlz_brand_code and customization_id in factory flow.

BUG=chrome-os-partner:28445
TEST=Unit tests, manually tested on device

Change-Id: I3a0ac1b0aa3ab06a914e3015aa93533ad62243c7
Reviewed-on: https://chromium-review.googlesource.com/197475
Reviewed-by: Jon Salz <jsalz@chromium.org>
Commit-Queue: Jon Salz <jsalz@chromium.org>
Tested-by: Jon Salz <jsalz@chromium.org>
diff --git a/py/gooftool/__init__.py b/py/gooftool/__init__.py
index 736a719..ccdf722 100644
--- a/py/gooftool/__init__.py
+++ b/py/gooftool/__init__.py
@@ -28,6 +28,8 @@
 from cros.factory.privacy import FilterDict
 from cros.factory.rule import Context
 from cros.factory.system import vpd
+from cros.factory.test import branding
+from cros.factory.tools.mount_partition import MountPartition
 from cros.factory.utils.process_utils import CheckOutput, GetLines
 from cros.factory.utils.string_utils import ParseDict
 
@@ -404,6 +406,51 @@
     if any(tpm_status[k] != v for k, v in tpm_cleared_status.iteritems()):
       raise Error, 'TPM is not cleared.'
 
+  def VerifyBranding(self):
+    """Verify that branding fields are properly set.
+
+    Returns:
+      A dictionary containing rlz_brand_code and customization_id fields,
+      for testing.
+    """
+    ro_vpd = vpd.ro.GetAll()
+
+    customization_id = ro_vpd.get('customization_id')
+    logging.info('RO VPD customization_id: %r', customization_id)
+    if customization_id is not None:
+      if not branding.CUSTOMIZATION_ID_REGEXP.match(customization_id):
+        raise ValueError('Bad format for customization_id %r in RO VPD '
+                         '(expected it to match regexp %r)' % (
+            customization_id, branding.CUSTOMIZATION_ID_REGEXP.pattern))
+
+    rlz_brand_code = ro_vpd.get('rlz_brand_code')
+
+    logging.info('RO VPD rlz_brand_code: %r', rlz_brand_code)
+    if rlz_brand_code is None:
+      # It must be present as BRAND_CODE_PATH in rootfs.
+      with MountPartition(
+          self._util.GetReleaseRootPartitionPath()) as mount_path:
+        path = os.path.join(mount_path, branding.BRAND_CODE_PATH.lstrip('/'))
+        if not os.path.exists(path):
+          raise ValueError('rlz_brand_code is not present in RO VPD, and %s '
+                           'does not exist in release rootfs' % (
+              branding.BRAND_CODE_PATH))
+        with open(path) as f:
+          rlz_brand_code = f.read().strip()
+          logging.info('rlz_brand_code from rootfs: %r', rlz_brand_code)
+      rlz_brand_code_source = 'release_rootfs'
+    else:
+      rlz_brand_code_source = 'RO VPD'
+
+    if not branding.RLZ_BRAND_CODE_REGEXP.match(rlz_brand_code):
+      raise ValueError('Bad format for rlz_brand_code %r in %s '
+                       '(expected it to match regexp %r)' % (
+          rlz_brand_code, rlz_brand_code_source,
+          branding.CUSTOMIZATION_ID_REGEXP.pattern))
+
+    return dict(rlz_brand_code=rlz_brand_code,
+                customization_id=customization_id)
+
   def ClearGBBFlags(self):
     """Zero out the GBB flags, in preparation for transition to release state.
 
diff --git a/py/gooftool/gooftool.py b/py/gooftool/gooftool.py
index 8fbf247..971197c 100755
--- a/py/gooftool/gooftool.py
+++ b/py/gooftool/gooftool.py
@@ -544,6 +544,18 @@
     event_log.Log('switch_dev', type='virtual switch')
 
 
+@Command('verify_branding')
+def VerifyBranding(options):  # pylint: disable=W0613
+  """Verify that branding fields are properly set.
+
+  customization_id, if set in the RO VPD, must be of the correct format.
+
+  rlz_brand_code must be set either in the RO VPD or OEM partition, and must
+  be of the correct format.
+  """
+  return GetGooftool(options).VerifyBranding()
+
+
 @Command('write_protect')
 def EnableFwWp(options):  # pylint: disable=W0613
   """Enable then verify firmware write protection."""
@@ -654,6 +666,7 @@
   VerifyKeys(options)
   VerifyRootFs(options)
   VerifyTPM(options)
+  VerifyBranding(options)
 
 @Command('untar_stateful_files')
 def UntarStatefulFiles(dummy_options):
diff --git a/py/gooftool/gooftool_unittest.py b/py/gooftool/gooftool_unittest.py
index ca850da..22f0bb8 100755
--- a/py/gooftool/gooftool_unittest.py
+++ b/py/gooftool/gooftool_unittest.py
@@ -13,6 +13,7 @@
 import unittest
 
 from collections import namedtuple
+from contextlib import contextmanager
 from tempfile import NamedTemporaryFile
 
 import factory_common  # pylint: disable=W0611
@@ -27,6 +28,9 @@
 from cros.factory.hwdb.hwid_tool import ProbeResults  # pylint: disable=E0611
 from cros.factory.gooftool import Mismatch
 from cros.factory.gooftool import ProbedComponentResult
+from cros.factory.system import vpd
+from cros.factory.test import branding
+from cros.factory.utils import file_utils
 from cros.factory.utils.process_utils import CheckOutput
 
 _TEST_DATA_PATH = os.path.join(os.path.dirname(__file__), 'testdata')
@@ -356,6 +360,79 @@
     self._gooftool.VerifyWPSwitch()
     self.assertRaises(Error, self._gooftool.VerifyWPSwitch)
 
+  def _SetupBrandingMocks(self, ro_vpd, fake_rootfs_path):
+    """Set up mocks for VerifyBranding tests.
+
+    Args:
+      ro_vpd: The dictionary to use for the RO VPD.
+      fake_rootfs_path: A path at which we pretend to mount the release rootfs.
+    """
+
+    # Fake partition to return from MountPartition mock.
+    @contextmanager
+    def MockPartition(path):
+      yield path
+
+    self.mox.StubOutWithMock(vpd.ro, "GetAll")
+    self.mox.StubOutWithMock(gooftool, "MountPartition")
+
+    vpd.ro.GetAll().AndReturn(ro_vpd)
+    if fake_rootfs_path:
+      # Pretend that '/dev/rel' is the release rootfs path.
+      self._gooftool._util.GetReleaseRootPartitionPath().AndReturn('/dev/rel')
+      # When '/dev/rel' is mounted, return a context manager yielding
+      # fake_rootfs_path.
+      gooftool.MountPartition('/dev/rel').AndReturn(
+          MockPartition(fake_rootfs_path))
+
+  def testVerifyBranding_NoBrandCode(self):
+    self._SetupBrandingMocks({}, '/doesntexist')
+    self.mox.ReplayAll()
+    # Should fail, since rlz_brand_code isn't present anywhere
+    self.assertRaisesRegexp(ValueError, 'rlz_brand_code is not present',
+                            self._gooftool.VerifyBranding)
+
+  def testVerifyBranding_AllInVPD(self):
+    self._SetupBrandingMocks(
+        dict(rlz_brand_code='ABCD', customization_id='FOO'), None)
+    self.mox.ReplayAll()
+    self.assertEquals(dict(rlz_brand_code='ABCD', customization_id='FOO'),
+                      self._gooftool.VerifyBranding())
+
+  def testVerifyBranding_BrandCodeInVPD(self):
+    self._SetupBrandingMocks(dict(rlz_brand_code='ABCD'), None)
+    self.mox.ReplayAll()
+    self.assertEquals(dict(rlz_brand_code='ABCD', customization_id=None),
+                      self._gooftool.VerifyBranding())
+
+  def testVerifyBranding_BrandCodeInRootFS(self):
+    with file_utils.TempDirectory() as tmp:
+      # Create a /opt/oem/etc/BRAND_CODE file within the fake mounted rootfs.
+      rlz_brand_code_path = os.path.join(
+          tmp, branding.BRAND_CODE_PATH.lstrip('/'))
+      file_utils.TryMakeDirs(os.path.dirname(rlz_brand_code_path))
+      with open(rlz_brand_code_path, 'w') as f:
+        f.write('ABCD')
+
+      self._SetupBrandingMocks({}, tmp)
+      self.mox.ReplayAll()
+      self.assertEquals(dict(rlz_brand_code='ABCD', customization_id=None),
+                        self._gooftool.VerifyBranding())
+
+  def testVerifyBranding_BadBrandCode(self):
+    self._SetupBrandingMocks(dict(rlz_brand_code='ABCDx',
+                                  customization_id='FOO'), None)
+    self.mox.ReplayAll()
+    self.assertRaisesRegexp(ValueError, 'Bad format for rlz_brand_code',
+                            self._gooftool.VerifyBranding)
+
+  def testVerifyBranding_BadConfigurationId(self):
+    self._SetupBrandingMocks(dict(rlz_brand_code='ABCD',
+                                  customization_id='FOOx'), None)
+    self.mox.ReplayAll()
+    self.assertRaisesRegexp(ValueError, 'Bad format for customization_id',
+                            self._gooftool.VerifyBranding)
+
   def testCheckDevSwitchForDisabling(self):
     # 1st call: virtual switch
     self._gooftool._util.GetVBSharedDataFlags().AndReturn(0x400)