overlord: implement ADB-like command line interface
Implement an ADB-like CLI client for Overlord. The client supports all
ovelord functions including shell, file upload/download, port
forwarding, etc.
BUG=chromium:517520
TEST=manually test
Change-Id: I7211ead781ecf8f59acc3e8f2e4c07c898a858b8
Reviewed-on: https://chromium-review.googlesource.com/291940
Commit-Ready: Wei-Ning Huang <wnhuang@chromium.org>
Tested-by: Wei-Ning Huang <wnhuang@chromium.org>
Reviewed-by: Hung-Te Lin <hungte@chromium.org>
diff --git a/py/tools/ovl.py b/py/tools/ovl.py
new file mode 100755
index 0000000..c44620c
--- /dev/null
+++ b/py/tools/ovl.py
@@ -0,0 +1,1098 @@
+#!/usr/bin/python -u
+# -*- coding: utf-8 -*-
+#
+# Copyright 2015 The Chromium OS Authors. All rights reserved.
+# Use of this source code is governed by a BSD-style license that can be
+# found in the LICENSE file.
+
+from __future__ import print_function
+
+import argparse
+import ast
+import base64
+import fcntl
+import hashlib
+import httplib
+import json
+import jsonrpclib
+import logging
+import os
+import re
+import select
+import signal
+import socket
+import StringIO
+import struct
+import subprocess
+import sys
+import tempfile
+import termios
+import threading
+import time
+import tty
+import urllib2
+import urlparse
+
+from jsonrpclib.SimpleJSONRPCServer import SimpleJSONRPCServer
+from jsonrpclib.config import Config
+from ws4py.client import WebSocketBaseClient
+
+# Python version >= 2.7.9 enables SSL check by default, bypass it.
+try:
+ import ssl
+ # pylint: disable=W0212
+ ssl._create_default_https_context = ssl._create_unverified_context
+except Exception:
+ pass
+
+
+_ESCAPE = '~'
+_BUFSIZ = 8192
+_OVERLORD_PORT = 4455
+_OVERLORD_HTTP_PORT = 9000
+_OVERLORD_CLIENT_DAEMON_PORT = 4488
+_OVERLORD_CLIENT_DAEMON_RPC_ADDR = ('127.0.0.1', _OVERLORD_CLIENT_DAEMON_PORT)
+
+_LIST_CACHE_TIMEOUT = 2
+_DEFAULT_TERMINAL_WIDTH = 80
+
+# echo -n overlord | md5sum
+_HTTP_BOUNDARY_MAGIC = '9246f080c855a69012707ab53489b921'
+
+_CONTROL_START = 128
+_CONTROL_END = 129
+_SSH_CONTROL_SOCKET_PREFIX = os.path.join(tempfile.gettempdir(),
+ 'ovl-ssh-control-')
+
+# A string that will always be included in the response of
+# GET http://OVERLORD_SERVER:_OVERLORD_HTTP_PORT
+_OVERLORD_RESPONSE_KEYWORD = '<html>'
+
+
+def GetVersionDigest():
+ """Return the sha1sum of the current executing script."""
+ with open(__file__, 'r') as f:
+ return hashlib.sha1(f.read()).hexdigest()
+
+
+def KillGraceful(pid, wait_secs=1):
+ """Kill a process gracefully by first sending SIGTERM, wait for some time,
+ then send SIGKILL to make sure it's killed."""
+ try:
+ os.kill(pid, signal.SIGTERM)
+ time.sleep(wait_secs)
+ os.kill(pid, signal.SIGKILL)
+ except OSError:
+ pass
+
+
+def BasicAuthHeader(user, password):
+ """Return HTTP basic auth header."""
+ credential = base64.b64encode('%s:%s' % (user, password))
+ return ('Authorization', 'Basic %s' % credential)
+
+
+def GetTerminalSize():
+ """Retrieve terminal window size."""
+ ws = struct.pack('HHHH', 0, 0, 0, 0)
+ ws = fcntl.ioctl(0, termios.TIOCGWINSZ, ws)
+ lines, columns, unused_x, unused_y = struct.unpack('HHHH', ws)
+ return lines, columns
+
+
+def MakeRequestUrl(state, url):
+ return 'http%s://%s' % ('s' if state.ssl else '', url)
+
+
+class ProgressBar(object):
+ SIZE_WIDTH = 11
+ SPEED_WIDTH = 10
+ DURATION_WIDTH = 6
+ PERCENTAGE_WIDTH = 8
+
+ def __init__(self, name):
+ self._start_time = time.time()
+ self._name = name
+ self._size = 0
+ self._width = 0
+ self._name_width = 0
+ self._name_max = 0
+ self._stat_width = 0
+ self._max = 0
+ self.CalculateSize()
+ self.SetProgress(0)
+
+ def CalculateSize(self):
+ self._width = GetTerminalSize()[1] or _DEFAULT_TERMINAL_WIDTH
+ self._name_width = int(self._width * 0.3)
+ self._name_max = self._name_width
+ self._stat_width = self.SIZE_WIDTH + self.SPEED_WIDTH + self.DURATION_WIDTH
+ self._max = (self._width - self._name_width - self._stat_width -
+ self.PERCENTAGE_WIDTH)
+
+ def SizeToHuman(self, size_in_bytes):
+ if size_in_bytes < 1024:
+ unit = 'B'
+ value = size_in_bytes
+ elif size_in_bytes < 1024 ** 2:
+ unit = 'KiB'
+ value = size_in_bytes / 1024.0
+ elif size_in_bytes < 1024 ** 3:
+ unit = 'MiB'
+ value = size_in_bytes / (1024.0 ** 2)
+ elif size_in_bytes < 1024 ** 4:
+ unit = 'GiB'
+ value = size_in_bytes / (1024.0 ** 3)
+ return ' %6.1f %3s' % (value, unit)
+
+ def SpeedToHuman(self, speed_in_bs):
+ if speed_in_bs < 1024:
+ unit = 'B'
+ value = speed_in_bs
+ elif speed_in_bs < 1024 ** 2:
+ unit = 'K'
+ value = speed_in_bs / 1024.0
+ elif speed_in_bs < 1024 ** 3:
+ unit = 'M'
+ value = speed_in_bs / (1024.0 ** 2)
+ elif speed_in_bs < 1024 ** 4:
+ unit = 'G'
+ value = speed_in_bs / (1024.0 ** 3)
+ return ' %6.1f%s/s' % (value, unit)
+
+ def DurationToClock(self, duration):
+ return ' %02d:%02d' % (duration / 60, duration % 60)
+
+ def SetProgress(self, percentage, size=None):
+ current_width = GetTerminalSize()[1]
+ if self._width != current_width:
+ self.CalculateSize()
+
+ if size is not None:
+ self._size = size
+
+ elapse_time = time.time() - self._start_time
+ speed = self._size / float(elapse_time)
+
+ size_str = self.SizeToHuman(self._size)
+ speed_str = self.SpeedToHuman(speed)
+ elapse_str = self.DurationToClock(elapse_time)
+
+ width = int(self._max * percentage / 100.0)
+ sys.stdout.write(
+ '%*s' % (- self._name_max,
+ self._name if len(self._name) <= self._name_max else
+ self._name[:self._name_max - 4] + ' ...') +
+ size_str + speed_str + elapse_str +
+ ((' [' + '#' * width + ' ' * (self._max - width) + ']' +
+ '%4d%%' % int(percentage)) if self._max > 2 else '') + '\r')
+ sys.stdout.flush()
+
+ def End(self):
+ self.SetProgress(100.0)
+ sys.stdout.write('\n')
+ sys.stdout.flush()
+
+
+class DaemonState(object):
+ """DaemonState is used for storing Overlord state info."""
+ def __init__(self):
+ self.version_sha1sum = GetVersionDigest()
+ self.host = None
+ self.port = None
+ self.ssl = False
+ self.ssh = False
+ self.orig_host = None
+ self.ssh_pid = None
+ self.username = None
+ self.password = None
+ self.selected_mid = None
+ self.forwards = {}
+ self.listing = []
+ self.last_list = 0
+
+
+class OverlordClientDaemon(object):
+ """Overlord Client Daemon."""
+ def __init__(self):
+ self._state = DaemonState()
+ self._server = None
+
+ def Start(self):
+ self.StartRPCServer()
+
+ def StartRPCServer(self):
+ self._server = SimpleJSONRPCServer(_OVERLORD_CLIENT_DAEMON_RPC_ADDR,
+ logRequests=False)
+ exports = [
+ (self.State, 'State'),
+ (self.Ping, 'Ping'),
+ (self.GetPid, 'GetPid'),
+ (self.Connect, 'Connect'),
+ (self.Clients, 'Clients'),
+ (self.SelectClient, 'SelectClient'),
+ (self.AddForward, 'AddForward'),
+ (self.RemoveForward, 'RemoveForward'),
+ (self.RemoveAllForward, 'RemoveAllForward'),
+ ]
+ for func, name in exports:
+ self._server.register_function(func, name)
+
+ pid = os.fork()
+ if pid == 0:
+ self._server.serve_forever()
+
+ @staticmethod
+ def GetRPCServer():
+ """Returns the Overlord client daemon RPC server."""
+ server = jsonrpclib.Server('http://%s:%d' %
+ _OVERLORD_CLIENT_DAEMON_RPC_ADDR)
+ try:
+ server.Ping()
+ except Exception:
+ return None
+ return server
+
+ def State(self):
+ return self._state
+
+ def Ping(self):
+ return True
+
+ def GetPid(self):
+ return os.getpid()
+
+ def _UrlOpen(self, url):
+ """Wrapper for urllib2.urlopen.
+
+ It selects correct HTTP scheme according to self._stat.ssl and add HTTP
+ basic auth headers.
+ """
+ url = MakeRequestUrl(self._state, url)
+ request = urllib2.Request(url)
+ if self._state.username is not None and self._state.password is not None:
+ request.add_header(*BasicAuthHeader(self._state.username,
+ self._state.password))
+ return urllib2.urlopen(request)
+
+ def _GetJSON(self, path):
+ url = '%s:%d%s' % (self._state.host, self._state.port, path)
+ return json.loads(self._UrlOpen(url).read())
+
+ def Connect(self, host, port=_OVERLORD_HTTP_PORT, ssh_pid=None,
+ username=None, password=None, orig_host=None):
+ self._state.username = username
+ self._state.password = password
+ self._state.host = host
+ self._state.port = port
+ self._state.ssl = False
+ self._state.orig_host = orig_host
+ self._state.ssh_pid = ssh_pid
+ self._state.selected_mid = None
+
+ try:
+ h = self._UrlOpen('%s:%d' % (host, port))
+ # Probably not an HTTP server, try HTTPS
+ if _OVERLORD_RESPONSE_KEYWORD not in h.read():
+ self._state.ssl = True
+ self._UrlOpen('%s:%d' % (host, port))
+ except urllib2.HTTPError as e:
+ logging.exception(e)
+ return e.getcode()
+ except Exception as e:
+ logging.exception(e)
+ return str(e)
+ return True
+
+ def Clients(self):
+ if time.time() - self._state.last_list <= _LIST_CACHE_TIMEOUT:
+ return self._state.listing
+
+ mids = [client['mid'] for client in self._GetJSON('/api/agents/list')]
+ self._state.listing = sorted(list(set(mids)))
+ self._state.last_list = time.time()
+ return self._state.listing
+
+ def SelectClient(self, mid):
+ self._state.selected_mid = mid
+
+ def AddForward(self, mid, remote, local, pid):
+ self._state.forwards[local] = (mid, remote, pid)
+
+ def RemoveForward(self, local_port):
+ try:
+ unused_mid, unused_remote, pid = self._state.forwards[local_port]
+ KillGraceful(pid)
+ del self._state.forwards[local_port]
+ except (KeyError, OSError):
+ pass
+
+ def RemoveAllForward(self):
+ for unused_mid, unused_remote, pid in self._state.forwards.values():
+ try:
+ KillGraceful(pid)
+ except OSError:
+ pass
+ self._state.forwards = {}
+
+
+class TerminalWebSocketClient(WebSocketBaseClient):
+ def __init__(self, mid, *args, **kwargs):
+ super(TerminalWebSocketClient, self).__init__(*args, **kwargs)
+ self._mid = mid
+ self._stdin_fd = sys.stdin.fileno()
+ self._old_termios = None
+
+ def handshake_ok(self):
+ pass
+
+ def opened(self):
+ nonlocals = {'size': (80, 40)}
+
+ def _ResizeWindow():
+ size = GetTerminalSize()
+ if size != nonlocals['size']: # Size not changed, ignore
+ control = {'command': 'resize', 'params': list(size)}
+ payload = chr(_CONTROL_START) + json.dumps(control) + chr(_CONTROL_END)
+ nonlocals['size'] = size
+ try:
+ self.send(payload, binary=True)
+ except Exception:
+ pass
+
+ def _FeedInput():
+ flags = fcntl.fcntl(sys.stdin, fcntl.F_GETFL)
+ fcntl.fcntl(sys.stdin, fcntl.F_SETFL, flags | os.O_NONBLOCK)
+
+ self._old_termios = termios.tcgetattr(self._stdin_fd)
+ tty.setraw(self._stdin_fd)
+
+ READY, ENTER_PRESSED, ESCAPE_PRESSED = range(3)
+
+ try:
+ state = READY
+ while True:
+ rd, unused_w, unused_x = select.select([sys.stdin], [], [], 0.5)
+
+ # We can't install a signal handler in the main thread since it'll
+ # interrupt the read/write system call (ws4py performing send/recv).
+ # Use polling instead (select's timeout is 0.5 seconds)
+ _ResizeWindow()
+
+ if sys.stdin in rd:
+ data = sys.stdin.read()
+
+ # Scan for escape sequence
+ for x in data:
+ if state == READY:
+ state = ENTER_PRESSED if x == chr(0x0d) else READY
+ elif state == ENTER_PRESSED:
+ state = ESCAPE_PRESSED if x == _ESCAPE else READY
+ elif state == ESCAPE_PRESSED:
+ if x == '.':
+ self.close()
+ raise RuntimeError('quit')
+ else:
+ state = READY
+
+ self.send(data)
+ except (KeyboardInterrupt, RuntimeError):
+ pass
+
+ t = threading.Thread(target=_FeedInput)
+ t.daemon = True
+ t.start()
+
+ def closed(self, code, reason=None):
+ termios.tcsetattr(self._stdin_fd, termios.TCSANOW, self._old_termios)
+ print('Connection to %s closed.' % self._mid)
+
+ def received_message(self, msg):
+ if msg.is_binary:
+ sys.stdout.write(msg.data)
+ sys.stdout.flush()
+
+
+class ShellWebSocketClient(WebSocketBaseClient):
+ def __init__(self, output, *args, **kwargs):
+ """Constructor.
+
+ Args:
+ output: output file object.
+ """
+ self.output = output
+ super(ShellWebSocketClient, self).__init__(*args, **kwargs)
+
+ def handshake_ok(self):
+ pass
+
+ def opened(self):
+ pass
+
+ def closed(self, code, reason=None):
+ pass
+
+ def received_message(self, msg):
+ if msg.is_binary:
+ self.output.write(msg.data)
+ self.output.flush()
+
+
+class ForwarderWebSocketClient(WebSocketBaseClient):
+ def __init__(self, sock, *args, **kwargs):
+ super(ForwarderWebSocketClient, self).__init__(*args, **kwargs)
+ self._sock = sock
+ self._stop = threading.Event()
+
+ def handshake_ok(self):
+ pass
+
+ def opened(self):
+ def _FeedInput():
+ try:
+ self._sock.setblocking(False)
+ while True:
+ rd, unused_w, unused_x = select.select([self._sock], [], [], 0.5)
+ if self._stop.is_set():
+ break
+ if self._sock in rd:
+ data = self._sock.recv(_BUFSIZ)
+ if len(data) == 0:
+ break
+ self.send(data, binary=True)
+ except Exception:
+ pass
+ finally:
+ self._sock.close()
+ self.close()
+
+ t = threading.Thread(target=_FeedInput)
+ t.daemon = True
+ t.start()
+
+ def closed(self, code, reason=None):
+ self._stop.set()
+ sys.exit(0)
+
+ def received_message(self, msg):
+ if msg.is_binary:
+ self._sock.send(msg.data)
+
+
+def Arg(*args, **kwargs):
+ return (args, kwargs)
+
+
+def Command(command, help_msg=None, args=None):
+ """Decorator for adding argparse parameter for a method."""
+ if args is None:
+ args = []
+ def WrapFunc(func):
+ def Wrapped(*args, **kwargs):
+ return func(*args, **kwargs)
+ # pylint: disable=W0212
+ Wrapped.__arg_attr = {'command': command, 'help': help_msg, 'args': args}
+ return Wrapped
+ return WrapFunc
+
+
+def ParseMethodSubCommands(cls):
+ """Decorator for a class using the @Command decorator.
+
+ This decorator retrieve command info from each method and append it in to the
+ SUBCOMMANDS class variable, which is later used to construct parser.
+ """
+ for unused_key, method in cls.__dict__.iteritems():
+ if hasattr(method, '__arg_attr'):
+ cls.SUBCOMMANDS.append(method.__arg_attr) # pylint: disable=W0212
+ return cls
+
+
+@ParseMethodSubCommands
+class OverlordCLIClient(object):
+ """Overlord command line interface client."""
+
+ SUBCOMMANDS = []
+
+ def __init__(self):
+ self._parser = self._BuildParser()
+ self._selected_mid = None
+ self._server = None
+ self._state = None
+
+ def _BuildParser(self):
+ root_parser = argparse.ArgumentParser(prog='ovl')
+ subparsers = root_parser.add_subparsers(help='sub-command')
+
+ root_parser.add_argument('-s', dest='selected_mid', action='store',
+ default=None,
+ help='select target to execute command on')
+ root_parser.add_argument('-S', dest='select_mid_before_action',
+ action='store_true', default=False,
+ help='select target before executing command')
+
+ for attr in self.SUBCOMMANDS:
+ parser = subparsers.add_parser(attr['command'], help=attr['help'])
+ parser.set_defaults(which=attr['command'])
+ for arg in attr['args']:
+ parser.add_argument(*arg[0], **arg[1])
+
+ return root_parser
+
+ def Main(self):
+ # We want to pass the rest of arguments after shell command directly to the
+ # function without parsing it.
+ try:
+ index = sys.argv.index('shell')
+ except ValueError:
+ args = self._parser.parse_args()
+ else:
+ args = self._parser.parse_args(sys.argv[1:index + 1])
+
+ command = args.which
+ self._selected_mid = args.selected_mid
+
+ if command == 'kill-server':
+ self.KillServer()
+ return
+
+ self.StartDaemon()
+ if command == 'status':
+ self.Status()
+ return
+ elif command == 'connect':
+ self.Connect(args)
+ return
+
+ # The following command requires connection to the server
+ self.CheckConnection()
+
+ if args.select_mid_before_action:
+ self.SelectClient(store=False)
+
+ if command == 'select':
+ self.SelectClient(args)
+ elif command == 'ls':
+ self.ListClients()
+ elif command == 'shell':
+ command = sys.argv[sys.argv.index('shell') + 1:]
+ self.Shell(command)
+ elif command == 'push':
+ self.Push(args)
+ elif command == 'pull':
+ self.Pull(args)
+ elif command == 'forward':
+ self.Forward(args)
+
+ def _UrlOpen(self, url):
+ url = MakeRequestUrl(self._state, url)
+ request = urllib2.Request(url)
+ if self._state.username is not None and self._state.password is not None:
+ request.add_header(*BasicAuthHeader(self._state.username,
+ self._state.password))
+ return urllib2.urlopen(request)
+
+ def _HTTPPostFile(self, url, filename, progress=None, user=None, passwd=None):
+ """Perform HTTP POST and upload file to Overlord.
+
+ To minimize the external dependencies, we construct the HTTP post request
+ by ourselves.
+ """
+ url = MakeRequestUrl(self._state, url)
+ size = os.stat(filename).st_size
+ boundary = '-----------%s' % _HTTP_BOUNDARY_MAGIC
+ CRLF = '\r\n'
+ parse = urlparse.urlparse(url)
+
+ part_headers = [
+ '--' + boundary,
+ 'Content-Disposition: form-data; name="file"; '
+ 'filename="%s"' % os.path.basename(filename),
+ 'Content-Type: application/octet-stream',
+ '', ''
+ ]
+ part_header = CRLF.join(part_headers)
+ end_part = CRLF + '--' + boundary + '--' + CRLF
+
+ content_length = len(part_header) + size + len(end_part)
+ if parse.scheme == 'http':
+ h = httplib.HTTP(parse.netloc)
+ else:
+ h = httplib.HTTPS(parse.netloc)
+
+ post_path = url[url.index(parse.netloc) + len(parse.netloc):]
+ h.putrequest('POST', post_path)
+ h.putheader('Content-Length', content_length)
+ h.putheader('Content-Type', 'multipart/form-data; boundary=%s' % boundary)
+
+ if user and passwd:
+ h.putheader(*BasicAuthHeader(user, passwd))
+ h.endheaders()
+ h.send(part_header)
+
+ count = 0
+ with open(filename, 'r') as f:
+ while True:
+ data = f.read(_BUFSIZ)
+ if not data:
+ break
+ count += len(data)
+ if progress:
+ progress(int(count * 100.0 / size), count)
+ h.send(data)
+
+ h.send(end_part)
+ progress(100)
+
+ if count != size:
+ logging.warning('file changed during upload, upload may be truncated.')
+
+ errcode, unused_x, unused_y = h.getreply()
+ return errcode == 200
+
+ def StartDaemon(self):
+ self._server = OverlordClientDaemon.GetRPCServer()
+ if self._server is None:
+ print('* daemon not running, starting it now on port %d ... *' %
+ _OVERLORD_CLIENT_DAEMON_PORT)
+ OverlordClientDaemon().Start()
+ time.sleep(1)
+ self._server = OverlordClientDaemon.GetRPCServer()
+ if self._server is not None:
+ print('* daemon started successfully *')
+
+ self._state = self._server.State()
+ sha1sum = GetVersionDigest()
+
+ if sha1sum != self._state.version_sha1sum:
+ print('ovl server is out of date. killing...')
+ KillGraceful(self._server.GetPid())
+ self.StartDaemon()
+
+ def GetSSHControlFile(self, host):
+ return _SSH_CONTROL_SOCKET_PREFIX + host
+
+ def SSHTunnel(self, user, host, port):
+ """SSH forward the remote overlord server.
+
+ Overlord server may not have port 9000 open to the public network, in such
+ case we can SSH forward the port to localhost.
+ """
+
+ control_file = self.GetSSHControlFile(host)
+ try:
+ os.unlink(control_file)
+ except Exception:
+ pass
+
+ subprocess.Popen([
+ 'ssh', '-Nf',
+ '-M', # Enable master mode
+ '-S', control_file,
+ '-L', '9000:localhost:9000',
+ '-p', str(port),
+ '%s%s' % (user + '@' if user else '', host)
+ ]).wait()
+
+ p = subprocess.Popen([
+ 'ssh',
+ '-S', control_file,
+ '-O', 'check', host,
+ ], stderr=subprocess.PIPE)
+ unused_stdout, stderr = p.communicate()
+
+ s = re.search(r'pid=(\d+)', stderr)
+ if s:
+ return int(s.group(1))
+
+ raise RuntimeError('can not establish ssh connection')
+
+ def CheckConnection(self):
+ if self._state.host is None:
+ raise RuntimeError('not connected to any server, abort')
+
+ try:
+ self._server.Clients()
+ except Exception:
+ raise RuntimeError('remote server disconnected, abort')
+
+ if self._state.ssh_pid is not None:
+ ret = subprocess.Popen(['kill', '-0', str(self._state.ssh_pid)],
+ stdout=subprocess.PIPE,
+ stderr=subprocess.PIPE).wait()
+ if ret != 0:
+ raise RuntimeError('ssh tunnel disconnected, please re-connect')
+
+ def CheckClient(self):
+ if self._selected_mid is None:
+ if self._state.selected_mid is None:
+ raise RuntimeError('No client is selected')
+ self._selected_mid = self._state.selected_mid
+
+ if self._selected_mid not in self._server.Clients():
+ raise RuntimeError('client %s disappeared' % self._selected_mid)
+
+ def CheckOutput(self, command):
+ headers = []
+ if self._state.username is not None and self._state.password is not None:
+ headers.append(BasicAuthHeader(self._state.username,
+ self._state.password))
+
+ scheme = 'ws%s://' % ('s' if self._state.ssl else '')
+ sio = StringIO.StringIO()
+ ws = ShellWebSocketClient(sio,
+ scheme + '%s:%d/api/agent/shell/%s?command=%s' %
+ (self._state.host, self._state.port,
+ self._selected_mid, urllib2.quote(command)),
+ headers=headers)
+ ws.connect()
+ ws.run()
+ return sio.getvalue()
+
+ @Command('status', 'show Overlord connection status')
+ def Status(self):
+ if self._state.host is None:
+ print('Not connected to any host.')
+ else:
+ if self._state.ssh_pid is not None:
+ print('Connected to %s with SSH tunneling.' % self._state.orig_host)
+ else:
+ print('Connected to %s:%d.' % (self._state.host, self._state.port))
+
+ if self._selected_mid is None:
+ self._selected_mid = self._state.selected_mid
+
+ if self._selected_mid is None:
+ print('No client is selected.')
+ else:
+ print('Client %s selected.' % self._selected_mid)
+
+ @Command('connect', 'connect to Overlord server', [
+ Arg('host', metavar='HOST', type=str, default='localhost',
+ help='Overlord hostname/IP'),
+ Arg('port', metavar='PORT', type=int,
+ default=_OVERLORD_HTTP_PORT, help='Overlord port'),
+ Arg('-f', '--forward', dest='ssh_forward', default=False,
+ action='store_true',
+ help='connect with SSH forwarding to the host'),
+ Arg('-p', '--ssh-port', dest='ssh_port', default=22,
+ type=int, help='SSH server port for SSH forwarding'),
+ Arg('-l', '--ssh-login', dest='ssh_login', default='',
+ type=str, help='SSH server login name for SSH forwarding'),
+ Arg('-u', '--user', dest='user', default=None,
+ type=str, help='Overlord HTTP auth username'),
+ Arg('-w', '--passwd', dest='passwd', default=None, type=str,
+ help='Overlord HTTP auth password')])
+ def Connect(self, args):
+ ssh_pid = None
+ host = args.host
+ orig_host = args.host
+
+ if args.ssh_forward:
+ # Kill previous SSH tunnel
+ self.KillSSHTunnel()
+
+ ssh_pid = self.SSHTunnel(args.ssh_login, args.host, args.ssh_port)
+ host = 'localhost'
+
+ status = self._server.Connect(host, args.port, ssh_pid, args.user,
+ args.passwd, orig_host)
+ if status is not True:
+ if isinstance(status, int):
+ if status == 401:
+ msg = '401 Unauthorized'
+ else:
+ msg = 'HTTP %d' % status
+ else:
+ msg = status
+ print('can not connect to %s: %s' % (host, msg))
+
+ @Command('kill-server', 'kill overlord CLI client server')
+ def KillServer(self):
+ self._server = OverlordClientDaemon.GetRPCServer()
+ if self._server is None:
+ return
+
+ # Kill SSH Tunnel
+ self.KillSSHTunnel()
+
+ # Kill server daemon
+ KillGraceful(self._server.GetPid())
+
+ def KillSSHTunnel(self):
+ if self._state.ssh_pid is not None:
+ KillGraceful(self._state.ssh_pid)
+
+ @Command('ls', 'list all clients')
+ def ListClients(self):
+ for client in self._server.Clients():
+ print(client)
+
+ @Command('select', 'select default client', [
+ Arg('mid', metavar='mid', nargs='?', default=None)])
+ def SelectClient(self, args=None, store=True):
+ clients = self._server.Clients()
+
+ mid = args.mid if args is not None else None
+ if mid is None:
+ print('Select from the following clients:')
+ for i, client in enumerate(clients):
+ print(' %d. %s' % (i + 1, client))
+
+ print('\nSelection: ', end='')
+ try:
+ choice = int(raw_input()) - 1
+ mid = clients[choice]
+ except ValueError:
+ raise RuntimeError('select: invalid selection')
+ except IndexError:
+ raise RuntimeError('select: selection out of range')
+ else:
+ if mid not in clients:
+ raise RuntimeError('select: client %s does not exist' % mid)
+
+ self._selected_mid = mid
+ if store:
+ self._server.SelectClient(mid)
+ print('Client %s selected' % mid)
+
+ @Command('shell', 'open a shell or execute a shell command', [
+ Arg('command', metavar='CMD', nargs='?', help='command to execute')])
+ def Shell(self, command=None):
+ if command is None:
+ command = []
+ self.CheckClient()
+
+ headers = []
+ if self._state.username is not None and self._state.password is not None:
+ headers.append(BasicAuthHeader(self._state.username,
+ self._state.password))
+
+ scheme = 'ws%s://' % ('s' if self._state.ssl else '')
+ if len(command) == 0:
+ ws = TerminalWebSocketClient(self._selected_mid,
+ scheme + '%s:%d/api/agent/tty/%s' %
+ (self._state.host, self._state.port,
+ self._selected_mid), headers=headers)
+ else:
+ cmd = ' '.join(command)
+ ws = ShellWebSocketClient(sys.stdout,
+ scheme + '%s:%d/api/agent/shell/%s?command=%s' %
+ (self._state.host, self._state.port,
+ self._selected_mid, urllib2.quote(cmd)),
+ headers=headers)
+ ws.connect()
+ ws.run()
+
+ @Command('push', 'push a file or directory to remote', [
+ Arg('src', metavar='SOURCE'),
+ Arg('dst', metavar='DESTINATION')])
+ def Push(self, args):
+ self.CheckClient()
+
+ if not os.path.exists(args.src):
+ raise RuntimeError('push: can not stat "%s": no such file or directory'
+ % args.src)
+
+ if not os.access(args.src, os.R_OK):
+ raise RuntimeError('push: can not open "%s" for reading' % args.src)
+
+ def _push(src, dst):
+ src_base = os.path.basename(src)
+ mode = '0%o' % (0x1FF & os.stat(src).st_mode)
+ url = ('%s:%d/api/agent/upload/%s?dest=%s&perm=%s' %
+ (self._state.host, self._state.port, self._selected_mid, dst,
+ mode))
+ try:
+ self._UrlOpen(url + '&filename=%s' % src_base)
+ except urllib2.HTTPError as e:
+ msg = json.loads(e.read()).get('error', None)
+ raise RuntimeError('push: %s' % msg)
+
+ pbar = ProgressBar(src_base)
+ self._HTTPPostFile(url, src, pbar.SetProgress,
+ self._state.username, self._state.password)
+ pbar.End()
+
+ if os.path.isdir(args.src):
+ dst_exists = ast.literal_eval(self.CheckOutput(
+ 'stat %s >/dev/null 2>&1 && echo True || echo False' % args.dst))
+ for root, unused_x, files in os.walk(args.src):
+ # If destination directory does not exist, we should strip the first
+ # layer of directory. For example: src_dir contains a single file 'A'
+ #
+ # push src_dir dest_dir
+ #
+ # If dest_dir exists, the resulting directory structure should be:
+ # dest_dir/src_dir/A
+ # If dest_dir does not exist, the resulting directory structure should
+ # be:
+ # dest_dir/A
+ dst_root = root if dst_exists else root[len(args.src):].lstrip('/')
+ for name in files:
+ _push(os.path.join(root, name),
+ os.path.join(args.dst, dst_root, name))
+ else:
+ _push(args.src, args.dst)
+
+ @Command('pull', 'pull a file or directory from remote', [
+ Arg('src', metavar='SOURCE'),
+ Arg('dst', metavar='DESTINATION', default='.', nargs='?')])
+ def Pull(self, args):
+ self.CheckClient()
+
+ def _pull(src, dst, perm=0644):
+ url = ('%s:%d/api/agent/download/%s?filename=%s' %
+ (self._state.host, self._state.port, self._selected_mid,
+ urllib2.quote(src)))
+ try:
+ h = self._UrlOpen(url)
+ except urllib2.HTTPError as e:
+ msg = json.loads(e.read()).get('error', 'unkown error')
+ raise RuntimeError('pull: %s' % msg)
+ except KeyboardInterrupt:
+ return
+
+ try:
+ os.makedirs(os.path.dirname(dst))
+ except Exception:
+ pass
+
+ pbar = ProgressBar(os.path.basename(src))
+ with open(dst, 'w') as f:
+ os.fchmod(f.fileno(), perm)
+ total_size = int(h.headers.get('Content-Length'))
+ downloaded_size = 0
+ while True:
+ data = h.read(_BUFSIZ)
+ downloaded_size += len(data)
+ pbar.SetProgress(float(downloaded_size) * 100 / total_size,
+ downloaded_size)
+ if len(data) == 0:
+ break
+ f.write(data)
+ pbar.End()
+
+ # Use find to get a listing of all files under a root directory. The 'stat'
+ # command is used to retrieve the filename and it's filemode.
+ output = self.CheckOutput(
+ 'cd $HOME; '
+ 'stat "%(src)s" >/dev/null && '
+ 'find "%(src)s" \'(\' -type f -o -type l \')\' -printf \'%%m\t%%p\n\''
+ % {'src': args.src})
+
+ # We got error from the stat command
+ if output.startswith('stat: '):
+ sys.stderr.write(output)
+ return
+
+ entries = output.strip().split('\n')
+ common_prefix = os.path.dirname(args.src)
+
+ if len(entries) == 1:
+ entry = entries[0]
+ perm, src_path = entry.split('\t')
+ if os.path.isdir(args.dst):
+ dst = os.path.join(args.dst, os.path.basename(src_path))
+ else:
+ dst = args.dst
+ _pull(src_path, dst, int(perm, base=8))
+ else:
+ if not os.path.exists(args.dst):
+ common_prefix = args.src
+
+ for entry in entries:
+ perm, src_path = entry.split('\t')
+ rel_dst = src_path[len(common_prefix):].lstrip('/')
+ _pull(src_path, os.path.join(args.dst, rel_dst), int(perm, base=8))
+
+ @Command('forward', 'forward remote port to local port', [
+ Arg('--list', dest='list_all', action='store_true', default=False,
+ help='list all port forwarding sessions'),
+ Arg('--remove', metavar='LOCAL_PORT', dest='remove', type=int,
+ default=None,
+ help='remove port forwarding for local port LOCAL_PORT'),
+ Arg('--remove-all', dest='remove_all', action='store_true',
+ default=False, help='remove all port forwarding'),
+ Arg('remote', metavar='REMOTE_PORT', type=int, nargs='?'),
+ Arg('local', metavar='LOCAL_PORT', type=int, nargs='?')])
+ def Forward(self, args):
+ if args.list_all:
+ max_len = 10
+ if len(self._state.forwards):
+ max_len = max([len(v[0]) for v in self._state.forwards.values()])
+
+ print('%-*s %-8s %-8s' % (max_len, 'Client', 'Remote', 'Local'))
+ for local in sorted(self._state.forwards.keys()):
+ value = self._state.forwards[local]
+ print('%-*s %-8s %-8s' % (max_len, value[0], value[1], local))
+ return
+
+ if args.remove_all:
+ self._server.RemoveAllForward()
+ return
+
+ if args.remove:
+ self._server.RemoveForward(args.remove)
+ return
+
+ self.CheckClient()
+
+ if args.local is None:
+ args.local = args.remote
+ remote = int(args.remote)
+ local = int(args.local)
+
+ def HandleConnection(conn):
+ headers = []
+ if self._state.username is not None and self._state.password is not None:
+ headers.append(BasicAuthHeader(self._state.username,
+ self._state.password))
+
+ scheme = 'ws%s://' % ('s' if self._state.ssl else '')
+ ws = ForwarderWebSocketClient(
+ conn,
+ scheme + '%s:%d/api/agent/forward/%s?port=%d' %
+ (self._state.host, self._state.port, self._selected_mid, remote),
+ headers=headers)
+ try:
+ ws.connect()
+ ws.run()
+ except Exception as e:
+ print('error: %s' % e)
+ finally:
+ ws.close()
+
+ server = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
+ server.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
+ server.bind(('0.0.0.0', local))
+ server.listen(5)
+
+ pid = os.fork()
+ if pid == 0:
+ while True:
+ conn, unused_addr = server.accept()
+ t = threading.Thread(target=HandleConnection, args=(conn,))
+ t.daemon = True
+ t.start()
+ else:
+ self._server.AddForward(self._selected_mid, remote, local, pid)
+
+
+def main():
+ logging.basicConfig(level=logging.INFO)
+
+ # Add DaemonState to JSONRPC lib classes
+ Config.instance().classes.add(DaemonState)
+
+ ovl = OverlordCLIClient()
+ try:
+ ovl.Main()
+ except KeyboardInterrupt:
+ print('Ctrl-C received, abort')
+ except Exception as e:
+ print('error: %s' % e)
+
+
+if __name__ == '__main__':
+ main()