overlord: ghost: buffer handling of leftover data

When the overlord control channel switch into a data channel after
registration, the leftover data need to be interpreted as input data
steam. Instead of having to deal with the leftover separately, implement
a wrapper class that supports 'UnRecv'. UnRecv push the data in to a
buffer, which will later be consume with Recv() is called.

BUG=none
TEST=`go/src/overlord/test/overlord_e2e_unittest.py`

Change-Id: Ia3068528ceeded5c49f8ea905fb7cc8cedb22a1f
Reviewed-on: https://chromium-review.googlesource.com/323923
Commit-Ready: Wei-Ning Huang <wnhuang@chromium.org>
Tested-by: Wei-Ning Huang <wnhuang@chromium.org>
Reviewed-by: Rong Chang <rongchang@chromium.org>
diff --git a/py/tools/ghost.py b/py/tools/ghost.py
index 998a02a..8f7de48 100755
--- a/py/tools/ghost.py
+++ b/py/tools/ghost.py
@@ -69,6 +69,41 @@
   pass
 
 
+class BufferedSocket(socket.socket):
+  """A buffered socket that supports unrecv.
+
+  Allow putting back data back to the socket for the next recv() call.
+  """
+  def __init__(self, *args, **kwargs):
+    super(BufferedSocket, self).__init__(*args, **kwargs)
+    self._buf = ''
+
+  def Recv(self, bufsize, flags=0):
+    if self._buf:
+      if len(self._buf) >= bufsize:
+        ret = self._buf[:bufsize]
+        self._buf = self._buf[bufsize:]
+        return ret
+      else:
+        ret = self._buf
+        self._buf = ''
+        return ret + super(BufferedSocket, self).recv(bufsize - len(ret), flags)
+    else:
+      return super(BufferedSocket, self).recv(bufsize, flags)
+
+  def UnRecv(self, buf):
+    self._buf = buf + self._buf
+
+  def Send(self, *args, **kwargs):
+    return super(BufferedSocket, self).send(*args, **kwargs)
+
+  def RecvBuf(self):
+    """Only recive from buffer."""
+    ret = self._buf
+    self._buf = ''
+    return ret
+
+
 class Ghost(object):
   """Ghost implements the client protocol of Overlord.
 
@@ -134,7 +169,6 @@
     self._reset = threading.Event()
 
     # RPC
-    self._buf = ''  # Read buffer
     self._requests = {}
     self._queue = Queue.Queue()
 
@@ -357,7 +391,6 @@
   def Reset(self):
     """Reset state and clear request handlers."""
     self._reset.clear()
-    self._buf = ''
     self._last_ping = 0
     self._requests = {}
     self.LoadProperties()
@@ -365,7 +398,7 @@
 
   def SendMessage(self, msg):
     """Serialize the message and send it through the socket."""
-    self._sock.send(json.dumps(msg) + _SEPARATOR)
+    self._sock.Send(json.dumps(msg) + _SEPARATOR)
 
   def SendRequest(self, name, args, handler=None,
                   timeout=_REQUEST_TIMEOUT_SECS):
@@ -382,8 +415,8 @@
     msg = {'rid': omsg['rid'], 'response': status, 'params': params}
     self.SendMessage(msg)
 
-  def HandleTTYControl(self, fd, control_string):
-    msg = json.loads(control_string)
+  def HandleTTYControl(self, fd, control_str):
+    msg = json.loads(control_str)
     command = msg['command']
     params = msg['params']
     if command == 'resize':
@@ -434,43 +467,48 @@
         attr[5] = termios.B115200
         termios.tcsetattr(fd, termios.TCSANOW, attr)
 
-      control_state = None
-      control_string = ''
-      write_buffer = ''
+      nonlocals = {'control_state': None, 'control_str': ''}
+
+      def _ProcessBuffer(buf):
+        write_buffer = ''
+        while buf:
+          if nonlocals['control_state']:
+            if chr(_CONTROL_END) in buf:
+              index = buf.index(chr(_CONTROL_END))
+              nonlocals['control_str'] += buf[:index]
+              self.HandleTTYControl(fd, nonlocals['control_str'])
+              nonlocals['control_state'] = None
+              nonlocals['control_str'] = ''
+              buf = buf[index+1:]
+            else:
+              nonlocals['control_str'] += buf
+              buf = ''
+          else:
+            if chr(_CONTROL_START) in buf:
+              nonlocals['control_state'] = _CONTROL_START
+              index = buf.index(chr(_CONTROL_START))
+              write_buffer += buf[:index]
+              buf = buf[index+1:]
+            else:
+              write_buffer += buf
+              buf = ''
+
+        if write_buffer:
+          os.write(fd, write_buffer)
+
+      _ProcessBuffer(self._sock.RecvBuf())
+
       while True:
         rd, unused_wd, unused_xd = select.select([self._sock, fd], [], [])
 
         if fd in rd:
-          self._sock.send(os.read(fd, _BUFSIZE))
+          self._sock.Send(os.read(fd, _BUFSIZE))
 
         if self._sock in rd:
-          ret = self._sock.recv(_BUFSIZE)
-          if len(ret) == 0:
+          buf = self._sock.Recv(_BUFSIZE)
+          if len(buf) == 0:
             raise RuntimeError('connection terminated')
-          while ret:
-            if control_state:
-              if chr(_CONTROL_END) in ret:
-                index = ret.index(chr(_CONTROL_END))
-                control_string += ret[:index]
-                self.HandleTTYControl(fd, control_string)
-                control_state = None
-                control_string = ''
-                ret = ret[index+1:]
-              else:
-                control_string += ret
-                ret = ''
-            else:
-              if chr(_CONTROL_START) in ret:
-                control_state = _CONTROL_START
-                index = ret.index(chr(_CONTROL_START))
-                write_buffer += ret[:index]
-                ret = ret[index+1:]
-              else:
-                write_buffer += ret
-                ret = ''
-          if write_buffer:
-            os.write(fd, write_buffer)
-            write_buffer = ''
+          _ProcessBuffer(buf)
     except Exception as e:
       logging.error('SpawnTTYServer: %s', e)
     finally:
@@ -503,18 +541,19 @@
     make_non_block(p.stderr)
 
     try:
+      p.stdin.write(self._sock.RecvBuf())
 
       while True:
         rd, unused_wd, unused_xd = select.select(
             [p.stdout, p.stderr, self._sock], [], [self._sock])
         if p.stdout in rd:
-          self._sock.send(p.stdout.read(_BUFSIZE))
+          self._sock.Send(p.stdout.read(_BUFSIZE))
 
         if p.stderr in rd:
-          self._sock.send(p.stderr.read(_BUFSIZE))
+          self._sock.Send(p.stderr.read(_BUFSIZE))
 
         if self._sock in rd:
-          ret = self._sock.recv(_BUFSIZE)
+          ret = self._sock.Recv(_BUFSIZE)
           if len(ret) == 0:
             raise RuntimeError('connection terminated')
 
@@ -575,7 +614,7 @@
           data = f.read(_BLOCK_SIZE)
           if len(data) == 0:
             break
-          self._sock.send(data)
+          self._sock.Send(data)
     except Exception as e:
       logging.error('StartDownloadServer: %s', e)
     finally:
@@ -600,10 +639,12 @@
         if self._file_op[2]:
           os.fchmod(f.fileno(), self._file_op[2])
 
+        f.write(self._sock.RecvBuf())
+
         while True:
           rd, unused_wd, unused_xd = select.select([self._sock], [], [])
           if self._sock in rd:
-            buf = self._sock.recv(_BLOCK_SIZE)
+            buf = self._sock.Recv(_BLOCK_SIZE)
             if len(buf) == 0:
               break
             f.write(buf)
@@ -628,16 +669,13 @@
       src_sock.connect(('localhost', self._port))
       src_sock.setblocking(False)
 
-      # Pass the leftovers of the previous buffer
-      if self._buf:
-        src_sock.send(self._buf)
-        self._buf = ''
+      src_sock.send(self._sock.RecvBuf())
 
       while True:
         rd, unused_wd, unused_xd = select.select([self._sock, src_sock], [], [])
 
         if self._sock in rd:
-          data = self._sock.recv(_BUFSIZE)
+          data = self._sock.Recv(_BUFSIZE)
           if len(data) == 0:
             raise RuntimeError('connection terminated')
           src_sock.send(data)
@@ -646,7 +684,7 @@
           data = src_sock.recv(_BUFSIZE)
           if len(data) == 0:
             break
-          self._sock.send(data)
+          self._sock.Send(data)
     except Exception as e:
       logging.error('SpawnPortForwardServer: %s', e)
     finally:
@@ -757,14 +795,14 @@
     else:
       logging.warning('Received unsolicited response, ignored')
 
-  def ParseMessage(self, single=True):
+  def ParseMessage(self, buf, single=True):
     if single:
-      index = self._buf.index(_SEPARATOR)
-      msgs_json = [self._buf[:index]]
-      self._buf = self._buf[index + 2:]
+      index = buf.index(_SEPARATOR)
+      msgs_json = [buf[:index]]
+      self._sock.UnRecv(buf[index + 2:])
     else:
-      msgs_json = self._buf.split(_SEPARATOR)
-      self._buf = msgs_json.pop()
+      msgs_json = buf.split(_SEPARATOR)
+      self._sock.UnRecv(msgs_json.pop())
 
     for msg_json in msgs_json:
       try:
@@ -809,15 +847,14 @@
                                                   _PING_INTERVAL / 2)
 
         if self._sock in rds:
-          data = self._sock.recv(_BUFSIZE)
+          data = self._sock.Recv(_BUFSIZE)
 
           # Socket is closed
           if len(data) == 0:
             self.Reset()
             break
 
-          self._buf += data
-          self.ParseMessage(self._register_status != SUCCESS)
+          self.ParseMessage(data, self._register_status != SUCCESS)
 
         if (self._mode == self.AGENT and
             self.Timestamp() - self._last_ping > _PING_INTERVAL):
@@ -864,7 +901,7 @@
       try:
         logging.info('Trying %s:%d ...', *addr)
         self.Reset()
-        self._sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
+        self._sock = BufferedSocket(socket.AF_INET, socket.SOCK_STREAM)
         self._sock.settimeout(_CONNECT_TIMEOUT)
         self._sock.connect(addr)