ovl: Support specifying certificate path

In the past, overlord used so-called self-signed certificate. When ovl
connect to it, it can download the certificate and use it to verify.
Because it's signed by itself.

However, with https://crrev.com/c/2467736, we create our own root CA and
use it to sign the certificate for overlord. Then it's impossible to
pass the verification if we can't specify the correct certificate.
Hence, `ovl connect` should be able to specify the path to the correct
root CA certificate to verify the connection.

BUG=b:170172074
TEST=Specify correct root CA, and `ovl connect` succeed
TEST=Specify wrong certificate, and `ovl connect` fail

Change-Id: I271354141514473b05e55e7ec19ff2453ef78cf0
Reviewed-on: https://chromium-review.googlesource.com/c/chromiumos/platform/factory/+/2497638
Reviewed-by: Cheng Yueh <cyueh@chromium.org>
Commit-Queue: Yilin Yang (kerker) <kerker@chromium.org>
Tested-by: Yilin Yang (kerker) <kerker@chromium.org>
diff --git a/py/tools/ovl.py b/py/tools/ovl.py
index bf02d48..717de43 100755
--- a/py/tools/ovl.py
+++ b/py/tools/ovl.py
@@ -17,6 +17,7 @@
 import os
 import re
 import select
+import shutil
 import signal
 import socket
 import ssl
@@ -74,28 +75,11 @@
 @@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@
 @ WARNING: REMOTE HOST VERIFICATION HAS FAILED! @
 @@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@
-IT IS POSSIBLE THAT SOMEONE IS DOING SOMETHING NASTY!
-Someone could be eavesdropping on you right now (man-in-the-middle attack)!
-It is also possible that the server is using a self-signed certificate.
-The fingerprint for the TLS host certificate sent by the remote host is
+Failed Reason: %s.
 
-%s
-
-Do you want to trust this certificate and proceed? [Y/n] """
-
-_TLS_CERT_CHANGED_WARNING = """
-@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@
-@ WARNING: REMOTE HOST IDENTIFICATION HAS CHANGED! @
-@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@
-IT IS POSSIBLE THAT SOMEONE IS DOING SOMETHING NASTY!
-Someone could be eavesdropping on you right now (man-in-the-middle attack)!
-It is also possible that the TLS host certificate has just been changed.
-The fingerprint for the TLS host certificate sent by the remote host is
-
-%s
-
-Remove '%s' if you still want to proceed.
-SSL Certificate verification failed."""
+Please use -c option to specify path of root CA certificate.
+This root CA certificate should be the one that signed the certificate used by
+overlord server."""
 
 
 def GetVersionDigest():
@@ -131,14 +115,6 @@
                                 context=state.ssl_context)
 
 
-def GetTLSCertificateSHA1Fingerprint(cert_pem):
-  beg = cert_pem.index('\n')
-  end = cert_pem.rindex('\n', 0, len(cert_pem) - 2)
-  cert_pem = cert_pem[beg:end]  # Remove BEGIN/END CERTIFICATE boundary
-  cert_der = base64.b64decode(cert_pem)
-  return hashlib.sha1(cert_der).hexdigest()
-
-
 def KillGraceful(pid, wait_secs=1):
   """Kill a process gracefully by first sending SIGTERM, wait for some time,
   then send SIGKILL to make sure it's killed."""
@@ -427,19 +403,18 @@
 
     tls_enabled = self._TLSEnabled()
     if tls_enabled:
-      result = self._CheckTLSCertificate(check_hostname)
-      if not result:
-        if self._state.ssl_self_signed:
-          return ('SSLCertificateChanged', ssl.get_server_certificate(
-              (self._state.host, self._state.port)))
-        return ('SSLVerifyFailed', ssl.get_server_certificate(
-            (self._state.host, self._state.port)))
+      if not os.path.exists(os.path.join(_CERT_DIR, '%s.cert' % host)):
+        return 'SSLCertificateNotExisted'
+
+      if not self._CheckTLSCertificate(check_hostname):
+        return 'SSLVerifyFailed'
 
     try:
       self._state.ssl = tls_enabled
       UrlOpen(self._state, '%s:%d' % (host, port))
     except urllib.error.HTTPError as e:
-      return ('HTTPError', e.getcode(), str(e), e.read().strip())
+      return ('HTTPError', e.getcode(), str(e),
+              e.read().strip().decode('utf-8'))
     except Exception as e:
       return str(e)
     else:
@@ -767,14 +742,6 @@
     elif command == 'forward':
       self.Forward(args)
 
-  def _SaveTLSCertificate(self, host, cert_pem):
-    try:
-      os.makedirs(_CERT_DIR)
-    except Exception:
-      pass
-    with open(GetTLSCertPath(host), 'w') as f:
-      f.write(cert_pem)
-
   def _HTTPPostFile(self, url, filename, progress=None, user=None, passwd=None):
     """Perform HTTP POST and upload file to Overlord.
 
@@ -951,27 +918,32 @@
   @Command('connect', 'connect to Overlord server', [
       Arg('host', metavar='HOST', type=str, default='localhost',
           help='Overlord hostname/IP'),
-      Arg('port', metavar='PORT', type=int,
-          default=_OVERLORD_HTTP_PORT, help='Overlord port'),
+      Arg('port', metavar='PORT', type=int, default=_OVERLORD_HTTP_PORT,
+          help='Overlord port'),
       Arg('-f', '--forward', dest='ssh_forward', default=False,
-          action='store_true',
-          help='connect with SSH forwarding to the host'),
-      Arg('-p', '--ssh-port', dest='ssh_port', default=22,
-          type=int, help='SSH server port for SSH forwarding'),
-      Arg('-l', '--ssh-login', dest='ssh_login', default='',
-          type=str, help='SSH server login name for SSH forwarding'),
-      Arg('-u', '--user', dest='user', default=None,
-          type=str, help='Overlord HTTP auth username'),
+          action='store_true', help='connect with SSH forwarding to the host'),
+      Arg('-p', '--ssh-port', dest='ssh_port', default=22, type=int,
+          help='SSH server port for SSH forwarding'),
+      Arg('-l', '--ssh-login', dest='ssh_login', default='', type=str,
+          help='SSH server login name for SSH forwarding'),
+      Arg('-u', '--user', dest='user', default=None, type=str,
+          help='Overlord HTTP auth username'),
       Arg('-w', '--passwd', dest='passwd', default=None, type=str,
           help='Overlord HTTP auth password'),
-      Arg('-i', '--no-check-hostname', dest='check_hostname',
-          default=True, action='store_false',
-          help='Ignore SSL cert hostname check')])
+      Arg('-c', '--root-CA', dest='cert', default=None, type=str,
+          help='Path to root CA certificate, only assign at the first time'),
+      Arg('-i', '--no-check-hostname', dest='check_hostname', default=True,
+          action='store_false', help='Ignore SSL cert hostname check')
+  ])
   def Connect(self, args):
     ssh_pid = None
     host = args.host
     orig_host = args.host
 
+    if args.cert and os.path.exists(args.cert):
+      os.makedirs(_CERT_DIR, exist_ok=True)
+      shutil.copy(args.cert, os.path.join(_CERT_DIR, '%s.cert' % host))
+
     if args.ssh_forward:
       # Kill previous SSH tunnel
       self.KillSSHTunnel()
@@ -995,24 +967,6 @@
                                    args.passwd, orig_host,
                                    args.check_hostname)
         if isinstance(ret, list):
-          if ret[0].startswith('SSL'):
-            cert_pem = ret[1]
-            fp = GetTLSCertificateSHA1Fingerprint(cert_pem)
-            fp_text = ':'.join([fp[i:i+2] for i in range(0, len(fp), 2)])
-
-          if ret[0] == 'SSLCertificateChanged':
-            print(_TLS_CERT_CHANGED_WARNING % (fp_text, GetTLSCertPath(host)))
-            return
-          if ret[0] == 'SSLVerifyFailed':
-            print(_TLS_CERT_FAILED_WARNING % (fp_text), end='')
-            response = input()
-            if response.lower() in ['y', 'ye', 'yes']:
-              self._SaveTLSCertificate(host, cert_pem)
-              print('TLS host Certificate trusted, you will not be prompted '
-                    'next time.\n')
-              continue
-            print('connection aborted.')
-            return
           if ret[0] == 'HTTPError':
             code, except_str, body = ret[1:]
             if code == 401:
@@ -1023,12 +977,15 @@
               break
             logging.error('%s; %s', except_str, body)
 
+        if ret in ('SSLCertificateNotExisted', 'SSLVerifyFailed'):
+          print(_TLS_CERT_FAILED_WARNING % ret)
+          return
         if ret is not True:
           print('can not connect to %s: %s' % (host, ret))
         else:
           print('connection to %s:%d established.' % (host, args.port))
       except Exception as e:
-        logging.error(e)
+        logging.exception(e)
       else:
         break