overlord: add python implementation of ghost

Python implementation is much smaller and thus save spaces.

BUG=chromium:443972
TEST=none

Change-Id: I3d9cae83d34bc02996d47597defde14d30ec21f5
Reviewed-on: https://chromium-review.googlesource.com/254952
Reviewed-by: Hung-Te Lin <hungte@chromium.org>
Commit-Queue: Wei-Ning Huang <wnhuang@chromium.org>
Tested-by: Wei-Ning Huang <wnhuang@chromium.org>
diff --git a/py/tools/ghost.py b/py/tools/ghost.py
new file mode 100755
index 0000000..e784ce4
--- /dev/null
+++ b/py/tools/ghost.py
@@ -0,0 +1,479 @@
+#!/usr/bin/python -u
+# -*- coding: utf-8 -*-
+#
+# Copyright 2015 The Chromium OS Authors. All rights reserved.
+# Use of this source code is governed by a BSD-style license that can be
+# found in the LICENSE file.
+
+import fcntl
+import json
+import logging
+import os
+import Queue
+import select
+import socket
+import subprocess
+import sys
+import threading
+import time
+import uuid
+
+
+_OVERLORD_PORT = 4455
+_OVERLORD_LAN_DISCOVERY_PORT = 4456
+
+_BUFSIZE = 8192
+_RETRY_INTERVAL = 2
+_SEPARATOR = '\r\n'
+_PING_TIMEOUT = 3
+_PING_INTERVAL = 5
+_REQUEST_TIMEOUT_SECS = 60
+_SHELL = os.getenv('SHELL', '/bin/bash')
+
+RESPONSE_SUCCESS = 'success'
+RESPONSE_FAILED = 'failed'
+
+
+class PingTimeoutError(Exception):
+  pass
+
+
+class RequestError(Exception):
+  pass
+
+
+class Ghost(object):
+  """Ghost implements the client protocol of Overlord.
+
+  Ghost provide terminal/shell/logcat functionality and manages the client
+  side connectivity.
+  """
+  NONE, AGENT, SHELL, LOGCAT, SLOGCAT = range(5)
+
+  MODE_NAME = {
+      NONE: 'NONE',
+      AGENT: 'Agent',
+      SHELL: 'Shell',
+      LOGCAT: 'Logcat',
+      SLOGCAT: 'Simple-Logcat'
+      }
+
+  def __init__(self, overlord_addrs, mode=AGENT, sid=None, filename=None):
+    """Constructor.
+
+    Args:
+      overlord_addrs: a list of possible address of overlord.
+      mode: client mode, either AGENT, SHELL or LOGCAT
+      sid: session id. If the connection is requested by overlord, sid should
+        be set to the corresponding session id assigned by overlord.
+      filename: the filename to cat when we are in LOGCAT mode.
+    """
+    assert mode in [Ghost.AGENT, Ghost.SHELL, Ghost.LOGCAT]
+    if mode == Ghost.LOGCAT:
+      assert filename is not None
+
+    self._overlord_addrs = overlord_addrs
+    self._mode = mode
+    self._sock = None
+    self._machine_id = self.GetMachineID()
+    self._client_id = sid if sid is not None else str(uuid.uuid4())
+    self._logcat_filename = filename
+    self._buf = ''
+    self._requests = {}
+    self._reset = False
+    self._last_ping = 0
+    self._queue = Queue.Queue()
+
+  def SpawnGhost(self, mode, sid, filename=None):
+    """Spawn a child ghost with specific mode.
+
+    Returns:
+      The spawned child process pid.
+    """
+    pid = os.fork()
+    if pid == 0:
+      g = Ghost(self._overlord_addrs, mode, sid, filename)
+      g.Start()
+      sys.exit(0)
+    else:
+      return pid
+
+  def Timestamp(self):
+    return int(time.time())
+
+  def GetGateWayIP(self):
+    with open('/proc/net/route', 'r') as f:
+      lines = f.readlines()
+
+    ips = []
+    for line in lines:
+      parts = line.split('\t')
+      if parts[2] == '00000000':
+        continue
+
+      try:
+        h = parts[2].decode('hex')
+        ips.append('%d.%d.%d.%d' % tuple(ord(x) for x in reversed(h)))
+      except TypeError:
+        pass
+
+    return ips
+
+  def GetMachineID(self):
+    """Generates machine-dependent ID string for a machine.
+    There are many ways to generate a machine ID:
+    1. factory device-data
+    2. /sys/class/dmi/id/product_uuid (only available on intel machines)
+    3. MAC address
+    We follow the listed order to generate machine ID, and fallback to the next
+    alternative if the previous doesn't work.
+    """
+    try:
+      p = subprocess.Popen('factory device-data | grep mlb_serial_number | '
+                           'cut -d " " -f 2', stdout=subprocess.PIPE,
+                           shell=True)
+      stdout, _ = p.communicate()
+      if stdout == '':
+        raise RuntimeError("empty mlb number")
+      return stdout.strip()
+    except Exception:
+      pass
+
+    try:
+      with open('/sys/class/dmi/id/product_uuid', 'r') as f:
+        return f.read().strip()
+    except Exception:
+      pass
+
+    try:
+      macs = []
+      ifaces = sorted(os.listdir('/sys/class/net'))
+      for iface in ifaces:
+        if iface == 'lo':
+          continue
+
+        with open('/sys/class/net/%s/address' % iface, 'r') as f:
+          macs.append(f.read().strip())
+
+      return ';'.join(macs)
+    except Exception:
+      pass
+
+    raise RuntimeError("can't generate machine ID")
+
+  def Reset(self):
+    """Reset state and clear request handlers."""
+    self._reset = False
+    self._buf = ""
+    self._last_ping = 0
+    self._requests = {}
+
+  def SendMessage(self, msg):
+    """Serialize the message and send it through the socket."""
+    self._sock.send(json.dumps(msg) + _SEPARATOR)
+
+  def SendRequest(self, name, args, handler=None,
+                  timeout=_REQUEST_TIMEOUT_SECS):
+    if handler and not callable(handler):
+      raise RequestError('Invalid requiest handler for msg "%s"' % name)
+
+    rid = str(uuid.uuid4())
+    msg = {'rid': rid, 'timeout': timeout, 'name': name, 'params': args}
+    self._requests[rid] = [self.Timestamp(), timeout, handler]
+    self.SendMessage(msg)
+
+  def SendResponse(self, omsg, status, params=None):
+    msg = {'rid': omsg['rid'], 'response': status, 'params': params}
+    self.SendMessage(msg)
+
+  def SpawnPTYServer(self, _):
+    """Spawn a PTY server and forward I/O to the TCP socket."""
+    logging.info('SpawnPTYServer: started')
+
+    pid, fd = os.forkpty()
+    if pid == 0:
+      env = os.environ.copy()
+      env['USER'] = os.getenv('USER', 'root')
+      env['HOME'] = os.getenv('HOME', '/root')
+      os.chdir(env['HOME'])
+      os.execve(_SHELL, [_SHELL], env)
+    else:
+      try:
+        while True:
+          rd, _, _ = select.select([self._sock, fd], [], [])
+
+          if fd in rd:
+            self._sock.send(os.read(fd, _BUFSIZE))
+
+          if self._sock in rd:
+            ret = self._sock.recv(_BUFSIZE)
+            if len(ret) == 0:
+              raise RuntimeError("socket closed")
+            os.write(fd, ret)
+      except (OSError, socket.error, RuntimeError):
+        self._sock.close()
+        logging.info('SpawnPTYServer: terminated')
+        sys.exit(0)
+
+  def SpawnLogcatServer(self, _):
+    """Spawn a Logcat server and forward output to the TCP socket."""
+    logging.info('SpawnLogcatServer: started')
+
+    p = subprocess.Popen('tail -n +0 -f "%s"' % self._logcat_filename,
+                         stdout=subprocess.PIPE, stderr=subprocess.PIPE,
+                         shell=True)
+
+    def make_non_block(fd):
+      fl = fcntl.fcntl(fd, fcntl.F_GETFL)
+      fcntl.fcntl(fd, fcntl.F_SETFL, fl | os.O_NONBLOCK)
+
+    make_non_block(p.stdout)
+    make_non_block(p.stderr)
+
+    try:
+      while True:
+        rd, _, _ = select.select([p.stdout, p.stderr, self._sock], [], [])
+
+        if p.stdout in rd:
+          self._sock.send(p.stdout.read(_BUFSIZE))
+
+        if p.stderr in rd:
+          self._sock.send(p.stderr.read(_BUFSIZE))
+
+        if self._sock in rd:
+          ret = self._sock.recv(_BUFSIZE)
+          if len(ret) == 0:
+            raise RuntimeError("socket closed")
+    except (OSError, socket.error, RuntimeError):
+      self._sock.close()
+      logging.info('SpawnLogcatServer: terminated')
+      sys.exit(0)
+
+
+  def Ping(self):
+    def timeout_handler(x):
+      if x is None:
+        raise PingTimeoutError
+
+    self._last_ping = self.Timestamp()
+    self.SendRequest('ping', {}, timeout_handler, 5)
+
+  def HandleShellRequest(self, msg):
+    params = msg['params']
+    stdout = stderr = err_msg = ""
+    try:
+      p = subprocess.Popen([params['cmd']] + params['args'],
+                           stdout=subprocess.PIPE, stderr=subprocess.PIPE)
+      stdout, stderr = p.communicate()
+    except Exception as e:
+      err_msg = str(e)
+
+    self.SendResponse(msg, RESPONSE_SUCCESS,
+                      {'output': stdout + stderr, 'err_msg': err_msg})
+
+  def HandleRequest(self, msg):
+    if msg['name'] == 'shell':
+      self.HandleShellRequest(msg)
+    elif msg['name'] == 'terminal':
+      self.SpawnGhost(self.SHELL, msg['params']['sid'])
+      self.SendResponse(msg, RESPONSE_SUCCESS)
+    elif msg['name'] == 'logcat':
+      self.SpawnGhost(self.LOGCAT, msg['params']['sid'],
+                      msg['params']['filename'])
+      self.SendResponse(msg, RESPONSE_SUCCESS)
+
+  def HandleResponse(self, response):
+    rid = str(response['rid'])
+    if rid in self._requests:
+      handler = self._requests[rid][2]
+      del self._requests[rid]
+      if callable(handler):
+        handler(response)
+    else:
+      print(response, self._requests.keys())
+      logging.warning('Recvied unsolicited response, ignored')
+
+  def ParseMessage(self):
+    msgs_json = self._buf.split(_SEPARATOR)
+    self._buf = msgs_json.pop()
+
+    for msg_json in msgs_json:
+      try:
+        msg = json.loads(msg_json)
+      except ValueError:
+        # Ignore mal-formed message.
+        continue
+
+      if 'name' in msg:
+        self.HandleRequest(msg)
+      elif 'response' in msg:
+        self.HandleResponse(msg)
+      else:  # Ingnore mal-formed message.
+        pass
+
+  def ScanForTimeoutRequests(self):
+    for rid in self._requests.keys()[:]:
+      request_time, timeout, handler = self._requests[rid]
+      if self.Timestamp() - request_time > timeout:
+        handler(None)
+        del self._requests[rid]
+
+  def Listen(self):
+    try:
+      while True:
+        rds, _, _ = select.select([self._sock], [], [], _PING_INTERVAL / 2)
+
+        if self._sock in rds:
+          self._buf += self._sock.recv(_BUFSIZE)
+          self.ParseMessage()
+
+        if self.Timestamp() - self._last_ping > _PING_INTERVAL:
+          self.Ping()
+        self.ScanForTimeoutRequests()
+
+        if self._reset:
+          self.Reset()
+          break
+    except socket.error:
+      raise RuntimeError('Connection dropped')
+    except PingTimeoutError:
+      raise RuntimeError('Connection timeout')
+    finally:
+      self._sock.close()
+
+    self._queue.put('resume')
+
+    if self._mode != Ghost.AGENT:
+      sys.exit(1)
+
+  def Register(self):
+    non_local = {}
+    for addr in self._overlord_addrs:
+      non_local['addr'] = addr
+      def registered(response):
+        if response is None:
+          self._reset = True
+          raise RuntimeError('Register request timeout')
+        logging.info('Registered with Overlord at %s:%d', *non_local['addr'])
+        self._queue.put("pause", True)
+
+      try:
+        logging.info('Trying %s:%d ...', *addr)
+        self.Reset()
+        self._sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
+        self._sock.settimeout(_PING_TIMEOUT)
+        self._sock.connect(addr)
+
+        logging.info('Connection established, registering...')
+        handler = {
+            Ghost.AGENT: registered,
+            Ghost.SHELL: self.SpawnPTYServer,
+            Ghost.LOGCAT: self.SpawnLogcatServer
+            }[self._mode]
+
+        # Machine ID may change if MAC address is used (USB-ethernet dongle
+        # plugged/unplugged)
+        self._machine_id = self.GetMachineID()
+        self.SendRequest('register', {'mode': self._mode, 'mid': self._machine_id,
+                                      'cid': self._client_id}, handler)
+      except socket.error:
+        pass
+      else:
+        self._sock.settimeout(None)
+        self.Listen()
+
+    raise RuntimeError("Cannot connect to any server")
+
+  def StartLanDiscovery(self):
+    """Start to listen to LAN discovery packet at
+    _OVERLORD_LAN_DISCOVERY_PORT."""
+    s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
+    s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
+    s.setsockopt(socket.SOL_SOCKET, socket.SO_BROADCAST, 1)
+    try:
+      s.bind(('0.0.0.0', _OVERLORD_LAN_DISCOVERY_PORT))
+    except socket.error as e:
+      logging.error("LAN discovery: %s, abort", e)
+      return
+
+    logging.info('LAN Discovery: started')
+    while True:
+      rd, _, _ = select.select([s], [], [], 1)
+
+      if s in rd:
+        data, source_addr = s.recvfrom(_BUFSIZE)
+        parts = data.split()
+        if parts[0] == 'OVERLORD':
+          ip = source_addr[0]
+          port = int(parts[1].lstrip(':'))
+          addr = (ip, port)
+          self._queue.put(addr, True)
+
+      try:
+        obj = self._queue.get(False)
+      except Queue.Empty:
+        pass
+      else:
+        if type(obj) is not str:
+          self._queue.put(obj)
+        elif obj == 'pause':
+          logging.info('LAN Discovery: paused')
+          while True:
+            obj = self._queue.get(True)
+            if obj == 'resume':
+              logging.info('LAN Discovery: resumed')
+              break
+
+  def ScanGateway(self):
+    for addr in [(x, _OVERLORD_PORT) for x in self.GetGateWayIP()]:
+      if addr not in self._overlord_addrs:
+        self._overlord_addrs.append(addr)
+
+  def Start(self):
+    logging.info('%s started', self.MODE_NAME[self._mode])
+    logging.info('MID: %s', self._machine_id)
+    logging.info('CID: %s', self._client_id)
+
+    if self._mode == Ghost.AGENT:
+      t = threading.Thread(target=self.StartLanDiscovery)
+      t.daemon = True
+      t.start()
+
+    try:
+      while True:
+        try:
+          addr = self._queue.get(False)
+        except Queue.Empty:
+          pass
+        else:
+          if type(addr) == tuple and addr not in self._overlord_addrs:
+            logging.info('LAN Discovery: got overlord address %s:%d', *addr)
+            self._overlord_addrs.append(addr)
+
+        try:
+          self.ScanGateway()
+          self.Register()
+        except Exception as e:
+          logging.info(str(e) + ', retrying in %ds' % _RETRY_INTERVAL)
+          time.sleep(_RETRY_INTERVAL)
+
+        self.Reset()
+    except KeyboardInterrupt:
+      logging.error('Received keyboard interrupt, quit')
+      sys.exit(0)
+
+
+def main():
+  logger = logging.getLogger()
+  logger.setLevel(logging.INFO)
+
+  addrs = [('localhost', _OVERLORD_PORT)]
+  if len(sys.argv) > 1:
+    addrs += [(x, _OVERLORD_PORT) for x in sys.argv[1:]]
+
+  g = Ghost(addrs)
+  g.Start()
+
+
+if __name__ == '__main__':
+  main()