overlord: ghost: buffer handling of leftover data
When the overlord control channel switch into a data channel after
registration, the leftover data need to be interpreted as input data
steam. Instead of having to deal with the leftover separately, implement
a wrapper class that supports 'UnRecv'. UnRecv push the data in to a
buffer, which will later be consume with Recv() is called.
BUG=none
TEST=`go/src/overlord/test/overlord_e2e_unittest.py`
Change-Id: Ia3068528ceeded5c49f8ea905fb7cc8cedb22a1f
Reviewed-on: https://chromium-review.googlesource.com/323923
Commit-Ready: Wei-Ning Huang <wnhuang@chromium.org>
Tested-by: Wei-Ning Huang <wnhuang@chromium.org>
Reviewed-by: Rong Chang <rongchang@chromium.org>
diff --git a/py/tools/ghost.py b/py/tools/ghost.py
index 998a02a..8f7de48 100755
--- a/py/tools/ghost.py
+++ b/py/tools/ghost.py
@@ -69,6 +69,41 @@
pass
+class BufferedSocket(socket.socket):
+ """A buffered socket that supports unrecv.
+
+ Allow putting back data back to the socket for the next recv() call.
+ """
+ def __init__(self, *args, **kwargs):
+ super(BufferedSocket, self).__init__(*args, **kwargs)
+ self._buf = ''
+
+ def Recv(self, bufsize, flags=0):
+ if self._buf:
+ if len(self._buf) >= bufsize:
+ ret = self._buf[:bufsize]
+ self._buf = self._buf[bufsize:]
+ return ret
+ else:
+ ret = self._buf
+ self._buf = ''
+ return ret + super(BufferedSocket, self).recv(bufsize - len(ret), flags)
+ else:
+ return super(BufferedSocket, self).recv(bufsize, flags)
+
+ def UnRecv(self, buf):
+ self._buf = buf + self._buf
+
+ def Send(self, *args, **kwargs):
+ return super(BufferedSocket, self).send(*args, **kwargs)
+
+ def RecvBuf(self):
+ """Only recive from buffer."""
+ ret = self._buf
+ self._buf = ''
+ return ret
+
+
class Ghost(object):
"""Ghost implements the client protocol of Overlord.
@@ -134,7 +169,6 @@
self._reset = threading.Event()
# RPC
- self._buf = '' # Read buffer
self._requests = {}
self._queue = Queue.Queue()
@@ -357,7 +391,6 @@
def Reset(self):
"""Reset state and clear request handlers."""
self._reset.clear()
- self._buf = ''
self._last_ping = 0
self._requests = {}
self.LoadProperties()
@@ -365,7 +398,7 @@
def SendMessage(self, msg):
"""Serialize the message and send it through the socket."""
- self._sock.send(json.dumps(msg) + _SEPARATOR)
+ self._sock.Send(json.dumps(msg) + _SEPARATOR)
def SendRequest(self, name, args, handler=None,
timeout=_REQUEST_TIMEOUT_SECS):
@@ -382,8 +415,8 @@
msg = {'rid': omsg['rid'], 'response': status, 'params': params}
self.SendMessage(msg)
- def HandleTTYControl(self, fd, control_string):
- msg = json.loads(control_string)
+ def HandleTTYControl(self, fd, control_str):
+ msg = json.loads(control_str)
command = msg['command']
params = msg['params']
if command == 'resize':
@@ -434,43 +467,48 @@
attr[5] = termios.B115200
termios.tcsetattr(fd, termios.TCSANOW, attr)
- control_state = None
- control_string = ''
- write_buffer = ''
+ nonlocals = {'control_state': None, 'control_str': ''}
+
+ def _ProcessBuffer(buf):
+ write_buffer = ''
+ while buf:
+ if nonlocals['control_state']:
+ if chr(_CONTROL_END) in buf:
+ index = buf.index(chr(_CONTROL_END))
+ nonlocals['control_str'] += buf[:index]
+ self.HandleTTYControl(fd, nonlocals['control_str'])
+ nonlocals['control_state'] = None
+ nonlocals['control_str'] = ''
+ buf = buf[index+1:]
+ else:
+ nonlocals['control_str'] += buf
+ buf = ''
+ else:
+ if chr(_CONTROL_START) in buf:
+ nonlocals['control_state'] = _CONTROL_START
+ index = buf.index(chr(_CONTROL_START))
+ write_buffer += buf[:index]
+ buf = buf[index+1:]
+ else:
+ write_buffer += buf
+ buf = ''
+
+ if write_buffer:
+ os.write(fd, write_buffer)
+
+ _ProcessBuffer(self._sock.RecvBuf())
+
while True:
rd, unused_wd, unused_xd = select.select([self._sock, fd], [], [])
if fd in rd:
- self._sock.send(os.read(fd, _BUFSIZE))
+ self._sock.Send(os.read(fd, _BUFSIZE))
if self._sock in rd:
- ret = self._sock.recv(_BUFSIZE)
- if len(ret) == 0:
+ buf = self._sock.Recv(_BUFSIZE)
+ if len(buf) == 0:
raise RuntimeError('connection terminated')
- while ret:
- if control_state:
- if chr(_CONTROL_END) in ret:
- index = ret.index(chr(_CONTROL_END))
- control_string += ret[:index]
- self.HandleTTYControl(fd, control_string)
- control_state = None
- control_string = ''
- ret = ret[index+1:]
- else:
- control_string += ret
- ret = ''
- else:
- if chr(_CONTROL_START) in ret:
- control_state = _CONTROL_START
- index = ret.index(chr(_CONTROL_START))
- write_buffer += ret[:index]
- ret = ret[index+1:]
- else:
- write_buffer += ret
- ret = ''
- if write_buffer:
- os.write(fd, write_buffer)
- write_buffer = ''
+ _ProcessBuffer(buf)
except Exception as e:
logging.error('SpawnTTYServer: %s', e)
finally:
@@ -503,18 +541,19 @@
make_non_block(p.stderr)
try:
+ p.stdin.write(self._sock.RecvBuf())
while True:
rd, unused_wd, unused_xd = select.select(
[p.stdout, p.stderr, self._sock], [], [self._sock])
if p.stdout in rd:
- self._sock.send(p.stdout.read(_BUFSIZE))
+ self._sock.Send(p.stdout.read(_BUFSIZE))
if p.stderr in rd:
- self._sock.send(p.stderr.read(_BUFSIZE))
+ self._sock.Send(p.stderr.read(_BUFSIZE))
if self._sock in rd:
- ret = self._sock.recv(_BUFSIZE)
+ ret = self._sock.Recv(_BUFSIZE)
if len(ret) == 0:
raise RuntimeError('connection terminated')
@@ -575,7 +614,7 @@
data = f.read(_BLOCK_SIZE)
if len(data) == 0:
break
- self._sock.send(data)
+ self._sock.Send(data)
except Exception as e:
logging.error('StartDownloadServer: %s', e)
finally:
@@ -600,10 +639,12 @@
if self._file_op[2]:
os.fchmod(f.fileno(), self._file_op[2])
+ f.write(self._sock.RecvBuf())
+
while True:
rd, unused_wd, unused_xd = select.select([self._sock], [], [])
if self._sock in rd:
- buf = self._sock.recv(_BLOCK_SIZE)
+ buf = self._sock.Recv(_BLOCK_SIZE)
if len(buf) == 0:
break
f.write(buf)
@@ -628,16 +669,13 @@
src_sock.connect(('localhost', self._port))
src_sock.setblocking(False)
- # Pass the leftovers of the previous buffer
- if self._buf:
- src_sock.send(self._buf)
- self._buf = ''
+ src_sock.send(self._sock.RecvBuf())
while True:
rd, unused_wd, unused_xd = select.select([self._sock, src_sock], [], [])
if self._sock in rd:
- data = self._sock.recv(_BUFSIZE)
+ data = self._sock.Recv(_BUFSIZE)
if len(data) == 0:
raise RuntimeError('connection terminated')
src_sock.send(data)
@@ -646,7 +684,7 @@
data = src_sock.recv(_BUFSIZE)
if len(data) == 0:
break
- self._sock.send(data)
+ self._sock.Send(data)
except Exception as e:
logging.error('SpawnPortForwardServer: %s', e)
finally:
@@ -757,14 +795,14 @@
else:
logging.warning('Received unsolicited response, ignored')
- def ParseMessage(self, single=True):
+ def ParseMessage(self, buf, single=True):
if single:
- index = self._buf.index(_SEPARATOR)
- msgs_json = [self._buf[:index]]
- self._buf = self._buf[index + 2:]
+ index = buf.index(_SEPARATOR)
+ msgs_json = [buf[:index]]
+ self._sock.UnRecv(buf[index + 2:])
else:
- msgs_json = self._buf.split(_SEPARATOR)
- self._buf = msgs_json.pop()
+ msgs_json = buf.split(_SEPARATOR)
+ self._sock.UnRecv(msgs_json.pop())
for msg_json in msgs_json:
try:
@@ -809,15 +847,14 @@
_PING_INTERVAL / 2)
if self._sock in rds:
- data = self._sock.recv(_BUFSIZE)
+ data = self._sock.Recv(_BUFSIZE)
# Socket is closed
if len(data) == 0:
self.Reset()
break
- self._buf += data
- self.ParseMessage(self._register_status != SUCCESS)
+ self.ParseMessage(data, self._register_status != SUCCESS)
if (self._mode == self.AGENT and
self.Timestamp() - self._last_ping > _PING_INTERVAL):
@@ -864,7 +901,7 @@
try:
logging.info('Trying %s:%d ...', *addr)
self.Reset()
- self._sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
+ self._sock = BufferedSocket(socket.AF_INET, socket.SOCK_STREAM)
self._sock.settimeout(_CONNECT_TIMEOUT)
self._sock.connect(addr)