overlord: add TLS support for overlord-ghost link
The overlord server maybe setup outside the factory, thus the connection
between overlord and ghost must be encrypted to ensure security.
BUG=chromium:585732
TEST=`go/src/overlord/test/overlord_e2e_unittest.py`
Change-Id: I29f9153339774c2b64f3e8f8dcf5ad4fa572e5d1
Reviewed-on: https://chromium-review.googlesource.com/326100
Commit-Ready: Wei-Ning Huang <wnhuang@chromium.org>
Tested-by: Wei-Ning Huang <wnhuang@chromium.org>
Reviewed-by: Wei-Han Chen <stimim@chromium.org>
diff --git a/py/tools/ghost.py b/py/tools/ghost.py
index 8f7de48..c375d17 100755
--- a/py/tools/ghost.py
+++ b/py/tools/ghost.py
@@ -17,6 +17,7 @@
import select
import signal
import socket
+import ssl
import struct
import subprocess
import sys
@@ -56,6 +57,10 @@
# Stream control
_STDIN_CLOSED = '##STDIN_CLOSED##'
+# A string that will always be included in the response of
+# GET http://OVERLORD_SERVER:_OVERLORD_HTTP_PORT
+_OVERLORD_RESPONSE_KEYWORD = 'HTTP'
+
SUCCESS = 'success'
FAILED = 'failed'
DISCONNECTED = 'disconnected'
@@ -69,15 +74,18 @@
pass
-class BufferedSocket(socket.socket):
+class BufferedSocket(object):
"""A buffered socket that supports unrecv.
Allow putting back data back to the socket for the next recv() call.
"""
- def __init__(self, *args, **kwargs):
- super(BufferedSocket, self).__init__(*args, **kwargs)
+ def __init__(self, sock):
+ self.sock = sock
self._buf = ''
+ def fileno(self):
+ return self.sock.fileno()
+
def Recv(self, bufsize, flags=0):
if self._buf:
if len(self._buf) >= bufsize:
@@ -87,15 +95,15 @@
else:
ret = self._buf
self._buf = ''
- return ret + super(BufferedSocket, self).recv(bufsize - len(ret), flags)
+ return ret + self.sock.recv(bufsize - len(ret), flags)
else:
- return super(BufferedSocket, self).recv(bufsize, flags)
+ return self.sock.recv(bufsize, flags)
def UnRecv(self, buf):
self._buf = buf + self._buf
def Send(self, *args, **kwargs):
- return super(BufferedSocket, self).send(*args, **kwargs)
+ return self.sock.send(*args, **kwargs)
def RecvBuf(self):
"""Only recive from buffer."""
@@ -103,6 +111,39 @@
self._buf = ''
return ret
+ def Close(self):
+ self.sock.close()
+
+
+class TLSSettings(object):
+ def __init__(self, tls_cert_file, enable_tls_without_verify):
+ """Constructor.
+
+ Args:
+ tls_cert_file: TLS certificate in PEM format.
+ enable_tls_without_verify: enable TLS but don't verify certificate.
+ """
+ self._tls_cert_file = tls_cert_file
+ self._tls_context = None
+
+ if self._tls_cert_file is not None or enable_tls_without_verify:
+ self._tls_context = ssl.SSLContext(ssl.PROTOCOL_TLSv1_2)
+ self._tls_context.verify_mode = ssl.CERT_NONE
+ if self._tls_cert_file:
+ self._tls_context.verify_mode = ssl.CERT_REQUIRED
+ self._tls_context.check_hostname = True
+ try:
+ self._tls_context.load_verify_locations(self._tls_cert_file)
+ except IOError as e:
+ logging.error('TLSSettings: %s: %s', self._tls_cert_file, e)
+ sys.exit(1)
+
+ def Enabled(self):
+ return self._tls_context is not None
+
+ def Context(self):
+ return self._tls_context
+
class Ghost(object):
"""Ghost implements the client protocol of Overlord.
@@ -124,13 +165,14 @@
RANDOM_MID = '##random_mid##'
- def __init__(self, overlord_addrs, mode=AGENT, mid=None, sid=None,
- prop_file=None, terminal_sid=None, tty_device=None,
+ def __init__(self, overlord_addrs, tls_settings=None, mode=AGENT, mid=None,
+ sid=None, prop_file=None, terminal_sid=None, tty_device=None,
command=None, file_op=None, port=None):
"""Constructor.
Args:
overlord_addrs: a list of possible address of overlord.
+ tls_settings: a TLSSetting object.
mode: client mode, either AGENT, SHELL or LOGCAT
mid: a str to set for machine ID. If mid equals Ghost.RANDOM_MID, machine
id is randomly generated.
@@ -155,6 +197,7 @@
self._overlord_addrs = overlord_addrs
self._connected_addr = None
+ self._tls_settings = tls_settings
self._mid = mid
self._sock = None
self._mode = mode
@@ -190,8 +233,8 @@
with open(filename, 'r') as f:
return hashlib.sha1(f.read()).hexdigest()
- def UseSSL(self):
- """Determine if SSL is enabled on the Overlord server."""
+ def OverlordHTTPSEnabled(self):
+ """Determine if SSL is enabled on the Overlord HTTP server."""
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
try:
sock.settimeout(_CONNECT_TIMEOUT)
@@ -199,25 +242,36 @@
sock.send('GET\r\n')
data = sock.recv(16)
- return 'HTTP' not in data
+ return _OVERLORD_RESPONSE_KEYWORD not in data
except Exception:
return False # For whatever reason above failed, assume HTTP
def Upgrade(self):
logging.info('Upgrade: initiating upgrade sequence...')
+ server_tls_enabled = self.OverlordHTTPSEnabled()
+ if self._tls_settings.Enabled() and not server_tls_enabled:
+ logging.error('Upgrade: TLS enforced but found Overlord HTTP server '
+ 'without TLS enabled! Possible mis-configuration or '
+ 'DNS/IP spoofing detected, abort')
+ return
+
scriptpath = os.path.abspath(sys.argv[0])
url = 'http%s://%s:%d/upgrade/ghost.py' % (
- 's' if self.UseSSL() else '', self._connected_addr[0],
+ 's' if server_tls_enabled else '', self._connected_addr[0],
_OVERLORD_HTTP_PORT)
# Download sha1sum for ghost.py for verification
try:
with contextlib.closing(
- urllib2.urlopen(url + '.sha1', timeout=_CONNECT_TIMEOUT)) as f:
+ urllib2.urlopen(url + '.sha1', timeout=_CONNECT_TIMEOUT,
+ context=self._tls_settings.Context())) as f:
if f.getcode() != 200:
raise RuntimeError('HTTP status %d' % f.getcode())
sha1sum = f.read().strip()
+ except (ssl.SSLError, ssl.CertificateError) as e:
+ logging.error('Upgrade: %s: %s', e.__class__.__name__, e)
+ return
except Exception:
logging.error('Upgrade: failed to download sha1sum file, abort')
return
@@ -229,10 +283,14 @@
# Download upgrade version of ghost.py
try:
with contextlib.closing(
- urllib2.urlopen(url, timeout=_CONNECT_TIMEOUT)) as f:
+ urllib2.urlopen(url, timeout=_CONNECT_TIMEOUT,
+ context=self._tls_settings.Context())) as f:
if f.getcode() != 200:
raise RuntimeError('HTTP status %d' % f.getcode())
data = f.read()
+ except (ssl.SSLError, ssl.CertificateError) as e:
+ logging.error('Upgrade: %s: %s', e.__class__.__name__, e)
+ return
except Exception:
logging.error('Upgrade: failed to download upgrade, abort')
return
@@ -260,7 +318,7 @@
with open(self._prop_file, 'r') as f:
self._properties = json.loads(f.read())
except Exception as e:
- logging.exception('LoadProperties: ' + str(e))
+ logging.error('LoadProperties: ' + str(e))
def CloseSockets(self):
# Close sockets opened by parent process, since we don't use it anymore.
@@ -285,7 +343,8 @@
pid = os.fork()
if pid == 0:
self.CloseSockets()
- g = Ghost([self._connected_addr], mode, Ghost.RANDOM_MID, sid,
+ g = Ghost([self._connected_addr], tls_settings=self._tls_settings,
+ mode=mode, mid=Ghost.RANDOM_MID, sid=sid,
terminal_sid=terminal_sid, tty_device=tty_device,
command=command, file_op=file_op, port=port)
g.Start()
@@ -390,6 +449,9 @@
def Reset(self):
"""Reset state and clear request handlers."""
+ if self._sock is not None:
+ self._sock.Close()
+ self._sock = None
self._reset.clear()
self._last_ping = 0
self._requests = {}
@@ -512,7 +574,7 @@
except Exception as e:
logging.error('SpawnTTYServer: %s', e)
finally:
- self._sock.close()
+ self._sock.Close()
logging.info('SpawnTTYServer: terminated')
sys.exit(0)
@@ -545,7 +607,7 @@
while True:
rd, unused_wd, unused_xd = select.select(
- [p.stdout, p.stderr, self._sock], [], [self._sock])
+ [p.stdout, p.stderr, self._sock], [], [])
if p.stdout in rd:
self._sock.Send(p.stdout.read(_BUFSIZE))
@@ -582,7 +644,7 @@
pass
p.wait()
- self._sock.close()
+ self._sock.Close()
logging.info('SpawnShellServer: terminated')
sys.exit(0)
@@ -618,7 +680,7 @@
except Exception as e:
logging.error('StartDownloadServer: %s', e)
finally:
- self._sock.close()
+ self._sock.Close()
logging.info('StartDownloadServer: terminated')
sys.exit(0)
@@ -634,7 +696,6 @@
except Exception:
pass
- self._sock.setblocking(False)
with open(filepath, 'wb') as f:
if self._file_op[2]:
os.fchmod(f.fileno(), self._file_op[2])
@@ -653,7 +714,7 @@
except Exception as e:
logging.error('StartUploadServer: %s', e)
finally:
- self._sock.close()
+ self._sock.Close()
logging.info('StartUploadServer: terminated')
sys.exit(0)
@@ -667,7 +728,6 @@
src_sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
src_sock.settimeout(_CONNECT_TIMEOUT)
src_sock.connect(('localhost', self._port))
- src_sock.setblocking(False)
src_sock.send(self._sock.RecvBuf())
@@ -690,7 +750,7 @@
finally:
if src_sock:
src_sock.close()
- self._sock.close()
+ self._sock.Close()
logging.info('SpawnPortForwardServer: terminated')
sys.exit(0)
@@ -797,7 +857,12 @@
def ParseMessage(self, buf, single=True):
if single:
- index = buf.index(_SEPARATOR)
+ try:
+ index = buf.index(_SEPARATOR)
+ except ValueError:
+ self._sock.UnRecv(buf)
+ return
+
msgs_json = [buf[:index]]
self._sock.UnRecv(buf[index + 2:])
else:
@@ -851,7 +916,6 @@
# Socket is closed
if len(data) == 0:
- self.Reset()
break
self.ParseMessage(data, self._register_status != SUCCESS)
@@ -865,14 +929,13 @@
self.InitiateDownload()
if self._reset.is_set():
- self.Reset()
break
except socket.error:
raise RuntimeError('Connection dropped')
except PingTimeoutError:
raise RuntimeError('Connection timeout')
finally:
- self._sock.close()
+ self.Reset()
self._queue.put('resume')
@@ -901,9 +964,26 @@
try:
logging.info('Trying %s:%d ...', *addr)
self.Reset()
- self._sock = BufferedSocket(socket.AF_INET, socket.SOCK_STREAM)
- self._sock.settimeout(_CONNECT_TIMEOUT)
- self._sock.connect(addr)
+
+ sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
+ sock.settimeout(_CONNECT_TIMEOUT)
+
+ try:
+ if self._tls_settings.Enabled():
+ tls_context = self._tls_settings.Context()
+ sock = tls_context.wrap_socket(sock, server_hostname=addr[0])
+
+ sock.connect(addr)
+ except (ssl.SSLError, ssl.CertificateError) as e:
+ logging.error('%s: %s', e.__class__.__name__, e)
+ continue
+ except IOError as e:
+ if e.errno == 2: # No such file or directory
+ logging.error('%s: %s', e.__class__.__name__, e)
+ continue
+ raise
+
+ self._sock = BufferedSocket(sock)
logging.info('Connection established, registering...')
handler = {
@@ -924,7 +1004,7 @@
except socket.error:
pass
else:
- self._sock.settimeout(None)
+ sock.settimeout(None)
self.Listen()
raise RuntimeError('Cannot connect to any server')
@@ -1106,6 +1186,14 @@
parser.add_argument('--no-rpc-server', dest='rpc_server',
action='store_false', default=True,
help='disable RPC server')
+ parser.add_argument('--tls-cert-file', metavar='TLS_CERT_FILE',
+ dest='tls_cert_file', type=str, default=None,
+ help='file containing the server TLS certificate in PEM '
+ 'format')
+ parser.add_argument('--enable-tls-without-verify',
+ dest='enable_tls_without_verify', action='store_true',
+ default=False,
+ help='Enable TLS but don\'t verify certificate')
parser.add_argument('--prop-file', metavar='PROP_FILE', dest='prop_file',
type=str, default=None,
help='file containing the JSON representation of client '
@@ -1132,9 +1220,16 @@
addrs = [('localhost', _OVERLORD_PORT)]
addrs += [(x, _OVERLORD_PORT) for x in args.overlord_ip]
- g = Ghost(addrs, Ghost.AGENT, args.mid, prop_file=args.prop_file)
+ prop_file = os.path.abspath(args.prop_file) if args.prop_file else None
+
+ tls_settings = TLSSettings(args.tls_cert_file, args.enable_tls_without_verify)
+ g = Ghost(addrs, tls_settings, Ghost.AGENT, args.mid,
+ prop_file=prop_file)
g.Start(args.lan_disc, args.rpc_server)
if __name__ == '__main__':
- main()
+ try:
+ main()
+ except Exception as e:
+ logging.error(e)