ghost: update machine ID with mlb_serial_number once it's set

When mlb_number is set for the device in the factory test, ghost should
automatically update the machine ID to MLB number and reconnects with
overlord. We implement an JSON-RPC server in Ghost containing the
'Reconnect' method, which is later called by the 'scan' pytest.

BUG=chromium:470876
TEST=run test on device, the ghost client should reconnect with new MLB
number once the MLB number is set.

Change-Id: Idef9b16f77fcca432a3a9d3faac71d852813a90a
Reviewed-on: https://chromium-review.googlesource.com/265403
Tested-by: Wei-Ning Huang <wnhuang@chromium.org>
Reviewed-by: Hung-Te Lin <hungte@chromium.org>
Commit-Queue: Wei-Ning Huang <wnhuang@chromium.org>
diff --git a/py/tools/ghost.py b/py/tools/ghost.py
index de5bfb8..5d7006d 100755
--- a/py/tools/ghost.py
+++ b/py/tools/ghost.py
@@ -19,6 +19,11 @@
 import time
 import uuid
 
+import jsonrpclib
+from jsonrpclib.SimpleJSONRPCServer import SimpleJSONRPCServer
+
+
+_GHOST_RPC_PORT = 4499
 
 _OVERLORD_PORT = 4455
 _OVERLORD_LAN_DISCOVERY_PORT = 4456
@@ -30,6 +35,7 @@
 _PING_INTERVAL = 5
 _REQUEST_TIMEOUT_SECS = 60
 _SHELL = os.getenv('SHELL', '/bin/bash')
+_DEFAULT_BIND_ADDRESS = '0.0.0.0'
 
 RESPONSE_SUCCESS = 'success'
 RESPONSE_FAILED = 'failed'
@@ -86,7 +92,7 @@
     self._shell_command = command
     self._buf = ''
     self._requests = {}
-    self._reset = False
+    self._reset = threading.Event()
     self._last_ping = 0
     self._queue = Queue.Queue()
 
@@ -106,7 +112,7 @@
     pid = os.fork()
     if pid == 0:
       g = Ghost([self._connected_addr], mode, True, sid, command)
-      g.Start(True)
+      g.Start()
       sys.exit(0)
     else:
       return pid
@@ -179,7 +185,7 @@
 
   def Reset(self):
     """Reset state and clear request handlers."""
-    self._reset = False
+    self._reset.clear()
     self._buf = ""
     self._last_ping = 0
     self._requests = {}
@@ -337,7 +343,7 @@
           self.Ping()
         self.ScanForTimeoutRequests()
 
-        if self._reset:
+        if self._reset.is_set():
           self.Reset()
           break
     except socket.error:
@@ -358,7 +364,7 @@
       non_local['addr'] = addr
       def registered(response):
         if response is None:
-          self._reset = True
+          self._reset.set()
           raise RuntimeError('Register request timeout')
         logging.info('Registered with Overlord at %s:%d', *non_local['addr'])
         self._queue.put("pause", True)
@@ -393,60 +399,77 @@
 
     raise RuntimeError("Cannot connect to any server")
 
+  def Reconnect(self):
+    logging.info('Received reconnect request from RPC server, reconnecting...')
+    self._reset.set()
+
   def StartLanDiscovery(self):
     """Start to listen to LAN discovery packet at
     _OVERLORD_LAN_DISCOVERY_PORT."""
-    s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
-    s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
-    s.setsockopt(socket.SOL_SOCKET, socket.SO_BROADCAST, 1)
-    try:
-      s.bind(('0.0.0.0', _OVERLORD_LAN_DISCOVERY_PORT))
-    except socket.error as e:
-      logging.error("LAN discovery: %s, abort", e)
-      return
 
-    logging.info('LAN Discovery: started')
-    while True:
-      rd, _, _ = select.select([s], [], [], 1)
-
-      if s in rd:
-        data, source_addr = s.recvfrom(_BUFSIZE)
-        parts = data.split()
-        if parts[0] == 'OVERLORD':
-          ip, port = parts[1].split(':')
-          if len(ip.strip()) == 0:
-            ip = source_addr[0]
-          self._queue.put((ip, int(port)), True)
-
+    def thread_func():
+      s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
+      s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
+      s.setsockopt(socket.SOL_SOCKET, socket.SO_BROADCAST, 1)
       try:
-        obj = self._queue.get(False)
-      except Queue.Empty:
-        pass
-      else:
-        if type(obj) is not str:
-          self._queue.put(obj)
-        elif obj == 'pause':
-          logging.info('LAN Discovery: paused')
-          while True:
-            obj = self._queue.get(True)
-            if obj == 'resume':
-              logging.info('LAN Discovery: resumed')
-              break
+        s.bind(('0.0.0.0', _OVERLORD_LAN_DISCOVERY_PORT))
+      except socket.error as e:
+        logging.error("LAN discovery: %s, abort", e)
+        return
+
+      logging.info('LAN Discovery: started')
+      while True:
+        rd, _, _ = select.select([s], [], [], 1)
+
+        if s in rd:
+          data, source_addr = s.recvfrom(_BUFSIZE)
+          parts = data.split()
+          if parts[0] == 'OVERLORD':
+            ip, port = parts[1].split(':')
+            if not ip:
+              ip = source_addr[0]
+            self._queue.put((ip, int(port)), True)
+
+        try:
+          obj = self._queue.get(False)
+        except Queue.Empty:
+          pass
+        else:
+          if type(obj) is not str:
+            self._queue.put(obj)
+          elif obj == 'pause':
+            logging.info('LAN Discovery: paused')
+            while obj != 'resume':
+              obj = self._queue.get(True)
+            logging.info('LAN Discovery: resumed')
+
+    t = threading.Thread(target=thread_func)
+    t.daemon = True
+    t.start()
+
+  def StartRPCServer(self):
+    rpc_server = SimpleJSONRPCServer((_DEFAULT_BIND_ADDRESS, _GHOST_RPC_PORT),
+                                     logRequests=False)
+    rpc_server.register_function(self.Reconnect, 'Reconnect')
+    t = threading.Thread(target=rpc_server.serve_forever)
+    t.daemon = True
+    t.start()
 
   def ScanGateway(self):
     for addr in [(x, _OVERLORD_PORT) for x in self.GetGateWayIP()]:
       if addr not in self._overlord_addrs:
         self._overlord_addrs.append(addr)
 
-  def Start(self, no_lan_disc=False):
+  def Start(self, lan_disc=False, rpc_server=False):
     logging.info('%s started', self.MODE_NAME[self._mode])
     logging.info('MID: %s', self._machine_id)
     logging.info('CID: %s', self._client_id)
 
-    if not no_lan_disc:
-      t = threading.Thread(target=self.StartLanDiscovery)
-      t.daemon = True
-      t.start()
+    if lan_disc:
+      self.StartLanDiscovery()
+
+    if rpc_server:
+      self.StartRPCServer()
 
     try:
       while True:
@@ -472,6 +495,10 @@
       sys.exit(0)
 
 
+def GhostRPCServer():
+  return jsonrpclib.Server('http://localhost:%d' % _GHOST_RPC_PORT)
+
+
 def main():
   logger = logging.getLogger()
   logger.setLevel(logging.INFO)
@@ -479,8 +506,11 @@
   parser = argparse.ArgumentParser()
   parser.add_argument('--rand-mid', dest='rand_mid', action='store_true',
                       default=False, help='use random machine ID')
-  parser.add_argument('--no-lan-disc', dest='no_lan_disc', action='store_true',
-                      default=False, help='disable LAN discovery')
+  parser.add_argument('--no-lan-disc', dest='lan_disc', action='store_false',
+                      default=True, help='disable LAN discovery')
+  parser.add_argument('--no-rpc-server', dest='rpc_server',
+                      action='store_false', default=True,
+                      help='disable RPC server')
   parser.add_argument("--prop-file", dest="prop_file", type=str, default=None,
                       help='file containing the JSON representation of client '
                            'properties')
@@ -494,7 +524,7 @@
   g = Ghost(addrs, Ghost.AGENT, args.rand_mid)
   if args.prop_file:
     g.LoadPropertiesFromFile(args.prop_file)
-  g.Start(args.no_lan_disc)
+  g.Start(args.lan_disc, args.rpc_server)
 
 
 if __name__ == '__main__':