overlord: implement ADB-like command line interface

Implement an ADB-like CLI client for Overlord. The client supports all
ovelord functions including shell, file upload/download, port
forwarding, etc.

BUG=chromium:517520
TEST=manually test

Change-Id: I7211ead781ecf8f59acc3e8f2e4c07c898a858b8
Reviewed-on: https://chromium-review.googlesource.com/291940
Commit-Ready: Wei-Ning Huang <wnhuang@chromium.org>
Tested-by: Wei-Ning Huang <wnhuang@chromium.org>
Reviewed-by: Hung-Te Lin <hungte@chromium.org>
diff --git a/py/tools/ovl.py b/py/tools/ovl.py
new file mode 100755
index 0000000..c44620c
--- /dev/null
+++ b/py/tools/ovl.py
@@ -0,0 +1,1098 @@
+#!/usr/bin/python -u
+# -*- coding: utf-8 -*-
+#
+# Copyright 2015 The Chromium OS Authors. All rights reserved.
+# Use of this source code is governed by a BSD-style license that can be
+# found in the LICENSE file.
+
+from __future__ import print_function
+
+import argparse
+import ast
+import base64
+import fcntl
+import hashlib
+import httplib
+import json
+import jsonrpclib
+import logging
+import os
+import re
+import select
+import signal
+import socket
+import StringIO
+import struct
+import subprocess
+import sys
+import tempfile
+import termios
+import threading
+import time
+import tty
+import urllib2
+import urlparse
+
+from jsonrpclib.SimpleJSONRPCServer import SimpleJSONRPCServer
+from jsonrpclib.config import Config
+from ws4py.client import WebSocketBaseClient
+
+# Python version >= 2.7.9 enables SSL check by default, bypass it.
+try:
+  import ssl
+  # pylint: disable=W0212
+  ssl._create_default_https_context = ssl._create_unverified_context
+except Exception:
+  pass
+
+
+_ESCAPE = '~'
+_BUFSIZ = 8192
+_OVERLORD_PORT = 4455
+_OVERLORD_HTTP_PORT = 9000
+_OVERLORD_CLIENT_DAEMON_PORT = 4488
+_OVERLORD_CLIENT_DAEMON_RPC_ADDR = ('127.0.0.1', _OVERLORD_CLIENT_DAEMON_PORT)
+
+_LIST_CACHE_TIMEOUT = 2
+_DEFAULT_TERMINAL_WIDTH = 80
+
+# echo -n overlord | md5sum
+_HTTP_BOUNDARY_MAGIC = '9246f080c855a69012707ab53489b921'
+
+_CONTROL_START = 128
+_CONTROL_END = 129
+_SSH_CONTROL_SOCKET_PREFIX = os.path.join(tempfile.gettempdir(),
+                                          'ovl-ssh-control-')
+
+# A string that will always be included in the response of
+# GET http://OVERLORD_SERVER:_OVERLORD_HTTP_PORT
+_OVERLORD_RESPONSE_KEYWORD = '<html>'
+
+
+def GetVersionDigest():
+  """Return the sha1sum of the current executing script."""
+  with open(__file__, 'r') as f:
+    return hashlib.sha1(f.read()).hexdigest()
+
+
+def KillGraceful(pid, wait_secs=1):
+  """Kill a process gracefully by first sending SIGTERM, wait for some time,
+  then send SIGKILL to make sure it's killed."""
+  try:
+    os.kill(pid, signal.SIGTERM)
+    time.sleep(wait_secs)
+    os.kill(pid, signal.SIGKILL)
+  except OSError:
+    pass
+
+
+def BasicAuthHeader(user, password):
+  """Return HTTP basic auth header."""
+  credential = base64.b64encode('%s:%s' % (user, password))
+  return ('Authorization', 'Basic %s' % credential)
+
+
+def GetTerminalSize():
+  """Retrieve terminal window size."""
+  ws = struct.pack('HHHH', 0, 0, 0, 0)
+  ws = fcntl.ioctl(0, termios.TIOCGWINSZ, ws)
+  lines, columns, unused_x, unused_y = struct.unpack('HHHH', ws)
+  return lines, columns
+
+
+def MakeRequestUrl(state, url):
+  return 'http%s://%s' % ('s' if state.ssl else '', url)
+
+
+class ProgressBar(object):
+  SIZE_WIDTH = 11
+  SPEED_WIDTH = 10
+  DURATION_WIDTH = 6
+  PERCENTAGE_WIDTH = 8
+
+  def __init__(self, name):
+    self._start_time = time.time()
+    self._name = name
+    self._size = 0
+    self._width = 0
+    self._name_width = 0
+    self._name_max = 0
+    self._stat_width = 0
+    self._max = 0
+    self.CalculateSize()
+    self.SetProgress(0)
+
+  def CalculateSize(self):
+    self._width = GetTerminalSize()[1] or _DEFAULT_TERMINAL_WIDTH
+    self._name_width = int(self._width * 0.3)
+    self._name_max = self._name_width
+    self._stat_width = self.SIZE_WIDTH + self.SPEED_WIDTH + self.DURATION_WIDTH
+    self._max = (self._width - self._name_width - self._stat_width -
+                 self.PERCENTAGE_WIDTH)
+
+  def SizeToHuman(self, size_in_bytes):
+    if size_in_bytes < 1024:
+      unit = 'B'
+      value = size_in_bytes
+    elif size_in_bytes < 1024 ** 2:
+      unit = 'KiB'
+      value = size_in_bytes / 1024.0
+    elif size_in_bytes < 1024 ** 3:
+      unit = 'MiB'
+      value = size_in_bytes / (1024.0 ** 2)
+    elif size_in_bytes < 1024 ** 4:
+      unit = 'GiB'
+      value = size_in_bytes / (1024.0 ** 3)
+    return ' %6.1f %3s' % (value, unit)
+
+  def SpeedToHuman(self, speed_in_bs):
+    if speed_in_bs < 1024:
+      unit = 'B'
+      value = speed_in_bs
+    elif speed_in_bs < 1024 ** 2:
+      unit = 'K'
+      value = speed_in_bs / 1024.0
+    elif speed_in_bs < 1024 ** 3:
+      unit = 'M'
+      value = speed_in_bs / (1024.0 ** 2)
+    elif speed_in_bs < 1024 ** 4:
+      unit = 'G'
+      value = speed_in_bs / (1024.0 ** 3)
+    return ' %6.1f%s/s' % (value, unit)
+
+  def DurationToClock(self, duration):
+    return ' %02d:%02d' % (duration / 60, duration % 60)
+
+  def SetProgress(self, percentage, size=None):
+    current_width = GetTerminalSize()[1]
+    if self._width != current_width:
+      self.CalculateSize()
+
+    if size is not None:
+      self._size = size
+
+    elapse_time = time.time() - self._start_time
+    speed = self._size / float(elapse_time)
+
+    size_str = self.SizeToHuman(self._size)
+    speed_str = self.SpeedToHuman(speed)
+    elapse_str = self.DurationToClock(elapse_time)
+
+    width = int(self._max * percentage / 100.0)
+    sys.stdout.write(
+        '%*s' % (- self._name_max,
+                 self._name if len(self._name) <= self._name_max else
+                 self._name[:self._name_max - 4] + ' ...') +
+        size_str + speed_str + elapse_str +
+        ((' [' + '#' * width + ' ' * (self._max - width) + ']' +
+          '%4d%%' % int(percentage)) if self._max > 2 else '') + '\r')
+    sys.stdout.flush()
+
+  def End(self):
+    self.SetProgress(100.0)
+    sys.stdout.write('\n')
+    sys.stdout.flush()
+
+
+class DaemonState(object):
+  """DaemonState is used for storing Overlord state info."""
+  def __init__(self):
+    self.version_sha1sum = GetVersionDigest()
+    self.host = None
+    self.port = None
+    self.ssl = False
+    self.ssh = False
+    self.orig_host = None
+    self.ssh_pid = None
+    self.username = None
+    self.password = None
+    self.selected_mid = None
+    self.forwards = {}
+    self.listing = []
+    self.last_list = 0
+
+
+class OverlordClientDaemon(object):
+  """Overlord Client Daemon."""
+  def __init__(self):
+    self._state = DaemonState()
+    self._server = None
+
+  def Start(self):
+    self.StartRPCServer()
+
+  def StartRPCServer(self):
+    self._server = SimpleJSONRPCServer(_OVERLORD_CLIENT_DAEMON_RPC_ADDR,
+                                       logRequests=False)
+    exports = [
+        (self.State, 'State'),
+        (self.Ping, 'Ping'),
+        (self.GetPid, 'GetPid'),
+        (self.Connect, 'Connect'),
+        (self.Clients, 'Clients'),
+        (self.SelectClient, 'SelectClient'),
+        (self.AddForward, 'AddForward'),
+        (self.RemoveForward, 'RemoveForward'),
+        (self.RemoveAllForward, 'RemoveAllForward'),
+    ]
+    for func, name in exports:
+      self._server.register_function(func, name)
+
+    pid = os.fork()
+    if pid == 0:
+      self._server.serve_forever()
+
+  @staticmethod
+  def GetRPCServer():
+    """Returns the Overlord client daemon RPC server."""
+    server = jsonrpclib.Server('http://%s:%d' %
+                               _OVERLORD_CLIENT_DAEMON_RPC_ADDR)
+    try:
+      server.Ping()
+    except Exception:
+      return None
+    return server
+
+  def State(self):
+    return self._state
+
+  def Ping(self):
+    return True
+
+  def GetPid(self):
+    return os.getpid()
+
+  def _UrlOpen(self, url):
+    """Wrapper for urllib2.urlopen.
+
+    It selects correct HTTP scheme according to self._stat.ssl and add HTTP
+    basic auth headers.
+    """
+    url = MakeRequestUrl(self._state, url)
+    request = urllib2.Request(url)
+    if self._state.username is not None and self._state.password is not None:
+      request.add_header(*BasicAuthHeader(self._state.username,
+                                          self._state.password))
+    return urllib2.urlopen(request)
+
+  def _GetJSON(self, path):
+    url = '%s:%d%s' % (self._state.host, self._state.port, path)
+    return json.loads(self._UrlOpen(url).read())
+
+  def Connect(self, host, port=_OVERLORD_HTTP_PORT, ssh_pid=None,
+              username=None, password=None, orig_host=None):
+    self._state.username = username
+    self._state.password = password
+    self._state.host = host
+    self._state.port = port
+    self._state.ssl = False
+    self._state.orig_host = orig_host
+    self._state.ssh_pid = ssh_pid
+    self._state.selected_mid = None
+
+    try:
+      h = self._UrlOpen('%s:%d' % (host, port))
+      # Probably not an HTTP server, try HTTPS
+      if _OVERLORD_RESPONSE_KEYWORD not in h.read():
+        self._state.ssl = True
+        self._UrlOpen('%s:%d' % (host, port))
+    except urllib2.HTTPError as e:
+      logging.exception(e)
+      return e.getcode()
+    except Exception as e:
+      logging.exception(e)
+      return str(e)
+    return True
+
+  def Clients(self):
+    if time.time() - self._state.last_list <= _LIST_CACHE_TIMEOUT:
+      return self._state.listing
+
+    mids = [client['mid'] for client in self._GetJSON('/api/agents/list')]
+    self._state.listing = sorted(list(set(mids)))
+    self._state.last_list = time.time()
+    return self._state.listing
+
+  def SelectClient(self, mid):
+    self._state.selected_mid = mid
+
+  def AddForward(self, mid, remote, local, pid):
+    self._state.forwards[local] = (mid, remote, pid)
+
+  def RemoveForward(self, local_port):
+    try:
+      unused_mid, unused_remote, pid = self._state.forwards[local_port]
+      KillGraceful(pid)
+      del self._state.forwards[local_port]
+    except (KeyError, OSError):
+      pass
+
+  def RemoveAllForward(self):
+    for unused_mid, unused_remote, pid in self._state.forwards.values():
+      try:
+        KillGraceful(pid)
+      except OSError:
+        pass
+    self._state.forwards = {}
+
+
+class TerminalWebSocketClient(WebSocketBaseClient):
+  def __init__(self, mid, *args, **kwargs):
+    super(TerminalWebSocketClient, self).__init__(*args, **kwargs)
+    self._mid = mid
+    self._stdin_fd = sys.stdin.fileno()
+    self._old_termios = None
+
+  def handshake_ok(self):
+    pass
+
+  def opened(self):
+    nonlocals = {'size': (80, 40)}
+
+    def _ResizeWindow():
+      size = GetTerminalSize()
+      if size != nonlocals['size']:  # Size not changed, ignore
+        control = {'command': 'resize', 'params': list(size)}
+        payload = chr(_CONTROL_START) + json.dumps(control) + chr(_CONTROL_END)
+        nonlocals['size'] = size
+        try:
+          self.send(payload, binary=True)
+        except Exception:
+          pass
+
+    def _FeedInput():
+      flags = fcntl.fcntl(sys.stdin, fcntl.F_GETFL)
+      fcntl.fcntl(sys.stdin, fcntl.F_SETFL, flags | os.O_NONBLOCK)
+
+      self._old_termios = termios.tcgetattr(self._stdin_fd)
+      tty.setraw(self._stdin_fd)
+
+      READY, ENTER_PRESSED, ESCAPE_PRESSED = range(3)
+
+      try:
+        state = READY
+        while True:
+          rd, unused_w, unused_x = select.select([sys.stdin], [], [], 0.5)
+
+          # We can't install a signal handler in the main thread since it'll
+          # interrupt the read/write system call (ws4py performing send/recv).
+          # Use polling instead (select's timeout is 0.5 seconds)
+          _ResizeWindow()
+
+          if sys.stdin in rd:
+            data = sys.stdin.read()
+
+            # Scan for escape sequence
+            for x in data:
+              if state == READY:
+                state = ENTER_PRESSED if x == chr(0x0d) else READY
+              elif state == ENTER_PRESSED:
+                state = ESCAPE_PRESSED if x == _ESCAPE else READY
+              elif state == ESCAPE_PRESSED:
+                if x == '.':
+                  self.close()
+                  raise RuntimeError('quit')
+              else:
+                state = READY
+
+            self.send(data)
+      except (KeyboardInterrupt, RuntimeError):
+        pass
+
+    t = threading.Thread(target=_FeedInput)
+    t.daemon = True
+    t.start()
+
+  def closed(self, code, reason=None):
+    termios.tcsetattr(self._stdin_fd, termios.TCSANOW, self._old_termios)
+    print('Connection to %s closed.' % self._mid)
+
+  def received_message(self, msg):
+    if msg.is_binary:
+      sys.stdout.write(msg.data)
+      sys.stdout.flush()
+
+
+class ShellWebSocketClient(WebSocketBaseClient):
+  def __init__(self, output, *args, **kwargs):
+    """Constructor.
+
+    Args:
+      output: output file object.
+    """
+    self.output = output
+    super(ShellWebSocketClient, self).__init__(*args, **kwargs)
+
+  def handshake_ok(self):
+    pass
+
+  def opened(self):
+    pass
+
+  def closed(self, code, reason=None):
+    pass
+
+  def received_message(self, msg):
+    if msg.is_binary:
+      self.output.write(msg.data)
+      self.output.flush()
+
+
+class ForwarderWebSocketClient(WebSocketBaseClient):
+  def __init__(self, sock, *args, **kwargs):
+    super(ForwarderWebSocketClient, self).__init__(*args, **kwargs)
+    self._sock = sock
+    self._stop = threading.Event()
+
+  def handshake_ok(self):
+    pass
+
+  def opened(self):
+    def _FeedInput():
+      try:
+        self._sock.setblocking(False)
+        while True:
+          rd, unused_w, unused_x = select.select([self._sock], [], [], 0.5)
+          if self._stop.is_set():
+            break
+          if self._sock in rd:
+            data = self._sock.recv(_BUFSIZ)
+            if len(data) == 0:
+              break
+            self.send(data, binary=True)
+      except Exception:
+        pass
+      finally:
+        self._sock.close()
+        self.close()
+
+    t = threading.Thread(target=_FeedInput)
+    t.daemon = True
+    t.start()
+
+  def closed(self, code, reason=None):
+    self._stop.set()
+    sys.exit(0)
+
+  def received_message(self, msg):
+    if msg.is_binary:
+      self._sock.send(msg.data)
+
+
+def Arg(*args, **kwargs):
+  return (args, kwargs)
+
+
+def Command(command, help_msg=None, args=None):
+  """Decorator for adding argparse parameter for a method."""
+  if args is None:
+    args = []
+  def WrapFunc(func):
+    def Wrapped(*args, **kwargs):
+      return func(*args, **kwargs)
+    # pylint: disable=W0212
+    Wrapped.__arg_attr = {'command': command, 'help': help_msg, 'args': args}
+    return Wrapped
+  return WrapFunc
+
+
+def ParseMethodSubCommands(cls):
+  """Decorator for a class using the @Command decorator.
+
+  This decorator retrieve command info from each method and append it in to the
+  SUBCOMMANDS class variable, which is later used to construct parser.
+  """
+  for unused_key, method in cls.__dict__.iteritems():
+    if hasattr(method, '__arg_attr'):
+      cls.SUBCOMMANDS.append(method.__arg_attr)  # pylint: disable=W0212
+  return cls
+
+
+@ParseMethodSubCommands
+class OverlordCLIClient(object):
+  """Overlord command line interface client."""
+
+  SUBCOMMANDS = []
+
+  def __init__(self):
+    self._parser = self._BuildParser()
+    self._selected_mid = None
+    self._server = None
+    self._state = None
+
+  def _BuildParser(self):
+    root_parser = argparse.ArgumentParser(prog='ovl')
+    subparsers = root_parser.add_subparsers(help='sub-command')
+
+    root_parser.add_argument('-s', dest='selected_mid', action='store',
+                             default=None,
+                             help='select target to execute command on')
+    root_parser.add_argument('-S', dest='select_mid_before_action',
+                             action='store_true', default=False,
+                             help='select target before executing command')
+
+    for attr in self.SUBCOMMANDS:
+      parser = subparsers.add_parser(attr['command'], help=attr['help'])
+      parser.set_defaults(which=attr['command'])
+      for arg in attr['args']:
+        parser.add_argument(*arg[0], **arg[1])
+
+    return root_parser
+
+  def Main(self):
+    # We want to pass the rest of arguments after shell command directly to the
+    # function without parsing it.
+    try:
+      index = sys.argv.index('shell')
+    except ValueError:
+      args = self._parser.parse_args()
+    else:
+      args = self._parser.parse_args(sys.argv[1:index + 1])
+
+    command = args.which
+    self._selected_mid = args.selected_mid
+
+    if command == 'kill-server':
+      self.KillServer()
+      return
+
+    self.StartDaemon()
+    if command == 'status':
+      self.Status()
+      return
+    elif command == 'connect':
+      self.Connect(args)
+      return
+
+    # The following command requires connection to the server
+    self.CheckConnection()
+
+    if args.select_mid_before_action:
+      self.SelectClient(store=False)
+
+    if command == 'select':
+      self.SelectClient(args)
+    elif command == 'ls':
+      self.ListClients()
+    elif command == 'shell':
+      command = sys.argv[sys.argv.index('shell') + 1:]
+      self.Shell(command)
+    elif command == 'push':
+      self.Push(args)
+    elif command == 'pull':
+      self.Pull(args)
+    elif command == 'forward':
+      self.Forward(args)
+
+  def _UrlOpen(self, url):
+    url = MakeRequestUrl(self._state, url)
+    request = urllib2.Request(url)
+    if self._state.username is not None and self._state.password is not None:
+      request.add_header(*BasicAuthHeader(self._state.username,
+                                          self._state.password))
+    return urllib2.urlopen(request)
+
+  def _HTTPPostFile(self, url, filename, progress=None, user=None, passwd=None):
+    """Perform HTTP POST and upload file to Overlord.
+
+    To minimize the external dependencies, we construct the HTTP post request
+    by ourselves.
+    """
+    url = MakeRequestUrl(self._state, url)
+    size = os.stat(filename).st_size
+    boundary = '-----------%s' % _HTTP_BOUNDARY_MAGIC
+    CRLF = '\r\n'
+    parse = urlparse.urlparse(url)
+
+    part_headers = [
+        '--' + boundary,
+        'Content-Disposition: form-data; name="file"; '
+        'filename="%s"' % os.path.basename(filename),
+        'Content-Type: application/octet-stream',
+        '', ''
+    ]
+    part_header = CRLF.join(part_headers)
+    end_part = CRLF + '--' + boundary + '--' + CRLF
+
+    content_length = len(part_header) + size + len(end_part)
+    if parse.scheme == 'http':
+      h = httplib.HTTP(parse.netloc)
+    else:
+      h = httplib.HTTPS(parse.netloc)
+
+    post_path = url[url.index(parse.netloc) + len(parse.netloc):]
+    h.putrequest('POST', post_path)
+    h.putheader('Content-Length', content_length)
+    h.putheader('Content-Type', 'multipart/form-data; boundary=%s' % boundary)
+
+    if user and passwd:
+      h.putheader(*BasicAuthHeader(user, passwd))
+    h.endheaders()
+    h.send(part_header)
+
+    count = 0
+    with open(filename, 'r') as f:
+      while True:
+        data = f.read(_BUFSIZ)
+        if not data:
+          break
+        count += len(data)
+        if progress:
+          progress(int(count * 100.0 / size), count)
+        h.send(data)
+
+    h.send(end_part)
+    progress(100)
+
+    if count != size:
+      logging.warning('file changed during upload, upload may be truncated.')
+
+    errcode, unused_x, unused_y = h.getreply()
+    return errcode == 200
+
+  def StartDaemon(self):
+    self._server = OverlordClientDaemon.GetRPCServer()
+    if self._server is None:
+      print('* daemon not running, starting it now on port %d ... *' %
+            _OVERLORD_CLIENT_DAEMON_PORT)
+      OverlordClientDaemon().Start()
+      time.sleep(1)
+      self._server = OverlordClientDaemon.GetRPCServer()
+      if self._server is not None:
+        print('* daemon started successfully *')
+
+    self._state = self._server.State()
+    sha1sum = GetVersionDigest()
+
+    if sha1sum != self._state.version_sha1sum:
+      print('ovl server is out of date.  killing...')
+      KillGraceful(self._server.GetPid())
+      self.StartDaemon()
+
+  def GetSSHControlFile(self, host):
+    return _SSH_CONTROL_SOCKET_PREFIX + host
+
+  def SSHTunnel(self, user, host, port):
+    """SSH forward the remote overlord server.
+
+    Overlord server may not have port 9000 open to the public network, in such
+    case we can SSH forward the port to localhost.
+    """
+
+    control_file = self.GetSSHControlFile(host)
+    try:
+      os.unlink(control_file)
+    except Exception:
+      pass
+
+    subprocess.Popen([
+        'ssh', '-Nf',
+        '-M',  # Enable master mode
+        '-S', control_file,
+        '-L', '9000:localhost:9000',
+        '-p', str(port),
+        '%s%s' % (user + '@' if user else '', host)
+    ]).wait()
+
+    p = subprocess.Popen([
+        'ssh',
+        '-S', control_file,
+        '-O', 'check', host,
+    ], stderr=subprocess.PIPE)
+    unused_stdout, stderr = p.communicate()
+
+    s = re.search(r'pid=(\d+)', stderr)
+    if s:
+      return int(s.group(1))
+
+    raise RuntimeError('can not establish ssh connection')
+
+  def CheckConnection(self):
+    if self._state.host is None:
+      raise RuntimeError('not connected to any server, abort')
+
+    try:
+      self._server.Clients()
+    except Exception:
+      raise RuntimeError('remote server disconnected, abort')
+
+    if self._state.ssh_pid is not None:
+      ret = subprocess.Popen(['kill', '-0', str(self._state.ssh_pid)],
+                             stdout=subprocess.PIPE,
+                             stderr=subprocess.PIPE).wait()
+      if ret != 0:
+        raise RuntimeError('ssh tunnel disconnected, please re-connect')
+
+  def CheckClient(self):
+    if self._selected_mid is None:
+      if self._state.selected_mid is None:
+        raise RuntimeError('No client is selected')
+      self._selected_mid = self._state.selected_mid
+
+    if self._selected_mid not in self._server.Clients():
+      raise RuntimeError('client %s disappeared' % self._selected_mid)
+
+  def CheckOutput(self, command):
+    headers = []
+    if self._state.username is not None and self._state.password is not None:
+      headers.append(BasicAuthHeader(self._state.username,
+                                     self._state.password))
+
+    scheme = 'ws%s://' % ('s' if self._state.ssl else '')
+    sio = StringIO.StringIO()
+    ws = ShellWebSocketClient(sio,
+                              scheme + '%s:%d/api/agent/shell/%s?command=%s' %
+                              (self._state.host, self._state.port,
+                               self._selected_mid, urllib2.quote(command)),
+                              headers=headers)
+    ws.connect()
+    ws.run()
+    return sio.getvalue()
+
+  @Command('status', 'show Overlord connection status')
+  def Status(self):
+    if self._state.host is None:
+      print('Not connected to any host.')
+    else:
+      if self._state.ssh_pid is not None:
+        print('Connected to %s with SSH tunneling.' % self._state.orig_host)
+      else:
+        print('Connected to %s:%d.' % (self._state.host, self._state.port))
+
+    if self._selected_mid is None:
+      self._selected_mid = self._state.selected_mid
+
+    if self._selected_mid is None:
+      print('No client is selected.')
+    else:
+      print('Client %s selected.' % self._selected_mid)
+
+  @Command('connect', 'connect to Overlord server', [
+      Arg('host', metavar='HOST', type=str, default='localhost',
+          help='Overlord hostname/IP'),
+      Arg('port', metavar='PORT', type=int,
+          default=_OVERLORD_HTTP_PORT, help='Overlord port'),
+      Arg('-f', '--forward', dest='ssh_forward', default=False,
+          action='store_true',
+          help='connect with SSH forwarding to the host'),
+      Arg('-p', '--ssh-port', dest='ssh_port', default=22,
+          type=int, help='SSH server port for SSH forwarding'),
+      Arg('-l', '--ssh-login', dest='ssh_login', default='',
+          type=str, help='SSH server login name for SSH forwarding'),
+      Arg('-u', '--user', dest='user', default=None,
+          type=str, help='Overlord HTTP auth username'),
+      Arg('-w', '--passwd', dest='passwd', default=None, type=str,
+          help='Overlord HTTP auth password')])
+  def Connect(self, args):
+    ssh_pid = None
+    host = args.host
+    orig_host = args.host
+
+    if args.ssh_forward:
+      # Kill previous SSH tunnel
+      self.KillSSHTunnel()
+
+      ssh_pid = self.SSHTunnel(args.ssh_login, args.host, args.ssh_port)
+      host = 'localhost'
+
+    status = self._server.Connect(host, args.port, ssh_pid, args.user,
+                                  args.passwd, orig_host)
+    if status is not True:
+      if isinstance(status, int):
+        if status == 401:
+          msg = '401 Unauthorized'
+        else:
+          msg = 'HTTP %d' % status
+      else:
+        msg = status
+      print('can not connect to %s: %s' % (host, msg))
+
+  @Command('kill-server', 'kill overlord CLI client server')
+  def KillServer(self):
+    self._server = OverlordClientDaemon.GetRPCServer()
+    if self._server is None:
+      return
+
+    # Kill SSH Tunnel
+    self.KillSSHTunnel()
+
+    # Kill server daemon
+    KillGraceful(self._server.GetPid())
+
+  def KillSSHTunnel(self):
+    if self._state.ssh_pid is not None:
+      KillGraceful(self._state.ssh_pid)
+
+  @Command('ls', 'list all clients')
+  def ListClients(self):
+    for client in self._server.Clients():
+      print(client)
+
+  @Command('select', 'select default client', [
+      Arg('mid', metavar='mid', nargs='?', default=None)])
+  def SelectClient(self, args=None, store=True):
+    clients = self._server.Clients()
+
+    mid = args.mid if args is not None else None
+    if mid is None:
+      print('Select from the following clients:')
+      for i, client in enumerate(clients):
+        print('    %d. %s' % (i + 1, client))
+
+      print('\nSelection: ', end='')
+      try:
+        choice = int(raw_input()) - 1
+        mid = clients[choice]
+      except ValueError:
+        raise RuntimeError('select: invalid selection')
+      except IndexError:
+        raise RuntimeError('select: selection out of range')
+    else:
+      if mid not in clients:
+        raise RuntimeError('select: client %s does not exist' % mid)
+
+    self._selected_mid = mid
+    if store:
+      self._server.SelectClient(mid)
+      print('Client %s selected' % mid)
+
+  @Command('shell', 'open a shell or execute a shell command', [
+      Arg('command', metavar='CMD', nargs='?', help='command to execute')])
+  def Shell(self, command=None):
+    if command is None:
+      command = []
+    self.CheckClient()
+
+    headers = []
+    if self._state.username is not None and self._state.password is not None:
+      headers.append(BasicAuthHeader(self._state.username,
+                                     self._state.password))
+
+    scheme = 'ws%s://' % ('s' if self._state.ssl else '')
+    if len(command) == 0:
+      ws = TerminalWebSocketClient(self._selected_mid,
+                                   scheme + '%s:%d/api/agent/tty/%s' %
+                                   (self._state.host, self._state.port,
+                                    self._selected_mid), headers=headers)
+    else:
+      cmd = ' '.join(command)
+      ws = ShellWebSocketClient(sys.stdout,
+                                scheme + '%s:%d/api/agent/shell/%s?command=%s' %
+                                (self._state.host, self._state.port,
+                                 self._selected_mid, urllib2.quote(cmd)),
+                                headers=headers)
+    ws.connect()
+    ws.run()
+
+  @Command('push', 'push a file or directory to remote', [
+      Arg('src', metavar='SOURCE'),
+      Arg('dst', metavar='DESTINATION')])
+  def Push(self, args):
+    self.CheckClient()
+
+    if not os.path.exists(args.src):
+      raise RuntimeError('push: can not stat "%s": no such file or directory'
+                         % args.src)
+
+    if not os.access(args.src, os.R_OK):
+      raise RuntimeError('push: can not open "%s" for reading' % args.src)
+
+    def _push(src, dst):
+      src_base = os.path.basename(src)
+      mode = '0%o' % (0x1FF & os.stat(src).st_mode)
+      url = ('%s:%d/api/agent/upload/%s?dest=%s&perm=%s' %
+             (self._state.host, self._state.port, self._selected_mid, dst,
+              mode))
+      try:
+        self._UrlOpen(url + '&filename=%s' % src_base)
+      except urllib2.HTTPError as e:
+        msg = json.loads(e.read()).get('error', None)
+        raise RuntimeError('push: %s' % msg)
+
+      pbar = ProgressBar(src_base)
+      self._HTTPPostFile(url, src, pbar.SetProgress,
+                         self._state.username, self._state.password)
+      pbar.End()
+
+    if os.path.isdir(args.src):
+      dst_exists = ast.literal_eval(self.CheckOutput(
+          'stat %s >/dev/null 2>&1 && echo True || echo False' % args.dst))
+      for root, unused_x, files in os.walk(args.src):
+        # If destination directory does not exist, we should strip the first
+        # layer of directory. For example: src_dir contains a single file 'A'
+        #
+        # push src_dir dest_dir
+        #
+        # If dest_dir exists, the resulting directory structure should be:
+        #   dest_dir/src_dir/A
+        # If dest_dir does not exist, the resulting directory structure should
+        # be:
+        #   dest_dir/A
+        dst_root = root if dst_exists else root[len(args.src):].lstrip('/')
+        for name in files:
+          _push(os.path.join(root, name),
+                os.path.join(args.dst, dst_root, name))
+    else:
+      _push(args.src, args.dst)
+
+  @Command('pull', 'pull a file or directory from remote', [
+      Arg('src', metavar='SOURCE'),
+      Arg('dst', metavar='DESTINATION', default='.', nargs='?')])
+  def Pull(self, args):
+    self.CheckClient()
+
+    def _pull(src, dst, perm=0644):
+      url = ('%s:%d/api/agent/download/%s?filename=%s' %
+             (self._state.host, self._state.port, self._selected_mid,
+              urllib2.quote(src)))
+      try:
+        h = self._UrlOpen(url)
+      except urllib2.HTTPError as e:
+        msg = json.loads(e.read()).get('error', 'unkown error')
+        raise RuntimeError('pull: %s' % msg)
+      except KeyboardInterrupt:
+        return
+
+      try:
+        os.makedirs(os.path.dirname(dst))
+      except Exception:
+        pass
+
+      pbar = ProgressBar(os.path.basename(src))
+      with open(dst, 'w') as f:
+        os.fchmod(f.fileno(), perm)
+        total_size = int(h.headers.get('Content-Length'))
+        downloaded_size = 0
+        while True:
+          data = h.read(_BUFSIZ)
+          downloaded_size += len(data)
+          pbar.SetProgress(float(downloaded_size) * 100 / total_size,
+                           downloaded_size)
+          if len(data) == 0:
+            break
+          f.write(data)
+      pbar.End()
+
+    # Use find to get a listing of all files under a root directory. The 'stat'
+    # command is used to retrieve the filename and it's filemode.
+    output = self.CheckOutput(
+        'cd $HOME; '
+        'stat "%(src)s" >/dev/null && '
+        'find "%(src)s" \'(\' -type f -o -type l \')\' -printf \'%%m\t%%p\n\''
+        % {'src': args.src})
+
+    # We got error from the stat command
+    if output.startswith('stat: '):
+      sys.stderr.write(output)
+      return
+
+    entries = output.strip().split('\n')
+    common_prefix = os.path.dirname(args.src)
+
+    if len(entries) == 1:
+      entry = entries[0]
+      perm, src_path = entry.split('\t')
+      if os.path.isdir(args.dst):
+        dst = os.path.join(args.dst, os.path.basename(src_path))
+      else:
+        dst = args.dst
+      _pull(src_path, dst, int(perm, base=8))
+    else:
+      if not os.path.exists(args.dst):
+        common_prefix = args.src
+
+      for entry in entries:
+        perm, src_path = entry.split('\t')
+        rel_dst = src_path[len(common_prefix):].lstrip('/')
+        _pull(src_path, os.path.join(args.dst, rel_dst), int(perm, base=8))
+
+  @Command('forward', 'forward remote port to local port', [
+      Arg('--list', dest='list_all', action='store_true', default=False,
+          help='list all port forwarding sessions'),
+      Arg('--remove', metavar='LOCAL_PORT', dest='remove', type=int,
+          default=None,
+          help='remove port forwarding for local port LOCAL_PORT'),
+      Arg('--remove-all', dest='remove_all', action='store_true',
+          default=False, help='remove all port forwarding'),
+      Arg('remote', metavar='REMOTE_PORT', type=int, nargs='?'),
+      Arg('local', metavar='LOCAL_PORT', type=int, nargs='?')])
+  def Forward(self, args):
+    if args.list_all:
+      max_len = 10
+      if len(self._state.forwards):
+        max_len = max([len(v[0]) for v in self._state.forwards.values()])
+
+      print('%-*s   %-8s  %-8s' % (max_len, 'Client', 'Remote', 'Local'))
+      for local in sorted(self._state.forwards.keys()):
+        value = self._state.forwards[local]
+        print('%-*s   %-8s  %-8s' % (max_len, value[0], value[1], local))
+      return
+
+    if args.remove_all:
+      self._server.RemoveAllForward()
+      return
+
+    if args.remove:
+      self._server.RemoveForward(args.remove)
+      return
+
+    self.CheckClient()
+
+    if args.local is None:
+      args.local = args.remote
+    remote = int(args.remote)
+    local = int(args.local)
+
+    def HandleConnection(conn):
+      headers = []
+      if self._state.username is not None and self._state.password is not None:
+        headers.append(BasicAuthHeader(self._state.username,
+                                       self._state.password))
+
+      scheme = 'ws%s://' % ('s' if self._state.ssl else '')
+      ws = ForwarderWebSocketClient(
+          conn,
+          scheme + '%s:%d/api/agent/forward/%s?port=%d' %
+          (self._state.host, self._state.port, self._selected_mid, remote),
+          headers=headers)
+      try:
+        ws.connect()
+        ws.run()
+      except Exception as e:
+        print('error: %s' % e)
+      finally:
+        ws.close()
+
+    server = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
+    server.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
+    server.bind(('0.0.0.0', local))
+    server.listen(5)
+
+    pid = os.fork()
+    if pid == 0:
+      while True:
+        conn, unused_addr = server.accept()
+        t = threading.Thread(target=HandleConnection, args=(conn,))
+        t.daemon = True
+        t.start()
+    else:
+      self._server.AddForward(self._selected_mid, remote, local, pid)
+
+
+def main():
+  logging.basicConfig(level=logging.INFO)
+
+  # Add DaemonState to JSONRPC lib classes
+  Config.instance().classes.add(DaemonState)
+
+  ovl = OverlordCLIClient()
+  try:
+    ovl.Main()
+  except KeyboardInterrupt:
+    print('Ctrl-C received, abort')
+  except Exception as e:
+    print('error: %s' % e)
+
+
+if __name__ == '__main__':
+  main()