overlord: implement file download function
Implement file download function to allow download files from DUT.
BUG=chromium:465674
TEST=manually
Change-Id: I49eca5c413a5328d48982953ea9d502fefb3b53d
Reviewed-on: https://chromium-review.googlesource.com/275043
Tested-by: Wei-Ning Huang <wnhuang@chromium.org>
Reviewed-by: Hsu Wei-Cheng <mojahsu@chromium.org>
Commit-Queue: Wei-Ning Huang <wnhuang@chromium.org>
diff --git a/py/tools/ghost.py b/py/tools/ghost.py
index 38614db..6b3b027 100755
--- a/py/tools/ghost.py
+++ b/py/tools/ghost.py
@@ -13,6 +13,7 @@
import Queue
import re
import select
+import signal
import socket
import subprocess
import sys
@@ -43,6 +44,8 @@
_CONTROL_START = 128
_CONTROL_END = 129
+_BLOCK_SIZE = 4096
+
RESPONSE_SUCCESS = 'success'
RESPONSE_FAILED = 'failed'
@@ -60,20 +63,21 @@
Ghost provide terminal/shell/logcat functionality and manages the client
side connectivity.
"""
- NONE, AGENT, TERMINAL, SHELL, LOGCAT = range(5)
+ NONE, AGENT, TERMINAL, SHELL, LOGCAT, FILE = range(6)
MODE_NAME = {
NONE: 'NONE',
AGENT: 'Agent',
TERMINAL: 'Terminal',
SHELL: 'Shell',
- LOGCAT: 'Logcat'
+ LOGCAT: 'Logcat',
+ FILE: 'File'
}
RANDOM_MID = '##random_mid##'
- def __init__(self, overlord_addrs, mode=AGENT, mid=None, sid=None,
- command=None):
+ def __init__(self, overlord_addrs, mode=AGENT, mid=None, sid=None, bid=None,
+ command=None, file_op=None):
"""Constructor.
Args:
@@ -81,13 +85,18 @@
mode: client mode, either AGENT, SHELL or LOGCAT
mid: a str to set for machine ID. If mid equals Ghost.RANDOM_MID, machine
id is randomly generated.
- sid: session id. If the connection is requested by overlord, sid should
+ sid: session ID. If the connection is requested by overlord, sid should
be set to the corresponding session id assigned by overlord.
- shell: the command to execute when we are in SHELL mode.
+ bid: browser ID. Identifies the browser which started the session.
+ command: the command to execute when we are in SHELL mode.
+ file_op: a tuple (action, filepath). action is either 'download' or
+ 'upload'.
"""
- assert mode in [Ghost.AGENT, Ghost.TERMINAL, Ghost.SHELL]
+ assert mode in [Ghost.AGENT, Ghost.TERMINAL, Ghost.SHELL, Ghost.FILE]
if mode == Ghost.SHELL:
assert command is not None
+ if mode == Ghost.FILE:
+ assert file_op is not None
self._overlord_addrs = overlord_addrs
self._connected_addr = None
@@ -96,13 +105,17 @@
self._sock = None
self._machine_id = self.GetMachineID()
self._client_id = sid if sid is not None else str(uuid.uuid4())
+ self._browser_id = bid
self._properties = {}
self._shell_command = command
+ self._file_op = file_op
self._buf = ''
self._requests = {}
self._reset = threading.Event()
self._last_ping = 0
self._queue = Queue.Queue()
+ self._download_queue = Queue.Queue()
+ self._session_map = {} # Stores the mapping between ttyname and browser_id
def LoadPropertiesFromFile(self, filename):
try:
@@ -111,7 +124,7 @@
except Exception as e:
logging.exception('LoadPropertiesFromFile: ' + str(e))
- def SpawnGhost(self, mode, sid, command=None):
+ def SpawnGhost(self, mode, sid=None, bid=None, command=None, file_op=None):
"""Spawn a child ghost with specific mode.
Returns:
@@ -119,7 +132,8 @@
"""
pid = os.fork()
if pid == 0:
- g = Ghost([self._connected_addr], mode, Ghost.RANDOM_MID, sid, command)
+ g = Ghost([self._connected_addr], mode, Ghost.RANDOM_MID, sid, bid,
+ command, file_op)
g.Start()
sys.exit(0)
else:
@@ -191,7 +205,7 @@
import factory_common # pylint: disable=W0612
from cros.factory.test import event_log
with open(event_log.DEVICE_ID_PATH) as f:
- return f.read()
+ return f.read().strip()
except Exception:
pass
@@ -233,7 +247,7 @@
def SendRequest(self, name, args, handler=None,
timeout=_REQUEST_TIMEOUT_SECS):
if handler and not callable(handler):
- raise RequestError('Invalid requiest handler for msg "%s"' % name)
+ raise RequestError('Invalid request handler for msg "%s"' % name)
rid = str(uuid.uuid4())
msg = {'rid': rid, 'timeout': timeout, 'name': name, 'params': args}
@@ -263,9 +277,23 @@
pid, fd = os.forkpty()
if pid == 0:
+ # Register the mapping of browser_id and ttyname
+ ttyname = os.readlink('/proc/%d/fd/0' % os.getpid())
+ try:
+ server = GhostRPCServer()
+ server.RegisterTTY(self._browser_id, ttyname)
+ except Exception:
+ # If ghost is launched without RPC server, the call will fail but we
+ # can ignore it.
+ pass
+
+ # The directory that contains the current running ghost script
+ script_dir = os.path.dirname(os.path.abspath(sys.argv[0]))
+
env = os.environ.copy()
env['USER'] = os.getenv('USER', 'root')
env['HOME'] = os.getenv('HOME', '/root')
+ env['PATH'] = os.getenv('PATH') + ':%s' % script_dir
os.chdir(env['HOME'])
os.execve(_SHELL, [_SHELL], env)
else:
@@ -351,6 +379,31 @@
logging.info('SpawnShellServer: terminated')
sys.exit(0)
+ def InitiateFileOperation(self, _):
+ if self._file_op[0] == 'download':
+ size = os.stat(self._file_op[1]).st_size
+ self.SendRequest('request_to_download',
+ {'bid': self._browser_id,
+ 'filename': os.path.basename(self._file_op[1]),
+ 'size': size})
+
+ def StartDownloadServer(self):
+ logging.info('StartDownloadServer: started')
+
+ try:
+ with open(self._file_op[1], 'rb') as f:
+ while True:
+ data = f.read(_BLOCK_SIZE)
+ if len(data) == 0:
+ break
+ self._sock.send(data)
+ except Exception as e:
+ logging.error('StartDownloadServer: %s', e)
+ finally:
+ self._sock.close()
+
+ logging.info('StartDownloadServer: terminated')
+ sys.exit(0)
def Ping(self):
def timeout_handler(x):
@@ -362,12 +415,19 @@
def HandleRequest(self, msg):
if msg['name'] == 'terminal':
- self.SpawnGhost(self.TERMINAL, msg['params']['sid'])
+ self.SpawnGhost(self.TERMINAL, msg['params']['sid'],
+ bid=msg['params']['bid'])
self.SendResponse(msg, RESPONSE_SUCCESS)
elif msg['name'] == 'shell':
self.SpawnGhost(self.SHELL, msg['params']['sid'],
- msg['params']['command'])
+ command=msg['params']['command'])
self.SendResponse(msg, RESPONSE_SUCCESS)
+ elif msg['name'] == 'file_download':
+ self.SpawnGhost(self.FILE, msg['params']['sid'],
+ file_op=('download', msg['params']['filename']))
+ self.SendResponse(msg, RESPONSE_SUCCESS)
+ elif msg['name'] == 'clear_to_download':
+ self.StartDownloadServer()
def HandleResponse(self, response):
rid = str(response['rid'])
@@ -402,9 +462,17 @@
for rid in self._requests.keys()[:]:
request_time, timeout, handler = self._requests[rid]
if self.Timestamp() - request_time > timeout:
- handler(None)
+ if callable(handler):
+ handler(None)
+ else:
+ logging.error('Request %s timeout', rid)
del self._requests[rid]
+ def InitiateDownload(self):
+ ttyname, filename = self._download_queue.get()
+ bid = self._session_map[ttyname]
+ self.SpawnGhost(self.FILE, bid=bid, file_op=('download', filename))
+
def Listen(self):
try:
while True:
@@ -414,10 +482,14 @@
self._buf += self._sock.recv(_BUFSIZE)
self.ParseMessage()
- if self.Timestamp() - self._last_ping > _PING_INTERVAL:
+ if (self._mode == self.AGENT and
+ self.Timestamp() - self._last_ping > _PING_INTERVAL):
self.Ping()
self.ScanForTimeoutRequests()
+ if not self._download_queue.empty():
+ self.InitiateDownload()
+
if self._reset.is_set():
self.Reset()
break
@@ -455,7 +527,8 @@
handler = {
Ghost.AGENT: registered,
Ghost.TERMINAL: self.SpawnPTYServer,
- Ghost.SHELL: self.SpawnShellServer
+ Ghost.SHELL: self.SpawnShellServer,
+ Ghost.FILE: self.InitiateFileOperation,
}[self._mode]
# Machine ID may change if MAC address is used (USB-ethernet dongle
@@ -478,6 +551,12 @@
logging.info('Received reconnect request from RPC server, reconnecting...')
self._reset.set()
+ def AddToDownloadQueue(self, ttyname, filename):
+ self._download_queue.put((ttyname, filename))
+
+ def RegisterTTY(self, browser_id, ttyname):
+ self._session_map[ttyname] = browser_id
+
def StartLanDiscovery(self):
"""Start to listen to LAN discovery packet at
_OVERLORD_LAN_DISCOVERY_PORT."""
@@ -523,9 +602,12 @@
t.start()
def StartRPCServer(self):
+ logging.info("RPC Server: started")
rpc_server = SimpleJSONRPCServer((_DEFAULT_BIND_ADDRESS, _GHOST_RPC_PORT),
logRequests=False)
rpc_server.register_function(self.Reconnect, 'Reconnect')
+ rpc_server.register_function(self.RegisterTTY, 'RegisterTTY')
+ rpc_server.register_function(self.AddToDownloadQueue, 'AddToDownloadQueue')
t = threading.Thread(target=rpc_server.serve_forever)
t.daemon = True
t.start()
@@ -541,6 +623,9 @@
logging.info('MID: %s', self._machine_id)
logging.info('CID: %s', self._client_id)
+ # We don't care about child process's return code, not wait is needed.
+ signal.signal(signal.SIGCHLD, signal.SIG_IGN)
+
if lan_disc:
self.StartLanDiscovery()
@@ -575,6 +660,22 @@
return jsonrpclib.Server('http://localhost:%d' % _GHOST_RPC_PORT)
+def DownloadFile(filename):
+ filepath = os.path.abspath(filename)
+ if not os.path.exists(filepath):
+ logging.error("file `%s' does not exist", filename)
+ sys.exit(1)
+
+ # Check if we actually have permission to read the file
+ if not os.access(filepath, os.R_OK):
+ logging.error("can not open %s for reading", filepath)
+ sys.exit(1)
+
+ server = GhostRPCServer()
+ server.AddToDownloadQueue(os.ttyname(0), filepath)
+ sys.exit(0)
+
+
def main():
logger = logging.getLogger()
logger.setLevel(logging.INFO)
@@ -589,13 +690,19 @@
parser.add_argument('--no-rpc-server', dest='rpc_server',
action='store_false', default=True,
help='disable RPC server')
- parser.add_argument('--prop-file', dest='prop_file', type=str, default=None,
+ parser.add_argument('--prop-file', metavar='PROP_FILE', dest="prop_file",
+ type=str, default=None,
help='file containing the JSON representation of client '
'properties')
+ parser.add_argument('--download', metavar='FILE', dest='download', type=str,
+ default=None, help='file to download')
parser.add_argument('overlord_ip', metavar='OVERLORD_IP', type=str,
nargs='*', help='overlord server address')
args = parser.parse_args()
+ if args.download:
+ DownloadFile(args.download)
+
addrs = [('localhost', _OVERLORD_PORT)]
addrs += [(x, _OVERLORD_PORT) for x in args.overlord_ip]