overlord: ghost to target ssh port forwarding negotiation

Update ghost to forward an SSH port to some target based on port number
suggestions from Overlord.

Also:
- cleaned up some syntax
- added traceback to exception reporting

BUG=chrome-os-partner:43605
TEST=Manually on local machine

Change-Id: I093dcdc43702b7d8adfbed65f99ac9691a8d074d
Reviewed-on: https://chromium-review.googlesource.com/291152
Commit-Ready: Joel Kitching <kitching@chromium.org>
Tested-by: Joel Kitching <kitching@chromium.org>
Reviewed-by: Wei-Ning Huang <wnhuang@chromium.org>
diff --git a/py/tools/ghost.py b/py/tools/ghost.py
index da72bd6..789cf63 100755
--- a/py/tools/ghost.py
+++ b/py/tools/ghost.py
@@ -5,6 +5,8 @@
 # Use of this source code is governed by a BSD-style license that can be
 # found in the LICENSE file.
 
+from __future__ import print_function
+
 import argparse
 import contextlib
 import fcntl
@@ -23,6 +25,7 @@
 import termios
 import threading
 import time
+import traceback
 import urllib
 import uuid
 
@@ -53,6 +56,7 @@
 RESPONSE_SUCCESS = 'success'
 RESPONSE_FAILED = 'failed'
 
+
 class PingTimeoutError(Exception):
   pass
 
@@ -61,6 +65,231 @@
   pass
 
 
+class SSHPortForwarder(object):
+  """Create and maintain an SSH port forwarding connection.
+
+  This is meant to be a standalone class to maintain an SSH port forwarding
+  connection to a given server.  It provides a fail/retry mechanism, and also
+  can report its current connection status.
+  """
+  _FAILED_STR = 'port forwarding failed'
+  _DEFAULT_CONNECT_TIMEOUT = 10
+  _DEFAULT_ALIVE_INTERVAL = 10
+  _DEFAULT_DISCONNECT_WAIT = 1
+  _DEFAULT_RETRIES = 5
+  _DEFAULT_EXP_FACTOR = 1
+  _DEBUG_INTERVAL = 2
+
+  CONNECTING = 1
+  INITIALIZED = 2
+  FAILED = 4
+
+  REMOTE = 1
+  LOCAL = 2
+
+  @classmethod
+  def ToRemote(cls, *args, **kwargs):
+    """Calls contructor with forward_to=REMOTE."""
+    return cls(*args, forward_to=cls.REMOTE, **kwargs)
+
+  @classmethod
+  def ToLocal(cls, *args, **kwargs):
+    """Calls contructor with forward_to=LOCAL."""
+    return cls(*args, forward_to=cls.LOCAL, **kwargs)
+
+  def __init__(self,
+               forward_to,
+               src_port,
+               dst_port,
+               user,
+               identity_file,
+               host,
+               port=22,
+               connect_timeout=_DEFAULT_CONNECT_TIMEOUT,
+               alive_interval=_DEFAULT_ALIVE_INTERVAL,
+               disconnect_wait=_DEFAULT_DISCONNECT_WAIT,
+               retries=_DEFAULT_RETRIES,
+               exp_factor=_DEFAULT_EXP_FACTOR):
+    """Constructor.
+
+    Args:
+      forward_to: Which direction to forward traffic: REMOTE or LOCAL.
+      src_port: Source port for forwarding.
+      dst_port: Destination port for forwarding.
+      user: Username on remote server.
+      identity_file: Identity file for passwordless authentication on remote
+          server.
+      host: Host of remote server.
+      port: Port of remote server.
+      connect_timeout: Time in seconds
+      alive_interval:
+      disconnect_wait: The number of seconds to wait before reconnecting after
+          the first disconnect.
+      retries: The number of times to retry before reporting a failed
+          connection.
+      exp_factor: After each reconnect, the disconnect wait time is multiplied
+          by 2^exp_factor.
+    """
+    # Internal use.
+    self._ssh_thread = None
+    self._ssh_output = None
+    self._exception = None
+    self._state = self.CONNECTING
+    self._poll = threading.Event()
+
+    # Connection arguments.
+    self._forward_to = forward_to
+    self._src_port = src_port
+    self._dst_port = dst_port
+    self._host = host
+    self._user = user
+    self._identity_file = identity_file
+    self._port = port
+
+    # Configuration arguments.
+    self._connect_timeout = connect_timeout
+    self._alive_interval = alive_interval
+    self._exp_factor = exp_factor
+
+    t = threading.Thread(
+        target=self._Run,
+        args=(disconnect_wait, retries))
+    t.daemon = True
+    t.start()
+
+  def __str__(self):
+    # State representation.
+    if self._state == self.CONNECTING:
+      state_str = 'connecting'
+    elif self._state == self.INITIALIZED:
+      state_str = 'initialized'
+    else:
+      state_str = 'failed'
+
+    # Port forward representation.
+    if self._forward_to == self.REMOTE:
+      fwd_str = '->%d' % self._dst_port
+    else:
+      fwd_str = '%d<-' % self._dst_port
+
+    return 'SSHPortForwarder(%s,%s)' % (state_str, fwd_str)
+
+  def _ForwardArgs(self):
+    if self._forward_to == self.REMOTE:
+      return ['-R', '%d:127.0.0.1:%d' % (self._dst_port, self._src_port)]
+    else:
+      return ['-L', '%d:127.0.0.1:%d' % (self._src_port, self._dst_port)]
+
+  def _RunSSHCmd(self):
+    """Runs the SSH command, storing the exception on failure."""
+    try:
+      cmd = [
+          'ssh',
+          '-o', 'StrictHostKeyChecking=no',
+          '-o', 'GlobalKnownHostsFile=/dev/null',
+          '-o', 'UserKnownHostsFile=/dev/null',
+          '-o', 'ExitOnForwardFailure=yes',
+          '-o', 'ConnectTimeout=%d' % self._connect_timeout,
+          '-o', 'ServerAliveInterval=%d' % self._alive_interval,
+          '-o', 'ServerAliveCountMax=1',
+          '-o', 'TCPKeepAlive=yes',
+          '-o', 'BatchMode=yes',
+          '-i', self._identity_file,
+          '-N',
+          '-p', str(self._port),
+          '%s@%s' % (self._user, self._host),
+          ] + self._ForwardArgs()
+      logging.info(' '.join(cmd))
+      self._ssh_output = subprocess.check_output(cmd, stderr=subprocess.STDOUT)
+    except subprocess.CalledProcessError as e:
+      self._exception = e
+    finally:
+      pass
+
+  def _Run(self, disconnect_wait, retries):
+    """Wraps around the SSH command, detecting its connection status."""
+    assert retries > 0, '%s: _Run must be called with retries > 0' % self
+
+    logging.info('%s: Connecting to %s:%d',
+                 self, self._host, self._port)
+
+    # Set identity file permissions.  Need to only be user-readable for ssh to
+    # use the key.
+    try:
+      os.chmod(self._identity_file, 0600)
+    except OSError as e:
+      logging.error('%s: Error setting identity file permissions: %s',
+                    self, e)
+      self._state = self.FAILED
+      return
+
+    # Start a thread.  If it fails, deal with the failure.  If it is still
+    # running after connect_timeout seconds, assume everything's working great,
+    # and tell the caller.  Then, continue waiting for it to end.
+    self._ssh_thread = threading.Thread(target=self._RunSSHCmd)
+    self._ssh_thread.daemon = True
+    self._ssh_thread.start()
+
+    # See if the SSH thread is still working after connect_timeout.
+    self._ssh_thread.join(self._connect_timeout)
+    if self._ssh_thread.is_alive():
+      # Assumed to be working.  Tell our caller that we are connected.
+      if self._state != self.INITIALIZED:
+        self._state = self.INITIALIZED
+        self._poll.set()
+      logging.info('%s: Still connected after timeout=%ds',
+                   self, self._connect_timeout)
+
+    # Only for debug purposes.  Keep showing connection status.
+    while self._ssh_thread.is_alive():
+      logging.debug('%s: Still connected', self)
+      self._ssh_thread.join(self._DEBUG_INTERVAL)
+
+    # Figure out what went wrong.
+    if not self._exception:
+      logging.info('%s: SSH unexpectedly exited: %s',
+                   self, self._ssh_output.rstrip())
+    if self._exception and self._FAILED_STR in self._exception.output:
+      self._state = self.FAILED
+      self._poll.set()
+      logging.info('%s: Port forwarding failed', self)
+      return
+    elif retries == 1:
+      self._state = self.FAILED
+      self._poll.set()
+      logging.info('%s: Disconnected (0 retries left)', self)
+      return
+    else:
+      logging.info('%s: Disconnected, retrying (sleep %1ds, %d retries left)',
+                   self, disconnect_wait, retries - 1)
+      time.sleep(disconnect_wait)
+      self._Run(disconnect_wait=disconnect_wait * (2 ** self._exp_factor),
+                retries=retries - 1)
+
+  def GetState(self):
+    """Returns the current connection state.
+
+    State may be one of:
+
+      CONNECTING: Still attempting to make the first successful connection.
+      INITIALIZED: Is either connected or is trying to make subsequent
+          connection.
+      FAILED: Has completed all connection attempts, or server has reported that
+          target port is in use.
+    """
+    return self._state
+
+  def GetDstPort(self):
+    """Returns the current target port."""
+    return self._dst_port
+
+  def Wait(self):
+    """Waits for a state change, and returns the new state."""
+    self._poll.wait()
+    self._poll.clear()
+    return self.GetState()
+
+
 class Ghost(object):
   """Ghost implements the client protocol of Overlord.
 
@@ -120,6 +349,10 @@
     self._reset = threading.Event()
     self._last_ping = 0
     self._queue = Queue.Queue()
+    self._forward_ssh = False
+    self._ssh_port_forwarder = None
+    self._target_identity_file = os.path.join(os.path.dirname(
+        os.path.abspath(os.path.realpath(__file__))), 'ghost_rsa')
     self._download_queue = Queue.Queue()
     self._ttyname_to_sid = {}
     self._terminal_sid_to_pid = {}
@@ -207,7 +440,7 @@
     Returns:
       The spawned child process pid.
     """
-    # Restore the default signal hanlder, so our child won't have problems.
+    # Restore the default signal handler, so our child won't have problems.
     self.SetIgnoreChild(False)
 
     pid = os.fork()
@@ -452,7 +685,6 @@
           if len(ret) == 0:
             raise RuntimeError('socket closed')
           p.stdin.write(ret)
-
         p.poll()
         if p.returncode != None:
           break
@@ -562,7 +794,7 @@
         handler(response)
     else:
       print(response, self._requests.keys())
-      logging.warning('Recvied unsolicited response, ignored')
+      logging.warning('Received unsolicited response, ignored')
 
   def ParseMessage(self):
     msgs_json = self._buf.split(_SEPARATOR)
@@ -583,6 +815,11 @@
         pass
 
   def ScanForTimeoutRequests(self):
+    """Scans for pending requests which have timed out.
+
+    If any timed-out requests are discovered, their handler is called with the
+    special response value of None.
+    """
     for rid in self._requests.keys()[:]:
       request_time, timeout, handler = self._requests[rid]
       if self.Timestamp() - request_time > timeout:
@@ -639,6 +876,9 @@
           self._reset.set()
           raise RuntimeError('Register request timeout')
         logging.info('Registered with Overlord at %s:%d', *non_local['addr'])
+        if self._forward_ssh:
+          logging.info('Starting target SSH port negotiation')
+          self.NegotiateTargetSSHPort()
         self._queue.put('pause', True)
 
       try:
@@ -730,7 +970,7 @@
     t.start()
 
   def StartRPCServer(self):
-    logging.info("RPC Server: started")
+    logging.info('RPC Server: started')
     rpc_server = SimpleJSONRPCServer((_DEFAULT_BIND_ADDRESS, _GHOST_RPC_PORT),
                                      logRequests=False)
     rpc_server.register_function(self.Reconnect, 'Reconnect')
@@ -747,7 +987,104 @@
         if addr not in self._overlord_addrs:
           self._overlord_addrs.append(addr)
 
-  def Start(self, lan_disc=False, rpc_server=False):
+  def NegotiateTargetSSHPort(self):
+    """Request-receive target SSH port forwarding loop.
+
+    Repeatedly attempts to forward this machine's SSH port to target.  It
+    bounces back and forth between RequestPort and ReceivePort when a new port
+    is required.  ReceivePort starts a new thread so that the main ghost thread
+    may continue running.
+    """
+    # Sanity check for identity file.
+    if not os.path.isfile(self._target_identity_file):
+      logging.info('No target host identity file: not negotiating '
+                   'target SSH port')
+      return
+
+    def PollSSHPortForwarder():
+      def ThreadFunc():
+        while True:
+          state = self._ssh_port_forwarder.GetState()
+
+          # Connected successfully.
+          if state == SSHPortForwarder.INITIALIZED:
+            # The SSH port forward has succeeded!  Let's tell Overlord.
+            port = self._ssh_port_forwarder.GetDstPort()
+            RegisterPort(port)
+
+          # We've given up... continue to the next port.
+          elif state == SSHPortForwarder.FAILED:
+            break
+
+          # Either CONNECTING or INITIALIZED.
+          self._ssh_port_forwarder.Wait()
+
+        # Only request a new port if we are still registered to Overlord.
+        # Otherwise, a new call to NegotiateTargetSSHPort will be made,
+        # which will take care of it.
+        try:
+          RequestPort()
+        except Exception:
+          logging.info('Failed to request port, will wait for next connection')
+          self._ssh_port_forwarder = None
+
+      t = threading.Thread(target=ThreadFunc)
+      t.daemon = True
+      t.start()
+
+    def ReceivePort(response):
+      # If the response times out, this version of Overlord may not support SSH
+      # port negotiation.  Give up on port negotiation process.
+      if response is None:
+        return
+
+      port = int(response['params']['port'])
+      logging.info('Received target SSH port: %d', port)
+
+      if (self._ssh_port_forwarder and
+          self._ssh_port_forwarder.GetState() != SSHPortForwarder.FAILED):
+        logging.info('Unexpectedly received a target SSH port')
+        return
+
+      # Try forwarding SSH port to target.
+      self._ssh_port_forwarder = SSHPortForwarder.ToRemote(
+          src_port=22,
+          dst_port=port,
+          user='ghost',
+          identity_file=self._target_identity_file,
+          host=self._connected_addr[0])  # Use Overlord host as target.
+
+      # Creates a new thread.
+      PollSSHPortForwarder()
+
+    def RequestPort():
+      logging.info('Requesting new target SSH port')
+      self.SendRequest('request_target_ssh_port', {}, ReceivePort, 5)
+
+    def RegisterPort(port):
+      logging.info('Registering target SSH port %d', port)
+      self.SendRequest(
+          'register_target_ssh_port',
+          {'port': port}, RegisterPortResponse, 5)
+
+    def RegisterPortResponse(response):
+      # Overlord responded to request_port already.  If register_port fails,
+      # something might be in an inconsistent state, so trigger a reconnect
+      # via PingTimeoutError.
+      if response is None:
+        raise PingTimeoutError
+      logging.info('Registering target SSH port acknowledged')
+
+    # If the SSHPortForwarder is already in a INITIALIZED state, we need to
+    # manually report the port to target, since SSHPortForwarder is currently
+    # blocking.
+    if (self._ssh_port_forwarder and
+        self._ssh_port_forwarder.GetState() == SSHPortForwarder.INITIALIZED):
+      RegisterPort(self._ssh_port_forwarder.GetDstPort())
+    if not self._ssh_port_forwarder:
+      RequestPort()
+
+  def Start(self, lan_disc=False, rpc_server=False, forward_ssh=False):
     logging.info('%s started', self.MODE_NAME[self._mode])
     logging.info('MID: %s', self._machine_id)
     logging.info('SID: %s', self._session_id)
@@ -762,6 +1099,8 @@
     if rpc_server:
       self.StartRPCServer()
 
+    self._forward_ssh = forward_ssh
+
     try:
       while True:
         try:
@@ -776,8 +1115,17 @@
         try:
           self.ScanServer()
           self.Register()
+        # Don't show stack trace for RuntimeError, which we use in this file for
+        # plausible and expected errors (such as can't connect to server).
+        except RuntimeError as e:
+          logging.info('%s: %s, retrying in %ds',
+                       e.__class__.__name__, e.message, _RETRY_INTERVAL)
+          time.sleep(_RETRY_INTERVAL)
         except Exception as e:
-          logging.info(str(e) + ', retrying in %ds' % _RETRY_INTERVAL)
+          _, _, exc_traceback = sys.exc_info()
+          traceback.print_tb(exc_traceback)
+          logging.info('%s: %s, retrying in %ds',
+                       e.__class__.__name__, e.message, _RETRY_INTERVAL)
           time.sleep(_RETRY_INTERVAL)
 
         self.Reset()
@@ -793,12 +1141,12 @@
 def DownloadFile(filename):
   filepath = os.path.abspath(filename)
   if not os.path.exists(filepath):
-    logging.error("file `%s' does not exist", filename)
+    logging.error('file `%s\' does not exist', filename)
     sys.exit(1)
 
   # Check if we actually have permission to read the file
   if not os.access(filepath, os.R_OK):
-    logging.error("can not open %s for reading", filepath)
+    logging.error('can not open %s for reading', filepath)
     sys.exit(1)
 
   server = GhostRPCServer()
@@ -820,7 +1168,10 @@
   parser.add_argument('--no-rpc-server', dest='rpc_server',
                       action='store_false', default=True,
                       help='disable RPC server')
-  parser.add_argument('--prop-file', metavar='PROP_FILE', dest="prop_file",
+  parser.add_argument('--no-forward-ssh', dest='forward_ssh',
+                      action='store_false', default=True,
+                      help='disable target SSH port forwarding')
+  parser.add_argument('--prop-file', metavar='PROP_FILE', dest='prop_file',
                       type=str, default=None,
                       help='file containing the JSON representation of client '
                            'properties')
@@ -839,8 +1190,27 @@
   g = Ghost(addrs, Ghost.AGENT, args.mid)
   if args.prop_file:
     g.LoadPropertiesFromFile(args.prop_file)
-  g.Start(args.lan_disc, args.rpc_server)
+  g.Start(args.lan_disc, args.rpc_server, args.forward_ssh)
+
+
+def _SigtermHandler(*_):
+  """Ensure that SSH processes also get killed on a sigterm signal.
+
+  By also passing the sigterm signal onto the process group, we ensure that any
+  child SSH processes will also get killed.
+
+  Source:
+  http://www.tsheffler.com/blog/2010/11/21/python-multithreaded-daemon-with-sigterm-support-a-recipe/
+  """
+  logging.info('SIGTERM handler: shutting down')
+  if not _SigtermHandler.SIGTERM_SENT:
+    _SigtermHandler.SIGTERM_SENT = True
+    logging.info('Sending TERM to process group')
+    os.killpg(0, signal.SIGTERM)
+  sys.exit()
+_SigtermHandler.SIGTERM_SENT = False
 
 
 if __name__ == '__main__':
+  signal.signal(signal.SIGTERM, _SigtermHandler)
   main()
diff --git a/py/tools/ghost_rsa b/py/tools/ghost_rsa
new file mode 100644
index 0000000..fd37f4f
--- /dev/null
+++ b/py/tools/ghost_rsa
@@ -0,0 +1,27 @@
+-----BEGIN RSA PRIVATE KEY-----
+MIIEowIBAAKCAQEAwbjPPNLRfrQC7vWMcnAaO291hR4PtYee4a+H2nqZXjqAbEUg
+asdWvyvev9BJYwR9Y1f46jvklVaqXgWcI+C6PA60ZFA0hAfzxljMF1UpqtO3jgZp
+k6j6as5V570QJw5T+/ri3YXJmlgpW1fRf4HAuC0U/EfZivJaH8lURZSZkGHq+m2R
+ZuOYS/IncXHUTKYY3JGwDOhAIuxR5ht3JcExoOf1zKSElg3NZLlovVI0aXmHiwKv
+nIKdhg9tnHMLdIXLuUFzTSFrRviWFYwdyzIVoL3xztoJp/Bw/FWQaqIKXGzwmw8a
+rAxXSJVdZexd1L9ifQK7y+jqEX/XJG69jP73ZQIDAQABAoIBAFUr+gGV9wGsB3Yk
+g4F1BDOJh7PlSabYX+R+Fk7ahD/HnNr9cYlA50TZ9u+CliFwwehBr3DcsF1wYys/
+cCzeC0OIe4t3L/+0t1tHg0Pm75Dp0NQiwZxoOOFooqBmoYlqZUZuQfx/a+nuRRCi
+Bbv3wlG3kHhy5pSOhU1gaSrGcNXnUvhR+TUdFHdGCWicVTl5fX73eKAHW21zoKlI
+hRYzpTWGhEOifzvHdu/8QMJFXA929o6/V0YFfdSra9JiiE9hqYRROQhjq1ttg6jg
+O7SUyV7aOYnbc6sSANEav5uG9c6bGK4AedQXKTqhCmhHHzs+CB5ZPM1ZZCiltpbn
+UcJ4WOkCgYEA7cZf3kPx6/9JelKLe0IjNspoArAzUnk0Sg58ZqmjXYGZZMrcpSWc
+7xmgGNi0j0G3LRTbLXqbm428ry+PQQXzKFfqvWaXC4oZi9BRQ+SF2Jt5ojhNN14G
+ZZ1gYO48+bzRMcXE/gLmX9dTMVV1SfpOCYUaHWm0HrpjIm8p0PiQ+7cCgYEA0JIG
+r2Uz3ehiF1bAILt8TMVEY0l7DyIBrsZ2sIr3CzbISXMER7lOUZopNXc1zKGKL1ah
+/abBZEkE/97t0vrKEeeK3mZdj/0wE1RIA1BOVDHaa84EO/mVNy4GgQs31KOCctMl
+nZs57WebgZNSPxmi6U3B0YSbBke5plPOanVQncMCgYAtqrsA6lXNeLN5Dd+CJdqz
+jD5bvdGtll/HlW6pHQ2mSNzYMeocwdOZTHemLgDHvtxaiTXrTzARuTAzCVRfLbBc
+4D3ScKCz86siYjkpa/uU9Y9v65ZQ+vsJiydWlosZf/1BrPU/v/jVEXsF757ePXe9
+dlXkrkeM20ls9KK4YvUdkwKBgQDPTcTtQkaqMpaEogn2vsLOP2g400lIAkHv6H0B
+/i2L7NhoALTpYSqR+wsohCNqD8mcQZxi1AL2XYlllLuHbxO3dg9V/CLUwg+ttqCZ
+ApHIJ4D0k+Erh2ejX9DBJFhKtnYrEOkbXLTX3Zn30Wj3JNEC2PFjAU1gkZvZ0QSi
+VZZaJwKBgBy5jRcKFhAb8okoY+FwopzVrBS0emz/VIV3ygVifU+XTE8DdcqZK3UR
+DF1EqbVPqEAfCioscPM1bo6oEN/IHat7+M59vp4prOdJ27L9xKoPDxoQnp+Ut4S0
+VXbXLuwIB5APoiipk/cHDMwz3mpaX9+BfSh1VtDLnJ6ePwP9wvJf
+-----END RSA PRIVATE KEY-----
diff --git a/py/tools/ghost_rsa.pub b/py/tools/ghost_rsa.pub
new file mode 100644
index 0000000..61664c1
--- /dev/null
+++ b/py/tools/ghost_rsa.pub
@@ -0,0 +1 @@
+ssh-rsa AAAAB3NzaC1yc2EAAAADAQABAAABAQDBuM880tF+tALu9YxycBo7b3WFHg+1h57hr4faepleOoBsRSBqx1a/K96/0EljBH1jV/jqO+SVVqpeBZwj4Lo8DrRkUDSEB/PGWMwXVSmq07eOBmmTqPpqzlXnvRAnDlP7+uLdhcmaWClbV9F/gcC4LRT8R9mK8lofyVRFlJmQYer6bZFm45hL8idxcdRMphjckbAM6EAi7FHmG3clwTGg5/XMpISWDc1kuWi9UjRpeYeLAq+cgp2GD22ccwt0hcu5QXNNIWtG+JYVjB3LMhWgvfHO2gmn8HD8VZBqogpcbPCbDxqsDFdIlV1l7F3Uv2J9ArvL6OoRf9ckbr2M/vdl ghost@SEL-SERVER01