ghost: update machine ID with mlb_serial_number once it's set
When mlb_number is set for the device in the factory test, ghost should
automatically update the machine ID to MLB number and reconnects with
overlord. We implement an JSON-RPC server in Ghost containing the
'Reconnect' method, which is later called by the 'scan' pytest.
BUG=chromium:470876
TEST=run test on device, the ghost client should reconnect with new MLB
number once the MLB number is set.
Change-Id: Idef9b16f77fcca432a3a9d3faac71d852813a90a
Reviewed-on: https://chromium-review.googlesource.com/265403
Tested-by: Wei-Ning Huang <wnhuang@chromium.org>
Reviewed-by: Hung-Te Lin <hungte@chromium.org>
Commit-Queue: Wei-Ning Huang <wnhuang@chromium.org>
diff --git a/py/tools/ghost.py b/py/tools/ghost.py
index de5bfb8..5d7006d 100755
--- a/py/tools/ghost.py
+++ b/py/tools/ghost.py
@@ -19,6 +19,11 @@
import time
import uuid
+import jsonrpclib
+from jsonrpclib.SimpleJSONRPCServer import SimpleJSONRPCServer
+
+
+_GHOST_RPC_PORT = 4499
_OVERLORD_PORT = 4455
_OVERLORD_LAN_DISCOVERY_PORT = 4456
@@ -30,6 +35,7 @@
_PING_INTERVAL = 5
_REQUEST_TIMEOUT_SECS = 60
_SHELL = os.getenv('SHELL', '/bin/bash')
+_DEFAULT_BIND_ADDRESS = '0.0.0.0'
RESPONSE_SUCCESS = 'success'
RESPONSE_FAILED = 'failed'
@@ -86,7 +92,7 @@
self._shell_command = command
self._buf = ''
self._requests = {}
- self._reset = False
+ self._reset = threading.Event()
self._last_ping = 0
self._queue = Queue.Queue()
@@ -106,7 +112,7 @@
pid = os.fork()
if pid == 0:
g = Ghost([self._connected_addr], mode, True, sid, command)
- g.Start(True)
+ g.Start()
sys.exit(0)
else:
return pid
@@ -179,7 +185,7 @@
def Reset(self):
"""Reset state and clear request handlers."""
- self._reset = False
+ self._reset.clear()
self._buf = ""
self._last_ping = 0
self._requests = {}
@@ -337,7 +343,7 @@
self.Ping()
self.ScanForTimeoutRequests()
- if self._reset:
+ if self._reset.is_set():
self.Reset()
break
except socket.error:
@@ -358,7 +364,7 @@
non_local['addr'] = addr
def registered(response):
if response is None:
- self._reset = True
+ self._reset.set()
raise RuntimeError('Register request timeout')
logging.info('Registered with Overlord at %s:%d', *non_local['addr'])
self._queue.put("pause", True)
@@ -393,60 +399,77 @@
raise RuntimeError("Cannot connect to any server")
+ def Reconnect(self):
+ logging.info('Received reconnect request from RPC server, reconnecting...')
+ self._reset.set()
+
def StartLanDiscovery(self):
"""Start to listen to LAN discovery packet at
_OVERLORD_LAN_DISCOVERY_PORT."""
- s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
- s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
- s.setsockopt(socket.SOL_SOCKET, socket.SO_BROADCAST, 1)
- try:
- s.bind(('0.0.0.0', _OVERLORD_LAN_DISCOVERY_PORT))
- except socket.error as e:
- logging.error("LAN discovery: %s, abort", e)
- return
- logging.info('LAN Discovery: started')
- while True:
- rd, _, _ = select.select([s], [], [], 1)
-
- if s in rd:
- data, source_addr = s.recvfrom(_BUFSIZE)
- parts = data.split()
- if parts[0] == 'OVERLORD':
- ip, port = parts[1].split(':')
- if len(ip.strip()) == 0:
- ip = source_addr[0]
- self._queue.put((ip, int(port)), True)
-
+ def thread_func():
+ s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
+ s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
+ s.setsockopt(socket.SOL_SOCKET, socket.SO_BROADCAST, 1)
try:
- obj = self._queue.get(False)
- except Queue.Empty:
- pass
- else:
- if type(obj) is not str:
- self._queue.put(obj)
- elif obj == 'pause':
- logging.info('LAN Discovery: paused')
- while True:
- obj = self._queue.get(True)
- if obj == 'resume':
- logging.info('LAN Discovery: resumed')
- break
+ s.bind(('0.0.0.0', _OVERLORD_LAN_DISCOVERY_PORT))
+ except socket.error as e:
+ logging.error("LAN discovery: %s, abort", e)
+ return
+
+ logging.info('LAN Discovery: started')
+ while True:
+ rd, _, _ = select.select([s], [], [], 1)
+
+ if s in rd:
+ data, source_addr = s.recvfrom(_BUFSIZE)
+ parts = data.split()
+ if parts[0] == 'OVERLORD':
+ ip, port = parts[1].split(':')
+ if not ip:
+ ip = source_addr[0]
+ self._queue.put((ip, int(port)), True)
+
+ try:
+ obj = self._queue.get(False)
+ except Queue.Empty:
+ pass
+ else:
+ if type(obj) is not str:
+ self._queue.put(obj)
+ elif obj == 'pause':
+ logging.info('LAN Discovery: paused')
+ while obj != 'resume':
+ obj = self._queue.get(True)
+ logging.info('LAN Discovery: resumed')
+
+ t = threading.Thread(target=thread_func)
+ t.daemon = True
+ t.start()
+
+ def StartRPCServer(self):
+ rpc_server = SimpleJSONRPCServer((_DEFAULT_BIND_ADDRESS, _GHOST_RPC_PORT),
+ logRequests=False)
+ rpc_server.register_function(self.Reconnect, 'Reconnect')
+ t = threading.Thread(target=rpc_server.serve_forever)
+ t.daemon = True
+ t.start()
def ScanGateway(self):
for addr in [(x, _OVERLORD_PORT) for x in self.GetGateWayIP()]:
if addr not in self._overlord_addrs:
self._overlord_addrs.append(addr)
- def Start(self, no_lan_disc=False):
+ def Start(self, lan_disc=False, rpc_server=False):
logging.info('%s started', self.MODE_NAME[self._mode])
logging.info('MID: %s', self._machine_id)
logging.info('CID: %s', self._client_id)
- if not no_lan_disc:
- t = threading.Thread(target=self.StartLanDiscovery)
- t.daemon = True
- t.start()
+ if lan_disc:
+ self.StartLanDiscovery()
+
+ if rpc_server:
+ self.StartRPCServer()
try:
while True:
@@ -472,6 +495,10 @@
sys.exit(0)
+def GhostRPCServer():
+ return jsonrpclib.Server('http://localhost:%d' % _GHOST_RPC_PORT)
+
+
def main():
logger = logging.getLogger()
logger.setLevel(logging.INFO)
@@ -479,8 +506,11 @@
parser = argparse.ArgumentParser()
parser.add_argument('--rand-mid', dest='rand_mid', action='store_true',
default=False, help='use random machine ID')
- parser.add_argument('--no-lan-disc', dest='no_lan_disc', action='store_true',
- default=False, help='disable LAN discovery')
+ parser.add_argument('--no-lan-disc', dest='lan_disc', action='store_false',
+ default=True, help='disable LAN discovery')
+ parser.add_argument('--no-rpc-server', dest='rpc_server',
+ action='store_false', default=True,
+ help='disable RPC server')
parser.add_argument("--prop-file", dest="prop_file", type=str, default=None,
help='file containing the JSON representation of client '
'properties')
@@ -494,7 +524,7 @@
g = Ghost(addrs, Ghost.AGENT, args.rand_mid)
if args.prop_file:
g.LoadPropertiesFromFile(args.prop_file)
- g.Start(args.no_lan_disc)
+ g.Start(args.lan_disc, args.rpc_server)
if __name__ == '__main__':