blob: e784ce4f896d94179a1517f02ceea699ad1d6bd4 [file] [log] [blame]
Wei-Ning Huang1cea6112015-03-02 12:45:34 +08001#!/usr/bin/python -u
2# -*- coding: utf-8 -*-
3#
4# Copyright 2015 The Chromium OS Authors. All rights reserved.
5# Use of this source code is governed by a BSD-style license that can be
6# found in the LICENSE file.
7
8import fcntl
9import json
10import logging
11import os
12import Queue
13import select
14import socket
15import subprocess
16import sys
17import threading
18import time
19import uuid
20
21
22_OVERLORD_PORT = 4455
23_OVERLORD_LAN_DISCOVERY_PORT = 4456
24
25_BUFSIZE = 8192
26_RETRY_INTERVAL = 2
27_SEPARATOR = '\r\n'
28_PING_TIMEOUT = 3
29_PING_INTERVAL = 5
30_REQUEST_TIMEOUT_SECS = 60
31_SHELL = os.getenv('SHELL', '/bin/bash')
32
33RESPONSE_SUCCESS = 'success'
34RESPONSE_FAILED = 'failed'
35
36
37class PingTimeoutError(Exception):
38 pass
39
40
41class RequestError(Exception):
42 pass
43
44
45class Ghost(object):
46 """Ghost implements the client protocol of Overlord.
47
48 Ghost provide terminal/shell/logcat functionality and manages the client
49 side connectivity.
50 """
51 NONE, AGENT, SHELL, LOGCAT, SLOGCAT = range(5)
52
53 MODE_NAME = {
54 NONE: 'NONE',
55 AGENT: 'Agent',
56 SHELL: 'Shell',
57 LOGCAT: 'Logcat',
58 SLOGCAT: 'Simple-Logcat'
59 }
60
61 def __init__(self, overlord_addrs, mode=AGENT, sid=None, filename=None):
62 """Constructor.
63
64 Args:
65 overlord_addrs: a list of possible address of overlord.
66 mode: client mode, either AGENT, SHELL or LOGCAT
67 sid: session id. If the connection is requested by overlord, sid should
68 be set to the corresponding session id assigned by overlord.
69 filename: the filename to cat when we are in LOGCAT mode.
70 """
71 assert mode in [Ghost.AGENT, Ghost.SHELL, Ghost.LOGCAT]
72 if mode == Ghost.LOGCAT:
73 assert filename is not None
74
75 self._overlord_addrs = overlord_addrs
76 self._mode = mode
77 self._sock = None
78 self._machine_id = self.GetMachineID()
79 self._client_id = sid if sid is not None else str(uuid.uuid4())
80 self._logcat_filename = filename
81 self._buf = ''
82 self._requests = {}
83 self._reset = False
84 self._last_ping = 0
85 self._queue = Queue.Queue()
86
87 def SpawnGhost(self, mode, sid, filename=None):
88 """Spawn a child ghost with specific mode.
89
90 Returns:
91 The spawned child process pid.
92 """
93 pid = os.fork()
94 if pid == 0:
95 g = Ghost(self._overlord_addrs, mode, sid, filename)
96 g.Start()
97 sys.exit(0)
98 else:
99 return pid
100
101 def Timestamp(self):
102 return int(time.time())
103
104 def GetGateWayIP(self):
105 with open('/proc/net/route', 'r') as f:
106 lines = f.readlines()
107
108 ips = []
109 for line in lines:
110 parts = line.split('\t')
111 if parts[2] == '00000000':
112 continue
113
114 try:
115 h = parts[2].decode('hex')
116 ips.append('%d.%d.%d.%d' % tuple(ord(x) for x in reversed(h)))
117 except TypeError:
118 pass
119
120 return ips
121
122 def GetMachineID(self):
123 """Generates machine-dependent ID string for a machine.
124 There are many ways to generate a machine ID:
125 1. factory device-data
126 2. /sys/class/dmi/id/product_uuid (only available on intel machines)
127 3. MAC address
128 We follow the listed order to generate machine ID, and fallback to the next
129 alternative if the previous doesn't work.
130 """
131 try:
132 p = subprocess.Popen('factory device-data | grep mlb_serial_number | '
133 'cut -d " " -f 2', stdout=subprocess.PIPE,
134 shell=True)
135 stdout, _ = p.communicate()
136 if stdout == '':
137 raise RuntimeError("empty mlb number")
138 return stdout.strip()
139 except Exception:
140 pass
141
142 try:
143 with open('/sys/class/dmi/id/product_uuid', 'r') as f:
144 return f.read().strip()
145 except Exception:
146 pass
147
148 try:
149 macs = []
150 ifaces = sorted(os.listdir('/sys/class/net'))
151 for iface in ifaces:
152 if iface == 'lo':
153 continue
154
155 with open('/sys/class/net/%s/address' % iface, 'r') as f:
156 macs.append(f.read().strip())
157
158 return ';'.join(macs)
159 except Exception:
160 pass
161
162 raise RuntimeError("can't generate machine ID")
163
164 def Reset(self):
165 """Reset state and clear request handlers."""
166 self._reset = False
167 self._buf = ""
168 self._last_ping = 0
169 self._requests = {}
170
171 def SendMessage(self, msg):
172 """Serialize the message and send it through the socket."""
173 self._sock.send(json.dumps(msg) + _SEPARATOR)
174
175 def SendRequest(self, name, args, handler=None,
176 timeout=_REQUEST_TIMEOUT_SECS):
177 if handler and not callable(handler):
178 raise RequestError('Invalid requiest handler for msg "%s"' % name)
179
180 rid = str(uuid.uuid4())
181 msg = {'rid': rid, 'timeout': timeout, 'name': name, 'params': args}
182 self._requests[rid] = [self.Timestamp(), timeout, handler]
183 self.SendMessage(msg)
184
185 def SendResponse(self, omsg, status, params=None):
186 msg = {'rid': omsg['rid'], 'response': status, 'params': params}
187 self.SendMessage(msg)
188
189 def SpawnPTYServer(self, _):
190 """Spawn a PTY server and forward I/O to the TCP socket."""
191 logging.info('SpawnPTYServer: started')
192
193 pid, fd = os.forkpty()
194 if pid == 0:
195 env = os.environ.copy()
196 env['USER'] = os.getenv('USER', 'root')
197 env['HOME'] = os.getenv('HOME', '/root')
198 os.chdir(env['HOME'])
199 os.execve(_SHELL, [_SHELL], env)
200 else:
201 try:
202 while True:
203 rd, _, _ = select.select([self._sock, fd], [], [])
204
205 if fd in rd:
206 self._sock.send(os.read(fd, _BUFSIZE))
207
208 if self._sock in rd:
209 ret = self._sock.recv(_BUFSIZE)
210 if len(ret) == 0:
211 raise RuntimeError("socket closed")
212 os.write(fd, ret)
213 except (OSError, socket.error, RuntimeError):
214 self._sock.close()
215 logging.info('SpawnPTYServer: terminated')
216 sys.exit(0)
217
218 def SpawnLogcatServer(self, _):
219 """Spawn a Logcat server and forward output to the TCP socket."""
220 logging.info('SpawnLogcatServer: started')
221
222 p = subprocess.Popen('tail -n +0 -f "%s"' % self._logcat_filename,
223 stdout=subprocess.PIPE, stderr=subprocess.PIPE,
224 shell=True)
225
226 def make_non_block(fd):
227 fl = fcntl.fcntl(fd, fcntl.F_GETFL)
228 fcntl.fcntl(fd, fcntl.F_SETFL, fl | os.O_NONBLOCK)
229
230 make_non_block(p.stdout)
231 make_non_block(p.stderr)
232
233 try:
234 while True:
235 rd, _, _ = select.select([p.stdout, p.stderr, self._sock], [], [])
236
237 if p.stdout in rd:
238 self._sock.send(p.stdout.read(_BUFSIZE))
239
240 if p.stderr in rd:
241 self._sock.send(p.stderr.read(_BUFSIZE))
242
243 if self._sock in rd:
244 ret = self._sock.recv(_BUFSIZE)
245 if len(ret) == 0:
246 raise RuntimeError("socket closed")
247 except (OSError, socket.error, RuntimeError):
248 self._sock.close()
249 logging.info('SpawnLogcatServer: terminated')
250 sys.exit(0)
251
252
253 def Ping(self):
254 def timeout_handler(x):
255 if x is None:
256 raise PingTimeoutError
257
258 self._last_ping = self.Timestamp()
259 self.SendRequest('ping', {}, timeout_handler, 5)
260
261 def HandleShellRequest(self, msg):
262 params = msg['params']
263 stdout = stderr = err_msg = ""
264 try:
265 p = subprocess.Popen([params['cmd']] + params['args'],
266 stdout=subprocess.PIPE, stderr=subprocess.PIPE)
267 stdout, stderr = p.communicate()
268 except Exception as e:
269 err_msg = str(e)
270
271 self.SendResponse(msg, RESPONSE_SUCCESS,
272 {'output': stdout + stderr, 'err_msg': err_msg})
273
274 def HandleRequest(self, msg):
275 if msg['name'] == 'shell':
276 self.HandleShellRequest(msg)
277 elif msg['name'] == 'terminal':
278 self.SpawnGhost(self.SHELL, msg['params']['sid'])
279 self.SendResponse(msg, RESPONSE_SUCCESS)
280 elif msg['name'] == 'logcat':
281 self.SpawnGhost(self.LOGCAT, msg['params']['sid'],
282 msg['params']['filename'])
283 self.SendResponse(msg, RESPONSE_SUCCESS)
284
285 def HandleResponse(self, response):
286 rid = str(response['rid'])
287 if rid in self._requests:
288 handler = self._requests[rid][2]
289 del self._requests[rid]
290 if callable(handler):
291 handler(response)
292 else:
293 print(response, self._requests.keys())
294 logging.warning('Recvied unsolicited response, ignored')
295
296 def ParseMessage(self):
297 msgs_json = self._buf.split(_SEPARATOR)
298 self._buf = msgs_json.pop()
299
300 for msg_json in msgs_json:
301 try:
302 msg = json.loads(msg_json)
303 except ValueError:
304 # Ignore mal-formed message.
305 continue
306
307 if 'name' in msg:
308 self.HandleRequest(msg)
309 elif 'response' in msg:
310 self.HandleResponse(msg)
311 else: # Ingnore mal-formed message.
312 pass
313
314 def ScanForTimeoutRequests(self):
315 for rid in self._requests.keys()[:]:
316 request_time, timeout, handler = self._requests[rid]
317 if self.Timestamp() - request_time > timeout:
318 handler(None)
319 del self._requests[rid]
320
321 def Listen(self):
322 try:
323 while True:
324 rds, _, _ = select.select([self._sock], [], [], _PING_INTERVAL / 2)
325
326 if self._sock in rds:
327 self._buf += self._sock.recv(_BUFSIZE)
328 self.ParseMessage()
329
330 if self.Timestamp() - self._last_ping > _PING_INTERVAL:
331 self.Ping()
332 self.ScanForTimeoutRequests()
333
334 if self._reset:
335 self.Reset()
336 break
337 except socket.error:
338 raise RuntimeError('Connection dropped')
339 except PingTimeoutError:
340 raise RuntimeError('Connection timeout')
341 finally:
342 self._sock.close()
343
344 self._queue.put('resume')
345
346 if self._mode != Ghost.AGENT:
347 sys.exit(1)
348
349 def Register(self):
350 non_local = {}
351 for addr in self._overlord_addrs:
352 non_local['addr'] = addr
353 def registered(response):
354 if response is None:
355 self._reset = True
356 raise RuntimeError('Register request timeout')
357 logging.info('Registered with Overlord at %s:%d', *non_local['addr'])
358 self._queue.put("pause", True)
359
360 try:
361 logging.info('Trying %s:%d ...', *addr)
362 self.Reset()
363 self._sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
364 self._sock.settimeout(_PING_TIMEOUT)
365 self._sock.connect(addr)
366
367 logging.info('Connection established, registering...')
368 handler = {
369 Ghost.AGENT: registered,
370 Ghost.SHELL: self.SpawnPTYServer,
371 Ghost.LOGCAT: self.SpawnLogcatServer
372 }[self._mode]
373
374 # Machine ID may change if MAC address is used (USB-ethernet dongle
375 # plugged/unplugged)
376 self._machine_id = self.GetMachineID()
377 self.SendRequest('register', {'mode': self._mode, 'mid': self._machine_id,
378 'cid': self._client_id}, handler)
379 except socket.error:
380 pass
381 else:
382 self._sock.settimeout(None)
383 self.Listen()
384
385 raise RuntimeError("Cannot connect to any server")
386
387 def StartLanDiscovery(self):
388 """Start to listen to LAN discovery packet at
389 _OVERLORD_LAN_DISCOVERY_PORT."""
390 s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
391 s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
392 s.setsockopt(socket.SOL_SOCKET, socket.SO_BROADCAST, 1)
393 try:
394 s.bind(('0.0.0.0', _OVERLORD_LAN_DISCOVERY_PORT))
395 except socket.error as e:
396 logging.error("LAN discovery: %s, abort", e)
397 return
398
399 logging.info('LAN Discovery: started')
400 while True:
401 rd, _, _ = select.select([s], [], [], 1)
402
403 if s in rd:
404 data, source_addr = s.recvfrom(_BUFSIZE)
405 parts = data.split()
406 if parts[0] == 'OVERLORD':
407 ip = source_addr[0]
408 port = int(parts[1].lstrip(':'))
409 addr = (ip, port)
410 self._queue.put(addr, True)
411
412 try:
413 obj = self._queue.get(False)
414 except Queue.Empty:
415 pass
416 else:
417 if type(obj) is not str:
418 self._queue.put(obj)
419 elif obj == 'pause':
420 logging.info('LAN Discovery: paused')
421 while True:
422 obj = self._queue.get(True)
423 if obj == 'resume':
424 logging.info('LAN Discovery: resumed')
425 break
426
427 def ScanGateway(self):
428 for addr in [(x, _OVERLORD_PORT) for x in self.GetGateWayIP()]:
429 if addr not in self._overlord_addrs:
430 self._overlord_addrs.append(addr)
431
432 def Start(self):
433 logging.info('%s started', self.MODE_NAME[self._mode])
434 logging.info('MID: %s', self._machine_id)
435 logging.info('CID: %s', self._client_id)
436
437 if self._mode == Ghost.AGENT:
438 t = threading.Thread(target=self.StartLanDiscovery)
439 t.daemon = True
440 t.start()
441
442 try:
443 while True:
444 try:
445 addr = self._queue.get(False)
446 except Queue.Empty:
447 pass
448 else:
449 if type(addr) == tuple and addr not in self._overlord_addrs:
450 logging.info('LAN Discovery: got overlord address %s:%d', *addr)
451 self._overlord_addrs.append(addr)
452
453 try:
454 self.ScanGateway()
455 self.Register()
456 except Exception as e:
457 logging.info(str(e) + ', retrying in %ds' % _RETRY_INTERVAL)
458 time.sleep(_RETRY_INTERVAL)
459
460 self.Reset()
461 except KeyboardInterrupt:
462 logging.error('Received keyboard interrupt, quit')
463 sys.exit(0)
464
465
466def main():
467 logger = logging.getLogger()
468 logger.setLevel(logging.INFO)
469
470 addrs = [('localhost', _OVERLORD_PORT)]
471 if len(sys.argv) > 1:
472 addrs += [(x, _OVERLORD_PORT) for x in sys.argv[1:]]
473
474 g = Ghost(addrs)
475 g.Start()
476
477
478if __name__ == '__main__':
479 main()