overlord: ghost to target ssh port forwarding negotiation
Update ghost to forward an SSH port to some target based on port number
suggestions from Overlord.
Also:
- cleaned up some syntax
- added traceback to exception reporting
BUG=chrome-os-partner:43605
TEST=Manually on local machine
Change-Id: I093dcdc43702b7d8adfbed65f99ac9691a8d074d
Reviewed-on: https://chromium-review.googlesource.com/291152
Commit-Ready: Joel Kitching <kitching@chromium.org>
Tested-by: Joel Kitching <kitching@chromium.org>
Reviewed-by: Wei-Ning Huang <wnhuang@chromium.org>
diff --git a/py/tools/ghost.py b/py/tools/ghost.py
index da72bd6..789cf63 100755
--- a/py/tools/ghost.py
+++ b/py/tools/ghost.py
@@ -5,6 +5,8 @@
# Use of this source code is governed by a BSD-style license that can be
# found in the LICENSE file.
+from __future__ import print_function
+
import argparse
import contextlib
import fcntl
@@ -23,6 +25,7 @@
import termios
import threading
import time
+import traceback
import urllib
import uuid
@@ -53,6 +56,7 @@
RESPONSE_SUCCESS = 'success'
RESPONSE_FAILED = 'failed'
+
class PingTimeoutError(Exception):
pass
@@ -61,6 +65,231 @@
pass
+class SSHPortForwarder(object):
+ """Create and maintain an SSH port forwarding connection.
+
+ This is meant to be a standalone class to maintain an SSH port forwarding
+ connection to a given server. It provides a fail/retry mechanism, and also
+ can report its current connection status.
+ """
+ _FAILED_STR = 'port forwarding failed'
+ _DEFAULT_CONNECT_TIMEOUT = 10
+ _DEFAULT_ALIVE_INTERVAL = 10
+ _DEFAULT_DISCONNECT_WAIT = 1
+ _DEFAULT_RETRIES = 5
+ _DEFAULT_EXP_FACTOR = 1
+ _DEBUG_INTERVAL = 2
+
+ CONNECTING = 1
+ INITIALIZED = 2
+ FAILED = 4
+
+ REMOTE = 1
+ LOCAL = 2
+
+ @classmethod
+ def ToRemote(cls, *args, **kwargs):
+ """Calls contructor with forward_to=REMOTE."""
+ return cls(*args, forward_to=cls.REMOTE, **kwargs)
+
+ @classmethod
+ def ToLocal(cls, *args, **kwargs):
+ """Calls contructor with forward_to=LOCAL."""
+ return cls(*args, forward_to=cls.LOCAL, **kwargs)
+
+ def __init__(self,
+ forward_to,
+ src_port,
+ dst_port,
+ user,
+ identity_file,
+ host,
+ port=22,
+ connect_timeout=_DEFAULT_CONNECT_TIMEOUT,
+ alive_interval=_DEFAULT_ALIVE_INTERVAL,
+ disconnect_wait=_DEFAULT_DISCONNECT_WAIT,
+ retries=_DEFAULT_RETRIES,
+ exp_factor=_DEFAULT_EXP_FACTOR):
+ """Constructor.
+
+ Args:
+ forward_to: Which direction to forward traffic: REMOTE or LOCAL.
+ src_port: Source port for forwarding.
+ dst_port: Destination port for forwarding.
+ user: Username on remote server.
+ identity_file: Identity file for passwordless authentication on remote
+ server.
+ host: Host of remote server.
+ port: Port of remote server.
+ connect_timeout: Time in seconds
+ alive_interval:
+ disconnect_wait: The number of seconds to wait before reconnecting after
+ the first disconnect.
+ retries: The number of times to retry before reporting a failed
+ connection.
+ exp_factor: After each reconnect, the disconnect wait time is multiplied
+ by 2^exp_factor.
+ """
+ # Internal use.
+ self._ssh_thread = None
+ self._ssh_output = None
+ self._exception = None
+ self._state = self.CONNECTING
+ self._poll = threading.Event()
+
+ # Connection arguments.
+ self._forward_to = forward_to
+ self._src_port = src_port
+ self._dst_port = dst_port
+ self._host = host
+ self._user = user
+ self._identity_file = identity_file
+ self._port = port
+
+ # Configuration arguments.
+ self._connect_timeout = connect_timeout
+ self._alive_interval = alive_interval
+ self._exp_factor = exp_factor
+
+ t = threading.Thread(
+ target=self._Run,
+ args=(disconnect_wait, retries))
+ t.daemon = True
+ t.start()
+
+ def __str__(self):
+ # State representation.
+ if self._state == self.CONNECTING:
+ state_str = 'connecting'
+ elif self._state == self.INITIALIZED:
+ state_str = 'initialized'
+ else:
+ state_str = 'failed'
+
+ # Port forward representation.
+ if self._forward_to == self.REMOTE:
+ fwd_str = '->%d' % self._dst_port
+ else:
+ fwd_str = '%d<-' % self._dst_port
+
+ return 'SSHPortForwarder(%s,%s)' % (state_str, fwd_str)
+
+ def _ForwardArgs(self):
+ if self._forward_to == self.REMOTE:
+ return ['-R', '%d:127.0.0.1:%d' % (self._dst_port, self._src_port)]
+ else:
+ return ['-L', '%d:127.0.0.1:%d' % (self._src_port, self._dst_port)]
+
+ def _RunSSHCmd(self):
+ """Runs the SSH command, storing the exception on failure."""
+ try:
+ cmd = [
+ 'ssh',
+ '-o', 'StrictHostKeyChecking=no',
+ '-o', 'GlobalKnownHostsFile=/dev/null',
+ '-o', 'UserKnownHostsFile=/dev/null',
+ '-o', 'ExitOnForwardFailure=yes',
+ '-o', 'ConnectTimeout=%d' % self._connect_timeout,
+ '-o', 'ServerAliveInterval=%d' % self._alive_interval,
+ '-o', 'ServerAliveCountMax=1',
+ '-o', 'TCPKeepAlive=yes',
+ '-o', 'BatchMode=yes',
+ '-i', self._identity_file,
+ '-N',
+ '-p', str(self._port),
+ '%s@%s' % (self._user, self._host),
+ ] + self._ForwardArgs()
+ logging.info(' '.join(cmd))
+ self._ssh_output = subprocess.check_output(cmd, stderr=subprocess.STDOUT)
+ except subprocess.CalledProcessError as e:
+ self._exception = e
+ finally:
+ pass
+
+ def _Run(self, disconnect_wait, retries):
+ """Wraps around the SSH command, detecting its connection status."""
+ assert retries > 0, '%s: _Run must be called with retries > 0' % self
+
+ logging.info('%s: Connecting to %s:%d',
+ self, self._host, self._port)
+
+ # Set identity file permissions. Need to only be user-readable for ssh to
+ # use the key.
+ try:
+ os.chmod(self._identity_file, 0600)
+ except OSError as e:
+ logging.error('%s: Error setting identity file permissions: %s',
+ self, e)
+ self._state = self.FAILED
+ return
+
+ # Start a thread. If it fails, deal with the failure. If it is still
+ # running after connect_timeout seconds, assume everything's working great,
+ # and tell the caller. Then, continue waiting for it to end.
+ self._ssh_thread = threading.Thread(target=self._RunSSHCmd)
+ self._ssh_thread.daemon = True
+ self._ssh_thread.start()
+
+ # See if the SSH thread is still working after connect_timeout.
+ self._ssh_thread.join(self._connect_timeout)
+ if self._ssh_thread.is_alive():
+ # Assumed to be working. Tell our caller that we are connected.
+ if self._state != self.INITIALIZED:
+ self._state = self.INITIALIZED
+ self._poll.set()
+ logging.info('%s: Still connected after timeout=%ds',
+ self, self._connect_timeout)
+
+ # Only for debug purposes. Keep showing connection status.
+ while self._ssh_thread.is_alive():
+ logging.debug('%s: Still connected', self)
+ self._ssh_thread.join(self._DEBUG_INTERVAL)
+
+ # Figure out what went wrong.
+ if not self._exception:
+ logging.info('%s: SSH unexpectedly exited: %s',
+ self, self._ssh_output.rstrip())
+ if self._exception and self._FAILED_STR in self._exception.output:
+ self._state = self.FAILED
+ self._poll.set()
+ logging.info('%s: Port forwarding failed', self)
+ return
+ elif retries == 1:
+ self._state = self.FAILED
+ self._poll.set()
+ logging.info('%s: Disconnected (0 retries left)', self)
+ return
+ else:
+ logging.info('%s: Disconnected, retrying (sleep %1ds, %d retries left)',
+ self, disconnect_wait, retries - 1)
+ time.sleep(disconnect_wait)
+ self._Run(disconnect_wait=disconnect_wait * (2 ** self._exp_factor),
+ retries=retries - 1)
+
+ def GetState(self):
+ """Returns the current connection state.
+
+ State may be one of:
+
+ CONNECTING: Still attempting to make the first successful connection.
+ INITIALIZED: Is either connected or is trying to make subsequent
+ connection.
+ FAILED: Has completed all connection attempts, or server has reported that
+ target port is in use.
+ """
+ return self._state
+
+ def GetDstPort(self):
+ """Returns the current target port."""
+ return self._dst_port
+
+ def Wait(self):
+ """Waits for a state change, and returns the new state."""
+ self._poll.wait()
+ self._poll.clear()
+ return self.GetState()
+
+
class Ghost(object):
"""Ghost implements the client protocol of Overlord.
@@ -120,6 +349,10 @@
self._reset = threading.Event()
self._last_ping = 0
self._queue = Queue.Queue()
+ self._forward_ssh = False
+ self._ssh_port_forwarder = None
+ self._target_identity_file = os.path.join(os.path.dirname(
+ os.path.abspath(os.path.realpath(__file__))), 'ghost_rsa')
self._download_queue = Queue.Queue()
self._ttyname_to_sid = {}
self._terminal_sid_to_pid = {}
@@ -207,7 +440,7 @@
Returns:
The spawned child process pid.
"""
- # Restore the default signal hanlder, so our child won't have problems.
+ # Restore the default signal handler, so our child won't have problems.
self.SetIgnoreChild(False)
pid = os.fork()
@@ -452,7 +685,6 @@
if len(ret) == 0:
raise RuntimeError('socket closed')
p.stdin.write(ret)
-
p.poll()
if p.returncode != None:
break
@@ -562,7 +794,7 @@
handler(response)
else:
print(response, self._requests.keys())
- logging.warning('Recvied unsolicited response, ignored')
+ logging.warning('Received unsolicited response, ignored')
def ParseMessage(self):
msgs_json = self._buf.split(_SEPARATOR)
@@ -583,6 +815,11 @@
pass
def ScanForTimeoutRequests(self):
+ """Scans for pending requests which have timed out.
+
+ If any timed-out requests are discovered, their handler is called with the
+ special response value of None.
+ """
for rid in self._requests.keys()[:]:
request_time, timeout, handler = self._requests[rid]
if self.Timestamp() - request_time > timeout:
@@ -639,6 +876,9 @@
self._reset.set()
raise RuntimeError('Register request timeout')
logging.info('Registered with Overlord at %s:%d', *non_local['addr'])
+ if self._forward_ssh:
+ logging.info('Starting target SSH port negotiation')
+ self.NegotiateTargetSSHPort()
self._queue.put('pause', True)
try:
@@ -730,7 +970,7 @@
t.start()
def StartRPCServer(self):
- logging.info("RPC Server: started")
+ logging.info('RPC Server: started')
rpc_server = SimpleJSONRPCServer((_DEFAULT_BIND_ADDRESS, _GHOST_RPC_PORT),
logRequests=False)
rpc_server.register_function(self.Reconnect, 'Reconnect')
@@ -747,7 +987,104 @@
if addr not in self._overlord_addrs:
self._overlord_addrs.append(addr)
- def Start(self, lan_disc=False, rpc_server=False):
+ def NegotiateTargetSSHPort(self):
+ """Request-receive target SSH port forwarding loop.
+
+ Repeatedly attempts to forward this machine's SSH port to target. It
+ bounces back and forth between RequestPort and ReceivePort when a new port
+ is required. ReceivePort starts a new thread so that the main ghost thread
+ may continue running.
+ """
+ # Sanity check for identity file.
+ if not os.path.isfile(self._target_identity_file):
+ logging.info('No target host identity file: not negotiating '
+ 'target SSH port')
+ return
+
+ def PollSSHPortForwarder():
+ def ThreadFunc():
+ while True:
+ state = self._ssh_port_forwarder.GetState()
+
+ # Connected successfully.
+ if state == SSHPortForwarder.INITIALIZED:
+ # The SSH port forward has succeeded! Let's tell Overlord.
+ port = self._ssh_port_forwarder.GetDstPort()
+ RegisterPort(port)
+
+ # We've given up... continue to the next port.
+ elif state == SSHPortForwarder.FAILED:
+ break
+
+ # Either CONNECTING or INITIALIZED.
+ self._ssh_port_forwarder.Wait()
+
+ # Only request a new port if we are still registered to Overlord.
+ # Otherwise, a new call to NegotiateTargetSSHPort will be made,
+ # which will take care of it.
+ try:
+ RequestPort()
+ except Exception:
+ logging.info('Failed to request port, will wait for next connection')
+ self._ssh_port_forwarder = None
+
+ t = threading.Thread(target=ThreadFunc)
+ t.daemon = True
+ t.start()
+
+ def ReceivePort(response):
+ # If the response times out, this version of Overlord may not support SSH
+ # port negotiation. Give up on port negotiation process.
+ if response is None:
+ return
+
+ port = int(response['params']['port'])
+ logging.info('Received target SSH port: %d', port)
+
+ if (self._ssh_port_forwarder and
+ self._ssh_port_forwarder.GetState() != SSHPortForwarder.FAILED):
+ logging.info('Unexpectedly received a target SSH port')
+ return
+
+ # Try forwarding SSH port to target.
+ self._ssh_port_forwarder = SSHPortForwarder.ToRemote(
+ src_port=22,
+ dst_port=port,
+ user='ghost',
+ identity_file=self._target_identity_file,
+ host=self._connected_addr[0]) # Use Overlord host as target.
+
+ # Creates a new thread.
+ PollSSHPortForwarder()
+
+ def RequestPort():
+ logging.info('Requesting new target SSH port')
+ self.SendRequest('request_target_ssh_port', {}, ReceivePort, 5)
+
+ def RegisterPort(port):
+ logging.info('Registering target SSH port %d', port)
+ self.SendRequest(
+ 'register_target_ssh_port',
+ {'port': port}, RegisterPortResponse, 5)
+
+ def RegisterPortResponse(response):
+ # Overlord responded to request_port already. If register_port fails,
+ # something might be in an inconsistent state, so trigger a reconnect
+ # via PingTimeoutError.
+ if response is None:
+ raise PingTimeoutError
+ logging.info('Registering target SSH port acknowledged')
+
+ # If the SSHPortForwarder is already in a INITIALIZED state, we need to
+ # manually report the port to target, since SSHPortForwarder is currently
+ # blocking.
+ if (self._ssh_port_forwarder and
+ self._ssh_port_forwarder.GetState() == SSHPortForwarder.INITIALIZED):
+ RegisterPort(self._ssh_port_forwarder.GetDstPort())
+ if not self._ssh_port_forwarder:
+ RequestPort()
+
+ def Start(self, lan_disc=False, rpc_server=False, forward_ssh=False):
logging.info('%s started', self.MODE_NAME[self._mode])
logging.info('MID: %s', self._machine_id)
logging.info('SID: %s', self._session_id)
@@ -762,6 +1099,8 @@
if rpc_server:
self.StartRPCServer()
+ self._forward_ssh = forward_ssh
+
try:
while True:
try:
@@ -776,8 +1115,17 @@
try:
self.ScanServer()
self.Register()
+ # Don't show stack trace for RuntimeError, which we use in this file for
+ # plausible and expected errors (such as can't connect to server).
+ except RuntimeError as e:
+ logging.info('%s: %s, retrying in %ds',
+ e.__class__.__name__, e.message, _RETRY_INTERVAL)
+ time.sleep(_RETRY_INTERVAL)
except Exception as e:
- logging.info(str(e) + ', retrying in %ds' % _RETRY_INTERVAL)
+ _, _, exc_traceback = sys.exc_info()
+ traceback.print_tb(exc_traceback)
+ logging.info('%s: %s, retrying in %ds',
+ e.__class__.__name__, e.message, _RETRY_INTERVAL)
time.sleep(_RETRY_INTERVAL)
self.Reset()
@@ -793,12 +1141,12 @@
def DownloadFile(filename):
filepath = os.path.abspath(filename)
if not os.path.exists(filepath):
- logging.error("file `%s' does not exist", filename)
+ logging.error('file `%s\' does not exist', filename)
sys.exit(1)
# Check if we actually have permission to read the file
if not os.access(filepath, os.R_OK):
- logging.error("can not open %s for reading", filepath)
+ logging.error('can not open %s for reading', filepath)
sys.exit(1)
server = GhostRPCServer()
@@ -820,7 +1168,10 @@
parser.add_argument('--no-rpc-server', dest='rpc_server',
action='store_false', default=True,
help='disable RPC server')
- parser.add_argument('--prop-file', metavar='PROP_FILE', dest="prop_file",
+ parser.add_argument('--no-forward-ssh', dest='forward_ssh',
+ action='store_false', default=True,
+ help='disable target SSH port forwarding')
+ parser.add_argument('--prop-file', metavar='PROP_FILE', dest='prop_file',
type=str, default=None,
help='file containing the JSON representation of client '
'properties')
@@ -839,8 +1190,27 @@
g = Ghost(addrs, Ghost.AGENT, args.mid)
if args.prop_file:
g.LoadPropertiesFromFile(args.prop_file)
- g.Start(args.lan_disc, args.rpc_server)
+ g.Start(args.lan_disc, args.rpc_server, args.forward_ssh)
+
+
+def _SigtermHandler(*_):
+ """Ensure that SSH processes also get killed on a sigterm signal.
+
+ By also passing the sigterm signal onto the process group, we ensure that any
+ child SSH processes will also get killed.
+
+ Source:
+ http://www.tsheffler.com/blog/2010/11/21/python-multithreaded-daemon-with-sigterm-support-a-recipe/
+ """
+ logging.info('SIGTERM handler: shutting down')
+ if not _SigtermHandler.SIGTERM_SENT:
+ _SigtermHandler.SIGTERM_SENT = True
+ logging.info('Sending TERM to process group')
+ os.killpg(0, signal.SIGTERM)
+ sys.exit()
+_SigtermHandler.SIGTERM_SENT = False
if __name__ == '__main__':
+ signal.signal(signal.SIGTERM, _SigtermHandler)
main()
diff --git a/py/tools/ghost_rsa b/py/tools/ghost_rsa
new file mode 100644
index 0000000..fd37f4f
--- /dev/null
+++ b/py/tools/ghost_rsa
@@ -0,0 +1,27 @@
+-----BEGIN RSA PRIVATE KEY-----
+MIIEowIBAAKCAQEAwbjPPNLRfrQC7vWMcnAaO291hR4PtYee4a+H2nqZXjqAbEUg
+asdWvyvev9BJYwR9Y1f46jvklVaqXgWcI+C6PA60ZFA0hAfzxljMF1UpqtO3jgZp
+k6j6as5V570QJw5T+/ri3YXJmlgpW1fRf4HAuC0U/EfZivJaH8lURZSZkGHq+m2R
+ZuOYS/IncXHUTKYY3JGwDOhAIuxR5ht3JcExoOf1zKSElg3NZLlovVI0aXmHiwKv
+nIKdhg9tnHMLdIXLuUFzTSFrRviWFYwdyzIVoL3xztoJp/Bw/FWQaqIKXGzwmw8a
+rAxXSJVdZexd1L9ifQK7y+jqEX/XJG69jP73ZQIDAQABAoIBAFUr+gGV9wGsB3Yk
+g4F1BDOJh7PlSabYX+R+Fk7ahD/HnNr9cYlA50TZ9u+CliFwwehBr3DcsF1wYys/
+cCzeC0OIe4t3L/+0t1tHg0Pm75Dp0NQiwZxoOOFooqBmoYlqZUZuQfx/a+nuRRCi
+Bbv3wlG3kHhy5pSOhU1gaSrGcNXnUvhR+TUdFHdGCWicVTl5fX73eKAHW21zoKlI
+hRYzpTWGhEOifzvHdu/8QMJFXA929o6/V0YFfdSra9JiiE9hqYRROQhjq1ttg6jg
+O7SUyV7aOYnbc6sSANEav5uG9c6bGK4AedQXKTqhCmhHHzs+CB5ZPM1ZZCiltpbn
+UcJ4WOkCgYEA7cZf3kPx6/9JelKLe0IjNspoArAzUnk0Sg58ZqmjXYGZZMrcpSWc
+7xmgGNi0j0G3LRTbLXqbm428ry+PQQXzKFfqvWaXC4oZi9BRQ+SF2Jt5ojhNN14G
+ZZ1gYO48+bzRMcXE/gLmX9dTMVV1SfpOCYUaHWm0HrpjIm8p0PiQ+7cCgYEA0JIG
+r2Uz3ehiF1bAILt8TMVEY0l7DyIBrsZ2sIr3CzbISXMER7lOUZopNXc1zKGKL1ah
+/abBZEkE/97t0vrKEeeK3mZdj/0wE1RIA1BOVDHaa84EO/mVNy4GgQs31KOCctMl
+nZs57WebgZNSPxmi6U3B0YSbBke5plPOanVQncMCgYAtqrsA6lXNeLN5Dd+CJdqz
+jD5bvdGtll/HlW6pHQ2mSNzYMeocwdOZTHemLgDHvtxaiTXrTzARuTAzCVRfLbBc
+4D3ScKCz86siYjkpa/uU9Y9v65ZQ+vsJiydWlosZf/1BrPU/v/jVEXsF757ePXe9
+dlXkrkeM20ls9KK4YvUdkwKBgQDPTcTtQkaqMpaEogn2vsLOP2g400lIAkHv6H0B
+/i2L7NhoALTpYSqR+wsohCNqD8mcQZxi1AL2XYlllLuHbxO3dg9V/CLUwg+ttqCZ
+ApHIJ4D0k+Erh2ejX9DBJFhKtnYrEOkbXLTX3Zn30Wj3JNEC2PFjAU1gkZvZ0QSi
+VZZaJwKBgBy5jRcKFhAb8okoY+FwopzVrBS0emz/VIV3ygVifU+XTE8DdcqZK3UR
+DF1EqbVPqEAfCioscPM1bo6oEN/IHat7+M59vp4prOdJ27L9xKoPDxoQnp+Ut4S0
+VXbXLuwIB5APoiipk/cHDMwz3mpaX9+BfSh1VtDLnJ6ePwP9wvJf
+-----END RSA PRIVATE KEY-----
diff --git a/py/tools/ghost_rsa.pub b/py/tools/ghost_rsa.pub
new file mode 100644
index 0000000..61664c1
--- /dev/null
+++ b/py/tools/ghost_rsa.pub
@@ -0,0 +1 @@
+ssh-rsa AAAAB3NzaC1yc2EAAAADAQABAAABAQDBuM880tF+tALu9YxycBo7b3WFHg+1h57hr4faepleOoBsRSBqx1a/K96/0EljBH1jV/jqO+SVVqpeBZwj4Lo8DrRkUDSEB/PGWMwXVSmq07eOBmmTqPpqzlXnvRAnDlP7+uLdhcmaWClbV9F/gcC4LRT8R9mK8lofyVRFlJmQYer6bZFm45hL8idxcdRMphjckbAM6EAi7FHmG3clwTGg5/XMpISWDc1kuWi9UjRpeYeLAq+cgp2GD22ccwt0hcu5QXNNIWtG+JYVjB3LMhWgvfHO2gmn8HD8VZBqogpcbPCbDxqsDFdIlV1l7F3Uv2J9ArvL6OoRf9ckbr2M/vdl ghost@SEL-SERVER01