overlord: implement file download function

Implement file download function to allow download files from DUT.

BUG=chromium:465674
TEST=manually

Change-Id: I49eca5c413a5328d48982953ea9d502fefb3b53d
Reviewed-on: https://chromium-review.googlesource.com/275043
Tested-by: Wei-Ning Huang <wnhuang@chromium.org>
Reviewed-by: Hsu Wei-Cheng <mojahsu@chromium.org>
Commit-Queue: Wei-Ning Huang <wnhuang@chromium.org>
diff --git a/py/tools/ghost.py b/py/tools/ghost.py
index 38614db..6b3b027 100755
--- a/py/tools/ghost.py
+++ b/py/tools/ghost.py
@@ -13,6 +13,7 @@
 import Queue
 import re
 import select
+import signal
 import socket
 import subprocess
 import sys
@@ -43,6 +44,8 @@
 _CONTROL_START = 128
 _CONTROL_END = 129
 
+_BLOCK_SIZE = 4096
+
 RESPONSE_SUCCESS = 'success'
 RESPONSE_FAILED = 'failed'
 
@@ -60,20 +63,21 @@
   Ghost provide terminal/shell/logcat functionality and manages the client
   side connectivity.
   """
-  NONE, AGENT, TERMINAL, SHELL, LOGCAT = range(5)
+  NONE, AGENT, TERMINAL, SHELL, LOGCAT, FILE = range(6)
 
   MODE_NAME = {
       NONE: 'NONE',
       AGENT: 'Agent',
       TERMINAL: 'Terminal',
       SHELL: 'Shell',
-      LOGCAT: 'Logcat'
+      LOGCAT: 'Logcat',
+      FILE: 'File'
       }
 
   RANDOM_MID = '##random_mid##'
 
-  def __init__(self, overlord_addrs, mode=AGENT, mid=None, sid=None,
-               command=None):
+  def __init__(self, overlord_addrs, mode=AGENT, mid=None, sid=None, bid=None,
+               command=None, file_op=None):
     """Constructor.
 
     Args:
@@ -81,13 +85,18 @@
       mode: client mode, either AGENT, SHELL or LOGCAT
       mid: a str to set for machine ID. If mid equals Ghost.RANDOM_MID, machine
         id is randomly generated.
-      sid: session id. If the connection is requested by overlord, sid should
+      sid: session ID. If the connection is requested by overlord, sid should
         be set to the corresponding session id assigned by overlord.
-      shell: the command to execute when we are in SHELL mode.
+      bid: browser ID. Identifies the browser which started the session.
+      command: the command to execute when we are in SHELL mode.
+      file_op: a tuple (action, filepath). action is either 'download' or
+        'upload'.
     """
-    assert mode in [Ghost.AGENT, Ghost.TERMINAL, Ghost.SHELL]
+    assert mode in [Ghost.AGENT, Ghost.TERMINAL, Ghost.SHELL, Ghost.FILE]
     if mode == Ghost.SHELL:
       assert command is not None
+    if mode == Ghost.FILE:
+      assert file_op is not None
 
     self._overlord_addrs = overlord_addrs
     self._connected_addr = None
@@ -96,13 +105,17 @@
     self._sock = None
     self._machine_id = self.GetMachineID()
     self._client_id = sid if sid is not None else str(uuid.uuid4())
+    self._browser_id = bid
     self._properties = {}
     self._shell_command = command
+    self._file_op = file_op
     self._buf = ''
     self._requests = {}
     self._reset = threading.Event()
     self._last_ping = 0
     self._queue = Queue.Queue()
+    self._download_queue = Queue.Queue()
+    self._session_map = {}  # Stores the mapping between ttyname and browser_id
 
   def LoadPropertiesFromFile(self, filename):
     try:
@@ -111,7 +124,7 @@
     except Exception as e:
       logging.exception('LoadPropertiesFromFile: ' + str(e))
 
-  def SpawnGhost(self, mode, sid, command=None):
+  def SpawnGhost(self, mode, sid=None, bid=None, command=None, file_op=None):
     """Spawn a child ghost with specific mode.
 
     Returns:
@@ -119,7 +132,8 @@
     """
     pid = os.fork()
     if pid == 0:
-      g = Ghost([self._connected_addr], mode, Ghost.RANDOM_MID, sid, command)
+      g = Ghost([self._connected_addr], mode, Ghost.RANDOM_MID, sid, bid,
+                command, file_op)
       g.Start()
       sys.exit(0)
     else:
@@ -191,7 +205,7 @@
       import factory_common  # pylint: disable=W0612
       from cros.factory.test import event_log
       with open(event_log.DEVICE_ID_PATH) as f:
-        return f.read()
+        return f.read().strip()
     except Exception:
       pass
 
@@ -233,7 +247,7 @@
   def SendRequest(self, name, args, handler=None,
                   timeout=_REQUEST_TIMEOUT_SECS):
     if handler and not callable(handler):
-      raise RequestError('Invalid requiest handler for msg "%s"' % name)
+      raise RequestError('Invalid request handler for msg "%s"' % name)
 
     rid = str(uuid.uuid4())
     msg = {'rid': rid, 'timeout': timeout, 'name': name, 'params': args}
@@ -263,9 +277,23 @@
 
     pid, fd = os.forkpty()
     if pid == 0:
+      # Register the mapping of browser_id and ttyname
+      ttyname = os.readlink('/proc/%d/fd/0' % os.getpid())
+      try:
+        server = GhostRPCServer()
+        server.RegisterTTY(self._browser_id, ttyname)
+      except Exception:
+        # If ghost is launched without RPC server, the call will fail but we
+        # can ignore it.
+        pass
+
+      # The directory that contains the current running ghost script
+      script_dir = os.path.dirname(os.path.abspath(sys.argv[0]))
+
       env = os.environ.copy()
       env['USER'] = os.getenv('USER', 'root')
       env['HOME'] = os.getenv('HOME', '/root')
+      env['PATH'] = os.getenv('PATH') + ':%s' % script_dir
       os.chdir(env['HOME'])
       os.execve(_SHELL, [_SHELL], env)
     else:
@@ -351,6 +379,31 @@
       logging.info('SpawnShellServer: terminated')
       sys.exit(0)
 
+  def InitiateFileOperation(self, _):
+    if self._file_op[0] == 'download':
+      size = os.stat(self._file_op[1]).st_size
+      self.SendRequest('request_to_download',
+                       {'bid': self._browser_id,
+                        'filename': os.path.basename(self._file_op[1]),
+                        'size': size})
+
+  def StartDownloadServer(self):
+    logging.info('StartDownloadServer: started')
+
+    try:
+      with open(self._file_op[1], 'rb') as f:
+        while True:
+          data = f.read(_BLOCK_SIZE)
+          if len(data) == 0:
+            break
+          self._sock.send(data)
+    except Exception as e:
+      logging.error('StartDownloadServer: %s', e)
+    finally:
+      self._sock.close()
+
+    logging.info('StartDownloadServer: terminated')
+    sys.exit(0)
 
   def Ping(self):
     def timeout_handler(x):
@@ -362,12 +415,19 @@
 
   def HandleRequest(self, msg):
     if msg['name'] == 'terminal':
-      self.SpawnGhost(self.TERMINAL, msg['params']['sid'])
+      self.SpawnGhost(self.TERMINAL, msg['params']['sid'],
+                      bid=msg['params']['bid'])
       self.SendResponse(msg, RESPONSE_SUCCESS)
     elif msg['name'] == 'shell':
       self.SpawnGhost(self.SHELL, msg['params']['sid'],
-                      msg['params']['command'])
+                      command=msg['params']['command'])
       self.SendResponse(msg, RESPONSE_SUCCESS)
+    elif msg['name'] == 'file_download':
+      self.SpawnGhost(self.FILE, msg['params']['sid'],
+                      file_op=('download', msg['params']['filename']))
+      self.SendResponse(msg, RESPONSE_SUCCESS)
+    elif msg['name'] == 'clear_to_download':
+      self.StartDownloadServer()
 
   def HandleResponse(self, response):
     rid = str(response['rid'])
@@ -402,9 +462,17 @@
     for rid in self._requests.keys()[:]:
       request_time, timeout, handler = self._requests[rid]
       if self.Timestamp() - request_time > timeout:
-        handler(None)
+        if callable(handler):
+          handler(None)
+        else:
+          logging.error('Request %s timeout', rid)
         del self._requests[rid]
 
+  def InitiateDownload(self):
+    ttyname, filename = self._download_queue.get()
+    bid = self._session_map[ttyname]
+    self.SpawnGhost(self.FILE, bid=bid, file_op=('download', filename))
+
   def Listen(self):
     try:
       while True:
@@ -414,10 +482,14 @@
           self._buf += self._sock.recv(_BUFSIZE)
           self.ParseMessage()
 
-        if self.Timestamp() - self._last_ping > _PING_INTERVAL:
+        if (self._mode == self.AGENT and
+            self.Timestamp() - self._last_ping > _PING_INTERVAL):
           self.Ping()
         self.ScanForTimeoutRequests()
 
+        if not self._download_queue.empty():
+          self.InitiateDownload()
+
         if self._reset.is_set():
           self.Reset()
           break
@@ -455,7 +527,8 @@
         handler = {
             Ghost.AGENT: registered,
             Ghost.TERMINAL: self.SpawnPTYServer,
-            Ghost.SHELL: self.SpawnShellServer
+            Ghost.SHELL: self.SpawnShellServer,
+            Ghost.FILE: self.InitiateFileOperation,
             }[self._mode]
 
         # Machine ID may change if MAC address is used (USB-ethernet dongle
@@ -478,6 +551,12 @@
     logging.info('Received reconnect request from RPC server, reconnecting...')
     self._reset.set()
 
+  def AddToDownloadQueue(self, ttyname, filename):
+    self._download_queue.put((ttyname, filename))
+
+  def RegisterTTY(self, browser_id, ttyname):
+    self._session_map[ttyname] = browser_id
+
   def StartLanDiscovery(self):
     """Start to listen to LAN discovery packet at
     _OVERLORD_LAN_DISCOVERY_PORT."""
@@ -523,9 +602,12 @@
     t.start()
 
   def StartRPCServer(self):
+    logging.info("RPC Server: started")
     rpc_server = SimpleJSONRPCServer((_DEFAULT_BIND_ADDRESS, _GHOST_RPC_PORT),
                                      logRequests=False)
     rpc_server.register_function(self.Reconnect, 'Reconnect')
+    rpc_server.register_function(self.RegisterTTY, 'RegisterTTY')
+    rpc_server.register_function(self.AddToDownloadQueue, 'AddToDownloadQueue')
     t = threading.Thread(target=rpc_server.serve_forever)
     t.daemon = True
     t.start()
@@ -541,6 +623,9 @@
     logging.info('MID: %s', self._machine_id)
     logging.info('CID: %s', self._client_id)
 
+    # We don't care about child process's return code, not wait is needed.
+    signal.signal(signal.SIGCHLD, signal.SIG_IGN)
+
     if lan_disc:
       self.StartLanDiscovery()
 
@@ -575,6 +660,22 @@
   return jsonrpclib.Server('http://localhost:%d' % _GHOST_RPC_PORT)
 
 
+def DownloadFile(filename):
+  filepath = os.path.abspath(filename)
+  if not os.path.exists(filepath):
+    logging.error("file `%s' does not exist", filename)
+    sys.exit(1)
+
+  # Check if we actually have permission to read the file
+  if not os.access(filepath, os.R_OK):
+    logging.error("can not open %s for reading", filepath)
+    sys.exit(1)
+
+  server = GhostRPCServer()
+  server.AddToDownloadQueue(os.ttyname(0), filepath)
+  sys.exit(0)
+
+
 def main():
   logger = logging.getLogger()
   logger.setLevel(logging.INFO)
@@ -589,13 +690,19 @@
   parser.add_argument('--no-rpc-server', dest='rpc_server',
                       action='store_false', default=True,
                       help='disable RPC server')
-  parser.add_argument('--prop-file', dest='prop_file', type=str, default=None,
+  parser.add_argument('--prop-file', metavar='PROP_FILE', dest="prop_file",
+                      type=str, default=None,
                       help='file containing the JSON representation of client '
                            'properties')
+  parser.add_argument('--download', metavar='FILE', dest='download', type=str,
+                      default=None, help='file to download')
   parser.add_argument('overlord_ip', metavar='OVERLORD_IP', type=str,
                       nargs='*', help='overlord server address')
   args = parser.parse_args()
 
+  if args.download:
+    DownloadFile(args.download)
+
   addrs = [('localhost', _OVERLORD_PORT)]
   addrs += [(x, _OVERLORD_PORT) for x in args.overlord_ip]