overlord: ovl: implement SSL certificate check support

Instead of ignoring the check for SSL certificate, prompt user to accept
the fingerprint of a host. The certificate is stored for future
connection. If the future connection failed to verify with the already
stored certificate, warn user about it.

BUG=chromium:517520,chromium:585732
TEST=manually

Change-Id: I380abc962313c1f9e5cea9ddee00e27fede927ff
Reviewed-on: https://chromium-review.googlesource.com/326930
Commit-Ready: Wei-Ning Huang <wnhuang@chromium.org>
Tested-by: Wei-Ning Huang <wnhuang@chromium.org>
Reviewed-by: Wei-Han Chen <stimim@chromium.org>
diff --git a/py/tools/ovl.py b/py/tools/ovl.py
index 2174e1c..f5bb5db 100755
--- a/py/tools/ovl.py
+++ b/py/tools/ovl.py
@@ -1,4 +1,4 @@
-#!/usr/bin/python -u
+#!/usr/bin/env python
 # -*- coding: utf-8 -*-
 #
 # Copyright 2015 The Chromium OS Authors. All rights reserved.
@@ -22,6 +22,7 @@
 import select
 import signal
 import socket
+import ssl
 import StringIO
 import struct
 import subprocess
@@ -38,14 +39,8 @@
 from jsonrpclib.config import Config
 from ws4py.client import WebSocketBaseClient
 
-# Python version >= 2.7.9 enables SSL check by default, bypass it.
-try:
-  import ssl
-  # pylint: disable=W0212
-  ssl._create_default_https_context = ssl._create_unverified_context
-except Exception:
-  pass
 
+_CERT_DIR = os.path.expanduser('~/.config/ovl')
 
 _ESCAPE = '~'
 _BUFSIZ = 8192
@@ -54,6 +49,7 @@
 _OVERLORD_CLIENT_DAEMON_PORT = 4488
 _OVERLORD_CLIENT_DAEMON_RPC_ADDR = ('127.0.0.1', _OVERLORD_CLIENT_DAEMON_PORT)
 
+_CONNECT_TIMEOUT = 3
 _DEFAULT_HTTP_TIMEOUT = 30
 _LIST_CACHE_TIMEOUT = 2
 _DEFAULT_TERMINAL_WIDTH = 80
@@ -76,6 +72,20 @@
 # GET http://OVERLORD_SERVER:_OVERLORD_HTTP_PORT
 _OVERLORD_RESPONSE_KEYWORD = 'HTTP'
 
+_TLS_CERT_CHANGED_WARNING = """
+@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@
+@ WARNING: REMOTE HOST IDENTIFICATION HAS CHANGED! @
+@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@
+IT IS POSSIBLE THAT SOMEONE IS DOING SOMETHING NASTY!
+Someone could be eavesdropping on you right now (man-in-the-middle attack)!
+It is also possible that the SSL host certificate has just been changed.
+The fingerprint for the SSL host certificate sent by the remote host is
+
+%s
+
+Remove '%s' if you still want to proceed.
+SSL Certificate verification failed."""
+
 
 def GetVersionDigest():
   """Return the sha1sum of the current executing script."""
@@ -83,6 +93,32 @@
     return hashlib.sha1(f.read()).hexdigest()
 
 
+def GetTLSCertPath(host):
+  return os.path.join(_CERT_DIR, '%s.cert' % host)
+
+
+def UrlOpen(state, url):
+  """Wrapper for urllib2.urlopen.
+
+  It selects correct HTTP scheme according to self._state.ssl, add HTTP
+  basic auth headers, and add specify correct SSL context.
+  """
+  url = MakeRequestUrl(state, url)
+  request = urllib2.Request(url)
+  if state.username is not None and state.password is not None:
+    request.add_header(*BasicAuthHeader(state.username, state.password))
+  return urllib2.urlopen(request, timeout=_DEFAULT_HTTP_TIMEOUT,
+                         context=state.ssl_context)
+
+
+def GetTLSCertificateSHA1Fingerprint(cert_pem):
+  beg = cert_pem.index('\n')
+  end = cert_pem.rindex('\n', 0, len(cert_pem) - 2)
+  cert_pem = cert_pem[beg:end]  # Remove BEGIN/END CERTIFICATE boundary
+  cert_der = base64.b64decode(cert_pem)
+  return hashlib.sha1(cert_der).hexdigest()
+
+
 def KillGraceful(pid, wait_secs=1):
   """Kill a process gracefully by first sending SIGTERM, wait for some time,
   then send SIGKILL to make sure it's killed."""
@@ -226,6 +262,7 @@
     self.host = None
     self.port = None
     self.ssl = False
+    self.ssl_context = ssl.SSLContext(ssl.PROTOCOL_TLSv1_2)
     self.ssh = False
     self.orig_host = None
     self.ssh_pid = None
@@ -287,22 +324,52 @@
   def GetPid(self):
     return os.getpid()
 
-  def _UrlOpen(self, url):
-    """Wrapper for urllib2.urlopen.
-
-    It selects correct HTTP scheme according to self._stat.ssl and add HTTP
-    basic auth headers.
-    """
-    url = MakeRequestUrl(self._state, url)
-    request = urllib2.Request(url)
-    if self._state.username is not None and self._state.password is not None:
-      request.add_header(*BasicAuthHeader(self._state.username,
-                                          self._state.password))
-    return urllib2.urlopen(request, timeout=_DEFAULT_HTTP_TIMEOUT)
-
   def _GetJSON(self, path):
     url = '%s:%d%s' % (self._state.host, self._state.port, path)
-    return json.loads(self._UrlOpen(url).read())
+    return json.loads(UrlOpen(self._state, url).read())
+
+  def _OverlordHTTPSEnabled(self):
+    """Determine if SSL is enabled on the Overlord HTTP server."""
+    sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
+    try:
+      sock.settimeout(_CONNECT_TIMEOUT)
+      sock.connect((self._state.host, self._state.port))
+      sock.send('GET\r\n')
+
+      data = sock.recv(16)
+      return _OVERLORD_RESPONSE_KEYWORD not in data
+    except Exception:
+      return False  # For whatever reason above failed, assume HTTP
+
+  def _CheckTLSCertificate(self):
+    """Check TLS certificate.
+
+    Returns:
+      A tupple (check_result, if_certificate_is_loaded)
+    """
+    tls_context = ssl.SSLContext(ssl.PROTOCOL_TLSv1_2)
+    tls_context.verify_mode = ssl.CERT_REQUIRED
+    tls_context.check_hostname = True
+    cert_loaded = False
+
+    tls_cert_path = GetTLSCertPath(self._state.host)
+    if os.path.exists(tls_cert_path):
+      tls_context.load_verify_locations(tls_cert_path)
+      cert_loaded = True
+
+    sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
+    try:
+      sock.settimeout(_CONNECT_TIMEOUT)
+      sock = tls_context.wrap_socket(sock, server_hostname=self._state.host)
+      sock.connect((self._state.host, self._state.port))
+    except ssl.SSLError:
+      return False, cert_loaded
+    finally:
+      sock.close()
+
+    # Save SSLContext for future use.
+    self._state.ssl_context = tls_context
+    return True, None
 
   def Connect(self, host, port=_OVERLORD_HTTP_PORT, ssh_pid=None,
               username=None, password=None, orig_host=None):
@@ -315,14 +382,22 @@
     self._state.ssh_pid = ssh_pid
     self._state.selected_mid = None
 
+    tls_enabled = self._OverlordHTTPSEnabled()
+    if tls_enabled:
+      result, cert_loaded = self._CheckTLSCertificate()
+      if not result:
+        if cert_loaded:
+          return ('SSLCertificateChanged', ssl.get_server_certificate(
+              (self._state.host, self._state.port)))
+        else:
+          return ('SSLVerifyFailed', ssl.get_server_certificate(
+              (self._state.host, self._state.port)))
+
     try:
-      h = self._UrlOpen('%s:%d' % (host, port))
-      # Probably not an HTTP server, try HTTPS
-      if _OVERLORD_RESPONSE_KEYWORD not in h.read():
-        self._state.ssl = True
-        self._UrlOpen('%s:%d' % (host, port))
+      self._state.ssl = tls_enabled
+      UrlOpen(self._state, '%s:%d' % (host, port))
     except urllib2.HTTPError as e:
-      return (e.getcode(), str(e), e.read().strip())
+      return ('HTTPError', e.getcode(), str(e), e.read().strip())
     except Exception as e:
       return str(e)
     else:
@@ -360,9 +435,19 @@
     self._state.forwards = {}
 
 
-class TerminalWebSocketClient(WebSocketBaseClient):
-  def __init__(self, mid, *args, **kwargs):
-    super(TerminalWebSocketClient, self).__init__(*args, **kwargs)
+class SSLEnabledWebSocketBaseClient(WebSocketBaseClient):
+  def __init__(self, host, *args, **kwargs):
+    ssl_options = {
+        'cert_reqs': ssl.CERT_REQUIRED,
+        'ca_certs': GetTLSCertPath(host)
+    }
+    super(SSLEnabledWebSocketBaseClient, self).__init__(
+        ssl_options=ssl_options, *args, **kwargs)
+
+
+class TerminalWebSocketClient(SSLEnabledWebSocketBaseClient):
+  def __init__(self, host, mid, *args, **kwargs):
+    super(TerminalWebSocketClient, self).__init__(host, *args, **kwargs)
     self._mid = mid
     self._stdin_fd = sys.stdin.fileno()
     self._old_termios = None
@@ -428,15 +513,15 @@
       sys.stdout.flush()
 
 
-class ShellWebSocketClient(WebSocketBaseClient):
-  def __init__(self, output, *args, **kwargs):
+class ShellWebSocketClient(SSLEnabledWebSocketBaseClient):
+  def __init__(self, host, output, *args, **kwargs):
     """Constructor.
 
     Args:
       output: output file object.
     """
     self.output = output
-    super(ShellWebSocketClient, self).__init__(*args, **kwargs)
+    super(ShellWebSocketClient, self).__init__(host, *args, **kwargs)
 
   def handshake_ok(self):
     pass
@@ -467,9 +552,9 @@
       self.output.flush()
 
 
-class ForwarderWebSocketClient(WebSocketBaseClient):
-  def __init__(self, sock, *args, **kwargs):
-    super(ForwarderWebSocketClient, self).__init__(*args, **kwargs)
+class ForwarderWebSocketClient(SSLEnabledWebSocketBaseClient):
+  def __init__(self, host, sock, *args, **kwargs):
+    super(ForwarderWebSocketClient, self).__init__(host, *args, **kwargs)
     self._sock = sock
     self._stop = threading.Event()
 
@@ -616,13 +701,13 @@
     elif command == 'forward':
       self.Forward(args)
 
-  def _UrlOpen(self, url):
-    url = MakeRequestUrl(self._state, url)
-    request = urllib2.Request(url)
-    if self._state.username is not None and self._state.password is not None:
-      request.add_header(*BasicAuthHeader(self._state.username,
-                                          self._state.password))
-    return urllib2.urlopen(request, timeout=_DEFAULT_HTTP_TIMEOUT)
+  def _SaveTLSCertificate(self, host, cert_pem):
+    try:
+      os.makedirs(_CERT_DIR)
+    except Exception:
+      pass
+    with open(GetTLSCertPath(host), 'w') as f:
+      f.write(cert_pem)
 
   def _HTTPPostFile(self, url, filename, progress=None, user=None, passwd=None):
     """Perform HTTP POST and upload file to Overlord.
@@ -650,7 +735,7 @@
     if parse.scheme == 'http':
       h = httplib.HTTP(parse.netloc)
     else:
-      h = httplib.HTTPS(parse.netloc)
+      h = httplib.HTTPS(parse.netloc, context=self._state.ssl_context)
 
     post_path = url[url.index(parse.netloc) + len(parse.netloc):]
     h.putrequest('POST', post_path)
@@ -768,7 +853,7 @@
 
     scheme = 'ws%s://' % ('s' if self._state.ssl else '')
     sio = StringIO.StringIO()
-    ws = ShellWebSocketClient(sio,
+    ws = ShellWebSocketClient(self._state.host, sio,
                               scheme + '%s:%d/api/agent/shell/%s?command=%s' %
                               (self._state.host, self._state.port,
                                self._selected_mid, urllib2.quote(command)),
@@ -837,23 +922,42 @@
 
         ret = self._server.Connect(host, args.port, ssh_pid, args.user,
                                    args.passwd, orig_host)
-        # HTTPError
         if isinstance(ret, list):
-          code, except_str, body = ret
-          if code == 401:
-            print('connect: %s' % body)
-            prompt = True
-            if not username_provided or not password_provided:
+          if ret[0].startswith('SSL'):
+            cert_pem = ret[1]
+            fp = GetTLSCertificateSHA1Fingerprint(cert_pem)
+            fp_text = ':'.join([fp[i:i+2] for i in range(0, len(fp), 2)])
+
+          if ret[0] == 'SSLCertificateChanged':
+            print(_TLS_CERT_CHANGED_WARNING % (fp_text, GetTLSCertPath(host)))
+            return
+          elif ret[0] == 'SSLVerifyFailed':
+            print('Server fingerprint: %s' % fp_text)
+            response = raw_input('Do you want to continue? [Y/n] ')
+            if response.lower() in ['y', 'ye', 'yes']:
+              self._SaveTLSCertificate(host, cert_pem)
               continue
             else:
-              break
-          else:
-            logging.error('%s; %s', except_str, body)
+              print('connection aborted.')
+              return
+          elif ret[0] == 'HTTPError':
+            code, except_str, body = ret[1:]
+            if code == 401:
+              print('connect: %s' % body)
+              prompt = True
+              if not username_provided or not password_provided:
+                continue
+              else:
+                break
+            else:
+              logging.error('%s; %s', except_str, body)
 
         if ret is not True:
           print('can not connect to %s: %s' % (host, ret))
+        else:
+          print('connection to %s:%d established.' % (host, args.port))
       except Exception as e:
-        logging.exception(e)
+        logging.error(e)
       else:
         break
 
@@ -865,7 +969,7 @@
       time.sleep(1)
       self._server = OverlordClientDaemon.GetRPCServer()
       if self._server is not None:
-        print('* daemon started successfully *')
+        print('* daemon started successfully *\n')
 
   @Command('kill-server', 'kill overlord CLI client server')
   def KillServer(self):
@@ -932,13 +1036,13 @@
 
     scheme = 'ws%s://' % ('s' if self._state.ssl else '')
     if len(command) == 0:
-      ws = TerminalWebSocketClient(self._selected_mid,
+      ws = TerminalWebSocketClient(self._state.host, self._selected_mid,
                                    scheme + '%s:%d/api/agent/tty/%s' %
                                    (self._state.host, self._state.port,
                                     self._selected_mid), headers=headers)
     else:
       cmd = ' '.join(command)
-      ws = ShellWebSocketClient(sys.stdout,
+      ws = ShellWebSocketClient(self._state.host, sys.stdout,
                                 scheme + '%s:%d/api/agent/shell/%s?command=%s' %
                                 (self._state.host, self._state.port,
                                  self._selected_mid, urllib2.quote(cmd)),
@@ -981,7 +1085,7 @@
              (self._state.host, self._state.port, self._selected_mid, dst,
               mode))
       try:
-        self._UrlOpen(url + '&filename=%s' % src_base)
+        UrlOpen(self._state, url + '&filename=%s' % src_base)
       except urllib2.HTTPError as e:
         msg = json.loads(e.read()).get('error', None)
         raise RuntimeError('push: %s' % msg)
@@ -1058,7 +1162,7 @@
              (self._state.host, self._state.port, self._selected_mid,
               urllib2.quote(src)))
       try:
-        h = self._UrlOpen(url)
+        h = UrlOpen(self._state, url)
       except urllib2.HTTPError as e:
         msg = json.loads(e.read()).get('error', 'unkown error')
         raise RuntimeError('pull: %s' % msg)
@@ -1164,7 +1268,7 @@
 
       scheme = 'ws%s://' % ('s' if self._state.ssl else '')
       ws = ForwarderWebSocketClient(
-          conn,
+          self._state.host, conn,
           scheme + '%s:%d/api/agent/forward/%s?port=%d' %
           (self._state.host, self._state.port, self._selected_mid, remote),
           headers=headers)