overlord: add TLS support for overlord-ghost link

The overlord server maybe setup outside the factory, thus the connection
between overlord and ghost must be encrypted to ensure security.

BUG=chromium:585732
TEST=`go/src/overlord/test/overlord_e2e_unittest.py`

Change-Id: I29f9153339774c2b64f3e8f8dcf5ad4fa572e5d1
Reviewed-on: https://chromium-review.googlesource.com/326100
Commit-Ready: Wei-Ning Huang <wnhuang@chromium.org>
Tested-by: Wei-Ning Huang <wnhuang@chromium.org>
Reviewed-by: Wei-Han Chen <stimim@chromium.org>
diff --git a/py/tools/ghost.py b/py/tools/ghost.py
index 8f7de48..c375d17 100755
--- a/py/tools/ghost.py
+++ b/py/tools/ghost.py
@@ -17,6 +17,7 @@
 import select
 import signal
 import socket
+import ssl
 import struct
 import subprocess
 import sys
@@ -56,6 +57,10 @@
 # Stream control
 _STDIN_CLOSED = '##STDIN_CLOSED##'
 
+# A string that will always be included in the response of
+# GET http://OVERLORD_SERVER:_OVERLORD_HTTP_PORT
+_OVERLORD_RESPONSE_KEYWORD = 'HTTP'
+
 SUCCESS = 'success'
 FAILED = 'failed'
 DISCONNECTED = 'disconnected'
@@ -69,15 +74,18 @@
   pass
 
 
-class BufferedSocket(socket.socket):
+class BufferedSocket(object):
   """A buffered socket that supports unrecv.
 
   Allow putting back data back to the socket for the next recv() call.
   """
-  def __init__(self, *args, **kwargs):
-    super(BufferedSocket, self).__init__(*args, **kwargs)
+  def __init__(self, sock):
+    self.sock = sock
     self._buf = ''
 
+  def fileno(self):
+    return self.sock.fileno()
+
   def Recv(self, bufsize, flags=0):
     if self._buf:
       if len(self._buf) >= bufsize:
@@ -87,15 +95,15 @@
       else:
         ret = self._buf
         self._buf = ''
-        return ret + super(BufferedSocket, self).recv(bufsize - len(ret), flags)
+        return ret + self.sock.recv(bufsize - len(ret), flags)
     else:
-      return super(BufferedSocket, self).recv(bufsize, flags)
+      return self.sock.recv(bufsize, flags)
 
   def UnRecv(self, buf):
     self._buf = buf + self._buf
 
   def Send(self, *args, **kwargs):
-    return super(BufferedSocket, self).send(*args, **kwargs)
+    return self.sock.send(*args, **kwargs)
 
   def RecvBuf(self):
     """Only recive from buffer."""
@@ -103,6 +111,39 @@
     self._buf = ''
     return ret
 
+  def Close(self):
+    self.sock.close()
+
+
+class TLSSettings(object):
+  def __init__(self, tls_cert_file, enable_tls_without_verify):
+    """Constructor.
+
+    Args:
+      tls_cert_file: TLS certificate in PEM format.
+      enable_tls_without_verify: enable TLS but don't verify certificate.
+    """
+    self._tls_cert_file = tls_cert_file
+    self._tls_context = None
+
+    if self._tls_cert_file is not None or enable_tls_without_verify:
+      self._tls_context = ssl.SSLContext(ssl.PROTOCOL_TLSv1_2)
+      self._tls_context.verify_mode = ssl.CERT_NONE
+      if self._tls_cert_file:
+        self._tls_context.verify_mode = ssl.CERT_REQUIRED
+        self._tls_context.check_hostname = True
+        try:
+          self._tls_context.load_verify_locations(self._tls_cert_file)
+        except IOError as e:
+          logging.error('TLSSettings: %s: %s', self._tls_cert_file, e)
+          sys.exit(1)
+
+  def Enabled(self):
+    return self._tls_context is not None
+
+  def Context(self):
+    return self._tls_context
+
 
 class Ghost(object):
   """Ghost implements the client protocol of Overlord.
@@ -124,13 +165,14 @@
 
   RANDOM_MID = '##random_mid##'
 
-  def __init__(self, overlord_addrs, mode=AGENT, mid=None, sid=None,
-               prop_file=None, terminal_sid=None, tty_device=None,
+  def __init__(self, overlord_addrs, tls_settings=None, mode=AGENT, mid=None,
+               sid=None, prop_file=None, terminal_sid=None, tty_device=None,
                command=None, file_op=None, port=None):
     """Constructor.
 
     Args:
       overlord_addrs: a list of possible address of overlord.
+      tls_settings: a TLSSetting object.
       mode: client mode, either AGENT, SHELL or LOGCAT
       mid: a str to set for machine ID. If mid equals Ghost.RANDOM_MID, machine
         id is randomly generated.
@@ -155,6 +197,7 @@
 
     self._overlord_addrs = overlord_addrs
     self._connected_addr = None
+    self._tls_settings = tls_settings
     self._mid = mid
     self._sock = None
     self._mode = mode
@@ -190,8 +233,8 @@
     with open(filename, 'r') as f:
       return hashlib.sha1(f.read()).hexdigest()
 
-  def UseSSL(self):
-    """Determine if SSL is enabled on the Overlord server."""
+  def OverlordHTTPSEnabled(self):
+    """Determine if SSL is enabled on the Overlord HTTP server."""
     sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
     try:
       sock.settimeout(_CONNECT_TIMEOUT)
@@ -199,25 +242,36 @@
       sock.send('GET\r\n')
 
       data = sock.recv(16)
-      return 'HTTP' not in data
+      return _OVERLORD_RESPONSE_KEYWORD not in data
     except Exception:
       return False  # For whatever reason above failed, assume HTTP
 
   def Upgrade(self):
     logging.info('Upgrade: initiating upgrade sequence...')
 
+    server_tls_enabled = self.OverlordHTTPSEnabled()
+    if self._tls_settings.Enabled() and not server_tls_enabled:
+      logging.error('Upgrade: TLS enforced but found Overlord HTTP server '
+                    'without TLS enabled! Possible mis-configuration or '
+                    'DNS/IP spoofing detected, abort')
+      return
+
     scriptpath = os.path.abspath(sys.argv[0])
     url = 'http%s://%s:%d/upgrade/ghost.py' % (
-        's' if self.UseSSL() else '', self._connected_addr[0],
+        's' if server_tls_enabled else '', self._connected_addr[0],
         _OVERLORD_HTTP_PORT)
 
     # Download sha1sum for ghost.py for verification
     try:
       with contextlib.closing(
-          urllib2.urlopen(url + '.sha1', timeout=_CONNECT_TIMEOUT)) as f:
+          urllib2.urlopen(url + '.sha1', timeout=_CONNECT_TIMEOUT,
+                          context=self._tls_settings.Context())) as f:
         if f.getcode() != 200:
           raise RuntimeError('HTTP status %d' % f.getcode())
         sha1sum = f.read().strip()
+    except (ssl.SSLError, ssl.CertificateError) as e:
+      logging.error('Upgrade: %s: %s', e.__class__.__name__, e)
+      return
     except Exception:
       logging.error('Upgrade: failed to download sha1sum file, abort')
       return
@@ -229,10 +283,14 @@
     # Download upgrade version of ghost.py
     try:
       with contextlib.closing(
-          urllib2.urlopen(url, timeout=_CONNECT_TIMEOUT)) as f:
+          urllib2.urlopen(url, timeout=_CONNECT_TIMEOUT,
+                          context=self._tls_settings.Context())) as f:
         if f.getcode() != 200:
           raise RuntimeError('HTTP status %d' % f.getcode())
         data = f.read()
+    except (ssl.SSLError, ssl.CertificateError) as e:
+      logging.error('Upgrade: %s: %s', e.__class__.__name__, e)
+      return
     except Exception:
       logging.error('Upgrade: failed to download upgrade, abort')
       return
@@ -260,7 +318,7 @@
         with open(self._prop_file, 'r') as f:
           self._properties = json.loads(f.read())
     except Exception as e:
-      logging.exception('LoadProperties: ' + str(e))
+      logging.error('LoadProperties: ' + str(e))
 
   def CloseSockets(self):
     # Close sockets opened by parent process, since we don't use it anymore.
@@ -285,7 +343,8 @@
     pid = os.fork()
     if pid == 0:
       self.CloseSockets()
-      g = Ghost([self._connected_addr], mode, Ghost.RANDOM_MID, sid,
+      g = Ghost([self._connected_addr], tls_settings=self._tls_settings,
+                mode=mode, mid=Ghost.RANDOM_MID, sid=sid,
                 terminal_sid=terminal_sid, tty_device=tty_device,
                 command=command, file_op=file_op, port=port)
       g.Start()
@@ -390,6 +449,9 @@
 
   def Reset(self):
     """Reset state and clear request handlers."""
+    if self._sock is not None:
+      self._sock.Close()
+      self._sock = None
     self._reset.clear()
     self._last_ping = 0
     self._requests = {}
@@ -512,7 +574,7 @@
     except Exception as e:
       logging.error('SpawnTTYServer: %s', e)
     finally:
-      self._sock.close()
+      self._sock.Close()
 
     logging.info('SpawnTTYServer: terminated')
     sys.exit(0)
@@ -545,7 +607,7 @@
 
       while True:
         rd, unused_wd, unused_xd = select.select(
-            [p.stdout, p.stderr, self._sock], [], [self._sock])
+            [p.stdout, p.stderr, self._sock], [], [])
         if p.stdout in rd:
           self._sock.Send(p.stdout.read(_BUFSIZE))
 
@@ -582,7 +644,7 @@
           pass
 
       p.wait()
-      self._sock.close()
+      self._sock.Close()
 
     logging.info('SpawnShellServer: terminated')
     sys.exit(0)
@@ -618,7 +680,7 @@
     except Exception as e:
       logging.error('StartDownloadServer: %s', e)
     finally:
-      self._sock.close()
+      self._sock.Close()
 
     logging.info('StartDownloadServer: terminated')
     sys.exit(0)
@@ -634,7 +696,6 @@
         except Exception:
           pass
 
-      self._sock.setblocking(False)
       with open(filepath, 'wb') as f:
         if self._file_op[2]:
           os.fchmod(f.fileno(), self._file_op[2])
@@ -653,7 +714,7 @@
     except Exception as e:
       logging.error('StartUploadServer: %s', e)
     finally:
-      self._sock.close()
+      self._sock.Close()
 
     logging.info('StartUploadServer: terminated')
     sys.exit(0)
@@ -667,7 +728,6 @@
       src_sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
       src_sock.settimeout(_CONNECT_TIMEOUT)
       src_sock.connect(('localhost', self._port))
-      src_sock.setblocking(False)
 
       src_sock.send(self._sock.RecvBuf())
 
@@ -690,7 +750,7 @@
     finally:
       if src_sock:
         src_sock.close()
-      self._sock.close()
+      self._sock.Close()
 
     logging.info('SpawnPortForwardServer: terminated')
     sys.exit(0)
@@ -797,7 +857,12 @@
 
   def ParseMessage(self, buf, single=True):
     if single:
-      index = buf.index(_SEPARATOR)
+      try:
+        index = buf.index(_SEPARATOR)
+      except ValueError:
+        self._sock.UnRecv(buf)
+        return
+
       msgs_json = [buf[:index]]
       self._sock.UnRecv(buf[index + 2:])
     else:
@@ -851,7 +916,6 @@
 
           # Socket is closed
           if len(data) == 0:
-            self.Reset()
             break
 
           self.ParseMessage(data, self._register_status != SUCCESS)
@@ -865,14 +929,13 @@
           self.InitiateDownload()
 
         if self._reset.is_set():
-          self.Reset()
           break
     except socket.error:
       raise RuntimeError('Connection dropped')
     except PingTimeoutError:
       raise RuntimeError('Connection timeout')
     finally:
-      self._sock.close()
+      self.Reset()
 
     self._queue.put('resume')
 
@@ -901,9 +964,26 @@
       try:
         logging.info('Trying %s:%d ...', *addr)
         self.Reset()
-        self._sock = BufferedSocket(socket.AF_INET, socket.SOCK_STREAM)
-        self._sock.settimeout(_CONNECT_TIMEOUT)
-        self._sock.connect(addr)
+
+        sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
+        sock.settimeout(_CONNECT_TIMEOUT)
+
+        try:
+          if self._tls_settings.Enabled():
+            tls_context = self._tls_settings.Context()
+            sock = tls_context.wrap_socket(sock, server_hostname=addr[0])
+
+          sock.connect(addr)
+        except (ssl.SSLError, ssl.CertificateError) as e:
+          logging.error('%s: %s', e.__class__.__name__, e)
+          continue
+        except IOError as e:
+          if e.errno == 2:  # No such file or directory
+            logging.error('%s: %s', e.__class__.__name__, e)
+            continue
+          raise
+
+        self._sock = BufferedSocket(sock)
 
         logging.info('Connection established, registering...')
         handler = {
@@ -924,7 +1004,7 @@
       except socket.error:
         pass
       else:
-        self._sock.settimeout(None)
+        sock.settimeout(None)
         self.Listen()
 
     raise RuntimeError('Cannot connect to any server')
@@ -1106,6 +1186,14 @@
   parser.add_argument('--no-rpc-server', dest='rpc_server',
                       action='store_false', default=True,
                       help='disable RPC server')
+  parser.add_argument('--tls-cert-file', metavar='TLS_CERT_FILE',
+                      dest='tls_cert_file', type=str, default=None,
+                      help='file containing the server TLS certificate in PEM '
+                           'format')
+  parser.add_argument('--enable-tls-without-verify',
+                      dest='enable_tls_without_verify', action='store_true',
+                      default=False,
+                      help='Enable TLS but don\'t verify certificate')
   parser.add_argument('--prop-file', metavar='PROP_FILE', dest='prop_file',
                       type=str, default=None,
                       help='file containing the JSON representation of client '
@@ -1132,9 +1220,16 @@
   addrs = [('localhost', _OVERLORD_PORT)]
   addrs += [(x, _OVERLORD_PORT) for x in args.overlord_ip]
 
-  g = Ghost(addrs, Ghost.AGENT, args.mid, prop_file=args.prop_file)
+  prop_file = os.path.abspath(args.prop_file) if args.prop_file else None
+
+  tls_settings = TLSSettings(args.tls_cert_file, args.enable_tls_without_verify)
+  g = Ghost(addrs, tls_settings, Ghost.AGENT, args.mid,
+            prop_file=prop_file)
   g.Start(args.lan_disc, args.rpc_server)
 
 
 if __name__ == '__main__':
-  main()
+  try:
+    main()
+  except Exception as e:
+    logging.error(e)