blob: 55875d22cb1c109759ecb4299b8625d0061a64a6 [file] [log] [blame]
Wei-Ning Huang91aaeed2015-09-24 14:51:56 +08001#!/usr/bin/python -u
2# -*- coding: utf-8 -*-
3#
4# Copyright 2015 The Chromium OS Authors. All rights reserved.
5# Use of this source code is governed by a BSD-style license that can be
6# found in the LICENSE file.
7
8from __future__ import print_function
9
10import argparse
11import ast
12import base64
13import fcntl
14import hashlib
15import httplib
16import json
17import jsonrpclib
18import logging
19import os
20import re
21import select
22import signal
23import socket
24import StringIO
25import struct
26import subprocess
27import sys
28import tempfile
29import termios
30import threading
31import time
32import tty
33import urllib2
34import urlparse
35
36from jsonrpclib.SimpleJSONRPCServer import SimpleJSONRPCServer
37from jsonrpclib.config import Config
38from ws4py.client import WebSocketBaseClient
39
40# Python version >= 2.7.9 enables SSL check by default, bypass it.
41try:
42 import ssl
43 # pylint: disable=W0212
44 ssl._create_default_https_context = ssl._create_unverified_context
45except Exception:
46 pass
47
48
49_ESCAPE = '~'
50_BUFSIZ = 8192
51_OVERLORD_PORT = 4455
52_OVERLORD_HTTP_PORT = 9000
53_OVERLORD_CLIENT_DAEMON_PORT = 4488
54_OVERLORD_CLIENT_DAEMON_RPC_ADDR = ('127.0.0.1', _OVERLORD_CLIENT_DAEMON_PORT)
55
Wei-Ning Huangee7ca8d2015-12-12 05:48:02 +080056_DEFAULT_HTTP_TIMEOUT = 30
Wei-Ning Huang91aaeed2015-09-24 14:51:56 +080057_LIST_CACHE_TIMEOUT = 2
58_DEFAULT_TERMINAL_WIDTH = 80
Wei-Ning Huangee7ca8d2015-12-12 05:48:02 +080059_RETRY_TIMES = 3
Wei-Ning Huang91aaeed2015-09-24 14:51:56 +080060
61# echo -n overlord | md5sum
62_HTTP_BOUNDARY_MAGIC = '9246f080c855a69012707ab53489b921'
63
64_CONTROL_START = 128
65_CONTROL_END = 129
66_SSH_CONTROL_SOCKET_PREFIX = os.path.join(tempfile.gettempdir(),
67 'ovl-ssh-control-')
68
69# A string that will always be included in the response of
70# GET http://OVERLORD_SERVER:_OVERLORD_HTTP_PORT
71_OVERLORD_RESPONSE_KEYWORD = '<html>'
72
73
74def GetVersionDigest():
75 """Return the sha1sum of the current executing script."""
76 with open(__file__, 'r') as f:
77 return hashlib.sha1(f.read()).hexdigest()
78
79
80def KillGraceful(pid, wait_secs=1):
81 """Kill a process gracefully by first sending SIGTERM, wait for some time,
82 then send SIGKILL to make sure it's killed."""
83 try:
84 os.kill(pid, signal.SIGTERM)
85 time.sleep(wait_secs)
86 os.kill(pid, signal.SIGKILL)
87 except OSError:
88 pass
89
90
Wei-Ning Huangee7ca8d2015-12-12 05:48:02 +080091def AutoRetry(action_name, retries):
92 """Decorator for retry function call."""
93 def Wrap(func):
94 def Loop(*args, **kwargs):
95 for i in range(retries):
96 try:
97 func(*args, **kwargs)
98 except Exception as e:
99 print('error: %s: %s: retrying ...' % (src, e))
100 else:
101 break
102 else:
103 print('error: failed to %s %s' % (action_name, src))
104 return Loop
105 return Wrap
106
107
Wei-Ning Huang91aaeed2015-09-24 14:51:56 +0800108def BasicAuthHeader(user, password):
109 """Return HTTP basic auth header."""
110 credential = base64.b64encode('%s:%s' % (user, password))
111 return ('Authorization', 'Basic %s' % credential)
112
113
114def GetTerminalSize():
115 """Retrieve terminal window size."""
116 ws = struct.pack('HHHH', 0, 0, 0, 0)
117 ws = fcntl.ioctl(0, termios.TIOCGWINSZ, ws)
118 lines, columns, unused_x, unused_y = struct.unpack('HHHH', ws)
119 return lines, columns
120
121
122def MakeRequestUrl(state, url):
123 return 'http%s://%s' % ('s' if state.ssl else '', url)
124
125
126class ProgressBar(object):
127 SIZE_WIDTH = 11
128 SPEED_WIDTH = 10
129 DURATION_WIDTH = 6
130 PERCENTAGE_WIDTH = 8
131
132 def __init__(self, name):
133 self._start_time = time.time()
134 self._name = name
135 self._size = 0
136 self._width = 0
137 self._name_width = 0
138 self._name_max = 0
139 self._stat_width = 0
140 self._max = 0
141 self.CalculateSize()
142 self.SetProgress(0)
143
144 def CalculateSize(self):
145 self._width = GetTerminalSize()[1] or _DEFAULT_TERMINAL_WIDTH
146 self._name_width = int(self._width * 0.3)
147 self._name_max = self._name_width
148 self._stat_width = self.SIZE_WIDTH + self.SPEED_WIDTH + self.DURATION_WIDTH
149 self._max = (self._width - self._name_width - self._stat_width -
150 self.PERCENTAGE_WIDTH)
151
152 def SizeToHuman(self, size_in_bytes):
153 if size_in_bytes < 1024:
154 unit = 'B'
155 value = size_in_bytes
156 elif size_in_bytes < 1024 ** 2:
157 unit = 'KiB'
158 value = size_in_bytes / 1024.0
159 elif size_in_bytes < 1024 ** 3:
160 unit = 'MiB'
161 value = size_in_bytes / (1024.0 ** 2)
162 elif size_in_bytes < 1024 ** 4:
163 unit = 'GiB'
164 value = size_in_bytes / (1024.0 ** 3)
165 return ' %6.1f %3s' % (value, unit)
166
167 def SpeedToHuman(self, speed_in_bs):
168 if speed_in_bs < 1024:
169 unit = 'B'
170 value = speed_in_bs
171 elif speed_in_bs < 1024 ** 2:
172 unit = 'K'
173 value = speed_in_bs / 1024.0
174 elif speed_in_bs < 1024 ** 3:
175 unit = 'M'
176 value = speed_in_bs / (1024.0 ** 2)
177 elif speed_in_bs < 1024 ** 4:
178 unit = 'G'
179 value = speed_in_bs / (1024.0 ** 3)
180 return ' %6.1f%s/s' % (value, unit)
181
182 def DurationToClock(self, duration):
183 return ' %02d:%02d' % (duration / 60, duration % 60)
184
185 def SetProgress(self, percentage, size=None):
186 current_width = GetTerminalSize()[1]
187 if self._width != current_width:
188 self.CalculateSize()
189
190 if size is not None:
191 self._size = size
192
193 elapse_time = time.time() - self._start_time
194 speed = self._size / float(elapse_time)
195
196 size_str = self.SizeToHuman(self._size)
197 speed_str = self.SpeedToHuman(speed)
198 elapse_str = self.DurationToClock(elapse_time)
199
200 width = int(self._max * percentage / 100.0)
201 sys.stdout.write(
202 '%*s' % (- self._name_max,
203 self._name if len(self._name) <= self._name_max else
204 self._name[:self._name_max - 4] + ' ...') +
205 size_str + speed_str + elapse_str +
206 ((' [' + '#' * width + ' ' * (self._max - width) + ']' +
207 '%4d%%' % int(percentage)) if self._max > 2 else '') + '\r')
208 sys.stdout.flush()
209
210 def End(self):
211 self.SetProgress(100.0)
212 sys.stdout.write('\n')
213 sys.stdout.flush()
214
215
216class DaemonState(object):
217 """DaemonState is used for storing Overlord state info."""
218 def __init__(self):
219 self.version_sha1sum = GetVersionDigest()
220 self.host = None
221 self.port = None
222 self.ssl = False
223 self.ssh = False
224 self.orig_host = None
225 self.ssh_pid = None
226 self.username = None
227 self.password = None
228 self.selected_mid = None
229 self.forwards = {}
230 self.listing = []
231 self.last_list = 0
232
233
234class OverlordClientDaemon(object):
235 """Overlord Client Daemon."""
236 def __init__(self):
237 self._state = DaemonState()
238 self._server = None
239
240 def Start(self):
241 self.StartRPCServer()
242
243 def StartRPCServer(self):
244 self._server = SimpleJSONRPCServer(_OVERLORD_CLIENT_DAEMON_RPC_ADDR,
245 logRequests=False)
246 exports = [
247 (self.State, 'State'),
248 (self.Ping, 'Ping'),
249 (self.GetPid, 'GetPid'),
250 (self.Connect, 'Connect'),
251 (self.Clients, 'Clients'),
252 (self.SelectClient, 'SelectClient'),
253 (self.AddForward, 'AddForward'),
254 (self.RemoveForward, 'RemoveForward'),
255 (self.RemoveAllForward, 'RemoveAllForward'),
256 ]
257 for func, name in exports:
258 self._server.register_function(func, name)
259
260 pid = os.fork()
261 if pid == 0:
262 self._server.serve_forever()
263
264 @staticmethod
265 def GetRPCServer():
266 """Returns the Overlord client daemon RPC server."""
267 server = jsonrpclib.Server('http://%s:%d' %
268 _OVERLORD_CLIENT_DAEMON_RPC_ADDR)
269 try:
270 server.Ping()
271 except Exception:
272 return None
273 return server
274
275 def State(self):
276 return self._state
277
278 def Ping(self):
279 return True
280
281 def GetPid(self):
282 return os.getpid()
283
284 def _UrlOpen(self, url):
285 """Wrapper for urllib2.urlopen.
286
287 It selects correct HTTP scheme according to self._stat.ssl and add HTTP
288 basic auth headers.
289 """
290 url = MakeRequestUrl(self._state, url)
291 request = urllib2.Request(url)
292 if self._state.username is not None and self._state.password is not None:
293 request.add_header(*BasicAuthHeader(self._state.username,
294 self._state.password))
Wei-Ning Huangee7ca8d2015-12-12 05:48:02 +0800295 return urllib2.urlopen(request, timeout=_DEFAULT_HTTP_TIMEOUT)
Wei-Ning Huang91aaeed2015-09-24 14:51:56 +0800296
297 def _GetJSON(self, path):
298 url = '%s:%d%s' % (self._state.host, self._state.port, path)
299 return json.loads(self._UrlOpen(url).read())
300
301 def Connect(self, host, port=_OVERLORD_HTTP_PORT, ssh_pid=None,
302 username=None, password=None, orig_host=None):
303 self._state.username = username
304 self._state.password = password
305 self._state.host = host
306 self._state.port = port
307 self._state.ssl = False
308 self._state.orig_host = orig_host
309 self._state.ssh_pid = ssh_pid
310 self._state.selected_mid = None
311
312 try:
313 h = self._UrlOpen('%s:%d' % (host, port))
314 # Probably not an HTTP server, try HTTPS
315 if _OVERLORD_RESPONSE_KEYWORD not in h.read():
316 self._state.ssl = True
317 self._UrlOpen('%s:%d' % (host, port))
318 except urllib2.HTTPError as e:
319 logging.exception(e)
320 return e.getcode()
321 except Exception as e:
322 logging.exception(e)
323 return str(e)
324 return True
325
326 def Clients(self):
327 if time.time() - self._state.last_list <= _LIST_CACHE_TIMEOUT:
328 return self._state.listing
329
330 mids = [client['mid'] for client in self._GetJSON('/api/agents/list')]
331 self._state.listing = sorted(list(set(mids)))
332 self._state.last_list = time.time()
333 return self._state.listing
334
335 def SelectClient(self, mid):
336 self._state.selected_mid = mid
337
338 def AddForward(self, mid, remote, local, pid):
339 self._state.forwards[local] = (mid, remote, pid)
340
341 def RemoveForward(self, local_port):
342 try:
343 unused_mid, unused_remote, pid = self._state.forwards[local_port]
344 KillGraceful(pid)
345 del self._state.forwards[local_port]
346 except (KeyError, OSError):
347 pass
348
349 def RemoveAllForward(self):
350 for unused_mid, unused_remote, pid in self._state.forwards.values():
351 try:
352 KillGraceful(pid)
353 except OSError:
354 pass
355 self._state.forwards = {}
356
357
358class TerminalWebSocketClient(WebSocketBaseClient):
359 def __init__(self, mid, *args, **kwargs):
360 super(TerminalWebSocketClient, self).__init__(*args, **kwargs)
361 self._mid = mid
362 self._stdin_fd = sys.stdin.fileno()
363 self._old_termios = None
364
365 def handshake_ok(self):
366 pass
367
368 def opened(self):
369 nonlocals = {'size': (80, 40)}
370
371 def _ResizeWindow():
372 size = GetTerminalSize()
373 if size != nonlocals['size']: # Size not changed, ignore
374 control = {'command': 'resize', 'params': list(size)}
375 payload = chr(_CONTROL_START) + json.dumps(control) + chr(_CONTROL_END)
376 nonlocals['size'] = size
377 try:
378 self.send(payload, binary=True)
379 except Exception:
380 pass
381
382 def _FeedInput():
383 flags = fcntl.fcntl(sys.stdin, fcntl.F_GETFL)
384 fcntl.fcntl(sys.stdin, fcntl.F_SETFL, flags | os.O_NONBLOCK)
385
386 self._old_termios = termios.tcgetattr(self._stdin_fd)
387 tty.setraw(self._stdin_fd)
388
389 READY, ENTER_PRESSED, ESCAPE_PRESSED = range(3)
390
391 try:
392 state = READY
393 while True:
394 rd, unused_w, unused_x = select.select([sys.stdin], [], [], 0.5)
395
396 # We can't install a signal handler in the main thread since it'll
397 # interrupt the read/write system call (ws4py performing send/recv).
398 # Use polling instead (select's timeout is 0.5 seconds)
399 _ResizeWindow()
400
401 if sys.stdin in rd:
402 data = sys.stdin.read()
403
404 # Scan for escape sequence
405 for x in data:
406 if state == READY:
407 state = ENTER_PRESSED if x == chr(0x0d) else READY
408 elif state == ENTER_PRESSED:
409 state = ESCAPE_PRESSED if x == _ESCAPE else READY
410 elif state == ESCAPE_PRESSED:
411 if x == '.':
412 self.close()
413 raise RuntimeError('quit')
414 else:
415 state = READY
416
417 self.send(data)
418 except (KeyboardInterrupt, RuntimeError):
419 pass
420
421 t = threading.Thread(target=_FeedInput)
422 t.daemon = True
423 t.start()
424
425 def closed(self, code, reason=None):
426 termios.tcsetattr(self._stdin_fd, termios.TCSANOW, self._old_termios)
427 print('Connection to %s closed.' % self._mid)
428
429 def received_message(self, msg):
430 if msg.is_binary:
431 sys.stdout.write(msg.data)
432 sys.stdout.flush()
433
434
435class ShellWebSocketClient(WebSocketBaseClient):
436 def __init__(self, output, *args, **kwargs):
437 """Constructor.
438
439 Args:
440 output: output file object.
441 """
442 self.output = output
443 super(ShellWebSocketClient, self).__init__(*args, **kwargs)
444
445 def handshake_ok(self):
446 pass
447
448 def opened(self):
449 pass
450
451 def closed(self, code, reason=None):
452 pass
453
454 def received_message(self, msg):
455 if msg.is_binary:
456 self.output.write(msg.data)
457 self.output.flush()
458
459
460class ForwarderWebSocketClient(WebSocketBaseClient):
461 def __init__(self, sock, *args, **kwargs):
462 super(ForwarderWebSocketClient, self).__init__(*args, **kwargs)
463 self._sock = sock
464 self._stop = threading.Event()
465
466 def handshake_ok(self):
467 pass
468
469 def opened(self):
470 def _FeedInput():
471 try:
472 self._sock.setblocking(False)
473 while True:
474 rd, unused_w, unused_x = select.select([self._sock], [], [], 0.5)
475 if self._stop.is_set():
476 break
477 if self._sock in rd:
478 data = self._sock.recv(_BUFSIZ)
479 if len(data) == 0:
480 break
481 self.send(data, binary=True)
482 except Exception:
483 pass
484 finally:
485 self._sock.close()
486 self.close()
487
488 t = threading.Thread(target=_FeedInput)
489 t.daemon = True
490 t.start()
491
492 def closed(self, code, reason=None):
493 self._stop.set()
494 sys.exit(0)
495
496 def received_message(self, msg):
497 if msg.is_binary:
498 self._sock.send(msg.data)
499
500
501def Arg(*args, **kwargs):
502 return (args, kwargs)
503
504
505def Command(command, help_msg=None, args=None):
506 """Decorator for adding argparse parameter for a method."""
507 if args is None:
508 args = []
509 def WrapFunc(func):
510 def Wrapped(*args, **kwargs):
511 return func(*args, **kwargs)
512 # pylint: disable=W0212
513 Wrapped.__arg_attr = {'command': command, 'help': help_msg, 'args': args}
514 return Wrapped
515 return WrapFunc
516
517
518def ParseMethodSubCommands(cls):
519 """Decorator for a class using the @Command decorator.
520
521 This decorator retrieve command info from each method and append it in to the
522 SUBCOMMANDS class variable, which is later used to construct parser.
523 """
524 for unused_key, method in cls.__dict__.iteritems():
525 if hasattr(method, '__arg_attr'):
526 cls.SUBCOMMANDS.append(method.__arg_attr) # pylint: disable=W0212
527 return cls
528
529
530@ParseMethodSubCommands
531class OverlordCLIClient(object):
532 """Overlord command line interface client."""
533
534 SUBCOMMANDS = []
535
536 def __init__(self):
537 self._parser = self._BuildParser()
538 self._selected_mid = None
539 self._server = None
540 self._state = None
541
542 def _BuildParser(self):
543 root_parser = argparse.ArgumentParser(prog='ovl')
544 subparsers = root_parser.add_subparsers(help='sub-command')
545
546 root_parser.add_argument('-s', dest='selected_mid', action='store',
547 default=None,
548 help='select target to execute command on')
549 root_parser.add_argument('-S', dest='select_mid_before_action',
550 action='store_true', default=False,
551 help='select target before executing command')
552
553 for attr in self.SUBCOMMANDS:
554 parser = subparsers.add_parser(attr['command'], help=attr['help'])
555 parser.set_defaults(which=attr['command'])
556 for arg in attr['args']:
557 parser.add_argument(*arg[0], **arg[1])
558
559 return root_parser
560
561 def Main(self):
562 # We want to pass the rest of arguments after shell command directly to the
563 # function without parsing it.
564 try:
565 index = sys.argv.index('shell')
566 except ValueError:
567 args = self._parser.parse_args()
568 else:
569 args = self._parser.parse_args(sys.argv[1:index + 1])
570
571 command = args.which
572 self._selected_mid = args.selected_mid
573
574 if command == 'kill-server':
575 self.KillServer()
576 return
577
578 self.StartDaemon()
579 if command == 'status':
580 self.Status()
581 return
582 elif command == 'connect':
583 self.Connect(args)
584 return
585
586 # The following command requires connection to the server
587 self.CheckConnection()
588
589 if args.select_mid_before_action:
590 self.SelectClient(store=False)
591
592 if command == 'select':
593 self.SelectClient(args)
594 elif command == 'ls':
595 self.ListClients()
596 elif command == 'shell':
597 command = sys.argv[sys.argv.index('shell') + 1:]
598 self.Shell(command)
599 elif command == 'push':
600 self.Push(args)
601 elif command == 'pull':
602 self.Pull(args)
603 elif command == 'forward':
604 self.Forward(args)
605
606 def _UrlOpen(self, url):
607 url = MakeRequestUrl(self._state, url)
608 request = urllib2.Request(url)
609 if self._state.username is not None and self._state.password is not None:
610 request.add_header(*BasicAuthHeader(self._state.username,
611 self._state.password))
Wei-Ning Huangee7ca8d2015-12-12 05:48:02 +0800612 return urllib2.urlopen(request, timeout=_DEFAULT_HTTP_TIMEOUT)
Wei-Ning Huang91aaeed2015-09-24 14:51:56 +0800613
614 def _HTTPPostFile(self, url, filename, progress=None, user=None, passwd=None):
615 """Perform HTTP POST and upload file to Overlord.
616
617 To minimize the external dependencies, we construct the HTTP post request
618 by ourselves.
619 """
620 url = MakeRequestUrl(self._state, url)
621 size = os.stat(filename).st_size
622 boundary = '-----------%s' % _HTTP_BOUNDARY_MAGIC
623 CRLF = '\r\n'
624 parse = urlparse.urlparse(url)
625
626 part_headers = [
627 '--' + boundary,
628 'Content-Disposition: form-data; name="file"; '
629 'filename="%s"' % os.path.basename(filename),
630 'Content-Type: application/octet-stream',
631 '', ''
632 ]
633 part_header = CRLF.join(part_headers)
634 end_part = CRLF + '--' + boundary + '--' + CRLF
635
636 content_length = len(part_header) + size + len(end_part)
637 if parse.scheme == 'http':
638 h = httplib.HTTP(parse.netloc)
639 else:
640 h = httplib.HTTPS(parse.netloc)
641
642 post_path = url[url.index(parse.netloc) + len(parse.netloc):]
643 h.putrequest('POST', post_path)
644 h.putheader('Content-Length', content_length)
645 h.putheader('Content-Type', 'multipart/form-data; boundary=%s' % boundary)
646
647 if user and passwd:
648 h.putheader(*BasicAuthHeader(user, passwd))
649 h.endheaders()
650 h.send(part_header)
651
652 count = 0
653 with open(filename, 'r') as f:
654 while True:
655 data = f.read(_BUFSIZ)
656 if not data:
657 break
658 count += len(data)
659 if progress:
660 progress(int(count * 100.0 / size), count)
661 h.send(data)
662
663 h.send(end_part)
664 progress(100)
665
666 if count != size:
667 logging.warning('file changed during upload, upload may be truncated.')
668
669 errcode, unused_x, unused_y = h.getreply()
670 return errcode == 200
671
672 def StartDaemon(self):
673 self._server = OverlordClientDaemon.GetRPCServer()
674 if self._server is None:
675 print('* daemon not running, starting it now on port %d ... *' %
676 _OVERLORD_CLIENT_DAEMON_PORT)
677 OverlordClientDaemon().Start()
678 time.sleep(1)
679 self._server = OverlordClientDaemon.GetRPCServer()
680 if self._server is not None:
681 print('* daemon started successfully *')
682
683 self._state = self._server.State()
684 sha1sum = GetVersionDigest()
685
686 if sha1sum != self._state.version_sha1sum:
687 print('ovl server is out of date. killing...')
688 KillGraceful(self._server.GetPid())
689 self.StartDaemon()
690
691 def GetSSHControlFile(self, host):
692 return _SSH_CONTROL_SOCKET_PREFIX + host
693
694 def SSHTunnel(self, user, host, port):
695 """SSH forward the remote overlord server.
696
697 Overlord server may not have port 9000 open to the public network, in such
698 case we can SSH forward the port to localhost.
699 """
700
701 control_file = self.GetSSHControlFile(host)
702 try:
703 os.unlink(control_file)
704 except Exception:
705 pass
706
707 subprocess.Popen([
708 'ssh', '-Nf',
709 '-M', # Enable master mode
710 '-S', control_file,
711 '-L', '9000:localhost:9000',
712 '-p', str(port),
713 '%s%s' % (user + '@' if user else '', host)
714 ]).wait()
715
716 p = subprocess.Popen([
717 'ssh',
718 '-S', control_file,
719 '-O', 'check', host,
720 ], stderr=subprocess.PIPE)
721 unused_stdout, stderr = p.communicate()
722
723 s = re.search(r'pid=(\d+)', stderr)
724 if s:
725 return int(s.group(1))
726
727 raise RuntimeError('can not establish ssh connection')
728
729 def CheckConnection(self):
730 if self._state.host is None:
731 raise RuntimeError('not connected to any server, abort')
732
733 try:
734 self._server.Clients()
735 except Exception:
736 raise RuntimeError('remote server disconnected, abort')
737
738 if self._state.ssh_pid is not None:
739 ret = subprocess.Popen(['kill', '-0', str(self._state.ssh_pid)],
740 stdout=subprocess.PIPE,
741 stderr=subprocess.PIPE).wait()
742 if ret != 0:
743 raise RuntimeError('ssh tunnel disconnected, please re-connect')
744
745 def CheckClient(self):
746 if self._selected_mid is None:
747 if self._state.selected_mid is None:
748 raise RuntimeError('No client is selected')
749 self._selected_mid = self._state.selected_mid
750
751 if self._selected_mid not in self._server.Clients():
752 raise RuntimeError('client %s disappeared' % self._selected_mid)
753
754 def CheckOutput(self, command):
755 headers = []
756 if self._state.username is not None and self._state.password is not None:
757 headers.append(BasicAuthHeader(self._state.username,
758 self._state.password))
759
760 scheme = 'ws%s://' % ('s' if self._state.ssl else '')
761 sio = StringIO.StringIO()
762 ws = ShellWebSocketClient(sio,
763 scheme + '%s:%d/api/agent/shell/%s?command=%s' %
764 (self._state.host, self._state.port,
765 self._selected_mid, urllib2.quote(command)),
766 headers=headers)
767 ws.connect()
768 ws.run()
769 return sio.getvalue()
770
771 @Command('status', 'show Overlord connection status')
772 def Status(self):
773 if self._state.host is None:
774 print('Not connected to any host.')
775 else:
776 if self._state.ssh_pid is not None:
777 print('Connected to %s with SSH tunneling.' % self._state.orig_host)
778 else:
779 print('Connected to %s:%d.' % (self._state.host, self._state.port))
780
781 if self._selected_mid is None:
782 self._selected_mid = self._state.selected_mid
783
784 if self._selected_mid is None:
785 print('No client is selected.')
786 else:
787 print('Client %s selected.' % self._selected_mid)
788
789 @Command('connect', 'connect to Overlord server', [
790 Arg('host', metavar='HOST', type=str, default='localhost',
791 help='Overlord hostname/IP'),
792 Arg('port', metavar='PORT', type=int,
793 default=_OVERLORD_HTTP_PORT, help='Overlord port'),
794 Arg('-f', '--forward', dest='ssh_forward', default=False,
795 action='store_true',
796 help='connect with SSH forwarding to the host'),
797 Arg('-p', '--ssh-port', dest='ssh_port', default=22,
798 type=int, help='SSH server port for SSH forwarding'),
799 Arg('-l', '--ssh-login', dest='ssh_login', default='',
800 type=str, help='SSH server login name for SSH forwarding'),
801 Arg('-u', '--user', dest='user', default=None,
802 type=str, help='Overlord HTTP auth username'),
803 Arg('-w', '--passwd', dest='passwd', default=None, type=str,
804 help='Overlord HTTP auth password')])
805 def Connect(self, args):
806 ssh_pid = None
807 host = args.host
808 orig_host = args.host
809
810 if args.ssh_forward:
811 # Kill previous SSH tunnel
812 self.KillSSHTunnel()
813
814 ssh_pid = self.SSHTunnel(args.ssh_login, args.host, args.ssh_port)
815 host = 'localhost'
816
817 status = self._server.Connect(host, args.port, ssh_pid, args.user,
818 args.passwd, orig_host)
819 if status is not True:
820 if isinstance(status, int):
821 if status == 401:
822 msg = '401 Unauthorized'
823 else:
824 msg = 'HTTP %d' % status
825 else:
826 msg = status
827 print('can not connect to %s: %s' % (host, msg))
828
829 @Command('kill-server', 'kill overlord CLI client server')
830 def KillServer(self):
831 self._server = OverlordClientDaemon.GetRPCServer()
832 if self._server is None:
833 return
834
Wei-Ning Huangee7ca8d2015-12-12 05:48:02 +0800835 self._state = self._server.State()
836
Wei-Ning Huang91aaeed2015-09-24 14:51:56 +0800837 # Kill SSH Tunnel
838 self.KillSSHTunnel()
839
840 # Kill server daemon
841 KillGraceful(self._server.GetPid())
842
843 def KillSSHTunnel(self):
844 if self._state.ssh_pid is not None:
845 KillGraceful(self._state.ssh_pid)
846
847 @Command('ls', 'list all clients')
848 def ListClients(self):
849 for client in self._server.Clients():
850 print(client)
851
852 @Command('select', 'select default client', [
853 Arg('mid', metavar='mid', nargs='?', default=None)])
854 def SelectClient(self, args=None, store=True):
855 clients = self._server.Clients()
856
857 mid = args.mid if args is not None else None
858 if mid is None:
859 print('Select from the following clients:')
860 for i, client in enumerate(clients):
861 print(' %d. %s' % (i + 1, client))
862
863 print('\nSelection: ', end='')
864 try:
865 choice = int(raw_input()) - 1
866 mid = clients[choice]
867 except ValueError:
868 raise RuntimeError('select: invalid selection')
869 except IndexError:
870 raise RuntimeError('select: selection out of range')
871 else:
872 if mid not in clients:
873 raise RuntimeError('select: client %s does not exist' % mid)
874
875 self._selected_mid = mid
876 if store:
877 self._server.SelectClient(mid)
878 print('Client %s selected' % mid)
879
880 @Command('shell', 'open a shell or execute a shell command', [
881 Arg('command', metavar='CMD', nargs='?', help='command to execute')])
882 def Shell(self, command=None):
883 if command is None:
884 command = []
885 self.CheckClient()
886
887 headers = []
888 if self._state.username is not None and self._state.password is not None:
889 headers.append(BasicAuthHeader(self._state.username,
890 self._state.password))
891
892 scheme = 'ws%s://' % ('s' if self._state.ssl else '')
893 if len(command) == 0:
894 ws = TerminalWebSocketClient(self._selected_mid,
895 scheme + '%s:%d/api/agent/tty/%s' %
896 (self._state.host, self._state.port,
897 self._selected_mid), headers=headers)
898 else:
899 cmd = ' '.join(command)
900 ws = ShellWebSocketClient(sys.stdout,
901 scheme + '%s:%d/api/agent/shell/%s?command=%s' %
902 (self._state.host, self._state.port,
903 self._selected_mid, urllib2.quote(cmd)),
904 headers=headers)
905 ws.connect()
906 ws.run()
907
908 @Command('push', 'push a file or directory to remote', [
909 Arg('src', metavar='SOURCE'),
910 Arg('dst', metavar='DESTINATION')])
911 def Push(self, args):
912 self.CheckClient()
913
914 if not os.path.exists(args.src):
915 raise RuntimeError('push: can not stat "%s": no such file or directory'
916 % args.src)
917
918 if not os.access(args.src, os.R_OK):
919 raise RuntimeError('push: can not open "%s" for reading' % args.src)
920
Wei-Ning Huangee7ca8d2015-12-12 05:48:02 +0800921 @AutoRetry('push', _RETRY_TIMES)
Wei-Ning Huang91aaeed2015-09-24 14:51:56 +0800922 def _push(src, dst):
923 src_base = os.path.basename(src)
Wei-Ning Huangee7ca8d2015-12-12 05:48:02 +0800924
925 # Local file is a link
926 if os.path.islink(src):
927 pbar = ProgressBar(src_base)
928 link_path = os.readlink(src)
929 self.CheckOutput('mkdir -p %(dirname)s; '
930 'if [ -d "%(dst)s" ]; then '
931 'ln -sf "%(link_path)s" "%(dst)s/%(link_name)s"; '
932 'else ln -sf "%(link_path)s" "%(dst)s"; fi' %
933 dict(dirname=os.path.dirname(dst),
934 link_path=link_path, dst=dst,
935 link_name=src_base))
936 pbar.End()
937 return
938
Wei-Ning Huang91aaeed2015-09-24 14:51:56 +0800939 mode = '0%o' % (0x1FF & os.stat(src).st_mode)
940 url = ('%s:%d/api/agent/upload/%s?dest=%s&perm=%s' %
941 (self._state.host, self._state.port, self._selected_mid, dst,
942 mode))
943 try:
944 self._UrlOpen(url + '&filename=%s' % src_base)
945 except urllib2.HTTPError as e:
946 msg = json.loads(e.read()).get('error', None)
947 raise RuntimeError('push: %s' % msg)
948
949 pbar = ProgressBar(src_base)
950 self._HTTPPostFile(url, src, pbar.SetProgress,
951 self._state.username, self._state.password)
952 pbar.End()
953
954 if os.path.isdir(args.src):
955 dst_exists = ast.literal_eval(self.CheckOutput(
956 'stat %s >/dev/null 2>&1 && echo True || echo False' % args.dst))
957 for root, unused_x, files in os.walk(args.src):
958 # If destination directory does not exist, we should strip the first
959 # layer of directory. For example: src_dir contains a single file 'A'
960 #
961 # push src_dir dest_dir
962 #
963 # If dest_dir exists, the resulting directory structure should be:
964 # dest_dir/src_dir/A
965 # If dest_dir does not exist, the resulting directory structure should
966 # be:
967 # dest_dir/A
968 dst_root = root if dst_exists else root[len(args.src):].lstrip('/')
969 for name in files:
970 _push(os.path.join(root, name),
971 os.path.join(args.dst, dst_root, name))
972 else:
973 _push(args.src, args.dst)
974
975 @Command('pull', 'pull a file or directory from remote', [
976 Arg('src', metavar='SOURCE'),
977 Arg('dst', metavar='DESTINATION', default='.', nargs='?')])
978 def Pull(self, args):
979 self.CheckClient()
980
Wei-Ning Huangee7ca8d2015-12-12 05:48:02 +0800981 @AutoRetry('pull', _RETRY_TIMES)
982 def _pull(src, dst, ftype, perm=0644, link=None):
983 try:
984 os.makedirs(os.path.dirname(dst))
985 except Exception:
986 pass
987
988 src_base = os.path.basename(src)
989
990 # Remote file is a link
991 if ftype == 'l':
992 pbar = ProgressBar(src_base)
993 if os.path.exists(dst):
994 os.remove(dst)
995 os.symlink(link, dst)
996 pbar.End()
997 return
998
Wei-Ning Huang91aaeed2015-09-24 14:51:56 +0800999 url = ('%s:%d/api/agent/download/%s?filename=%s' %
1000 (self._state.host, self._state.port, self._selected_mid,
1001 urllib2.quote(src)))
1002 try:
1003 h = self._UrlOpen(url)
1004 except urllib2.HTTPError as e:
1005 msg = json.loads(e.read()).get('error', 'unkown error')
1006 raise RuntimeError('pull: %s' % msg)
1007 except KeyboardInterrupt:
1008 return
1009
Wei-Ning Huangee7ca8d2015-12-12 05:48:02 +08001010 pbar = ProgressBar(src_base)
Wei-Ning Huang91aaeed2015-09-24 14:51:56 +08001011 with open(dst, 'w') as f:
1012 os.fchmod(f.fileno(), perm)
1013 total_size = int(h.headers.get('Content-Length'))
1014 downloaded_size = 0
Wei-Ning Huangee7ca8d2015-12-12 05:48:02 +08001015
Wei-Ning Huang91aaeed2015-09-24 14:51:56 +08001016 while True:
1017 data = h.read(_BUFSIZ)
Wei-Ning Huangee7ca8d2015-12-12 05:48:02 +08001018 if len(data) == 0:
1019 break
Wei-Ning Huang91aaeed2015-09-24 14:51:56 +08001020 downloaded_size += len(data)
1021 pbar.SetProgress(float(downloaded_size) * 100 / total_size,
1022 downloaded_size)
Wei-Ning Huang91aaeed2015-09-24 14:51:56 +08001023 f.write(data)
1024 pbar.End()
1025
1026 # Use find to get a listing of all files under a root directory. The 'stat'
1027 # command is used to retrieve the filename and it's filemode.
1028 output = self.CheckOutput(
1029 'cd $HOME; '
1030 'stat "%(src)s" >/dev/null && '
Wei-Ning Huangee7ca8d2015-12-12 05:48:02 +08001031 'find "%(src)s" \'(\' -type f -o -type l \')\' '
1032 '-printf \'%%m\t%%p\t%%y\t%%l\n\''
Wei-Ning Huang91aaeed2015-09-24 14:51:56 +08001033 % {'src': args.src})
1034
1035 # We got error from the stat command
1036 if output.startswith('stat: '):
1037 sys.stderr.write(output)
1038 return
1039
Wei-Ning Huangee7ca8d2015-12-12 05:48:02 +08001040 entries = output.strip('\n').split('\n')
Wei-Ning Huang91aaeed2015-09-24 14:51:56 +08001041 common_prefix = os.path.dirname(args.src)
1042
1043 if len(entries) == 1:
1044 entry = entries[0]
Wei-Ning Huangee7ca8d2015-12-12 05:48:02 +08001045 perm, src_path, ftype, link = entry.split('\t', -1)
Wei-Ning Huang91aaeed2015-09-24 14:51:56 +08001046 if os.path.isdir(args.dst):
1047 dst = os.path.join(args.dst, os.path.basename(src_path))
1048 else:
1049 dst = args.dst
Wei-Ning Huangee7ca8d2015-12-12 05:48:02 +08001050 _pull(src_path, dst, ftype, int(perm, base=8), link)
Wei-Ning Huang91aaeed2015-09-24 14:51:56 +08001051 else:
1052 if not os.path.exists(args.dst):
1053 common_prefix = args.src
1054
1055 for entry in entries:
Wei-Ning Huangee7ca8d2015-12-12 05:48:02 +08001056 perm, src_path, ftype, link = entry.split('\t', -1)
Wei-Ning Huang91aaeed2015-09-24 14:51:56 +08001057 rel_dst = src_path[len(common_prefix):].lstrip('/')
Wei-Ning Huangee7ca8d2015-12-12 05:48:02 +08001058 _pull(src_path, os.path.join(args.dst, rel_dst), ftype,
1059 int(perm, base=8), link)
Wei-Ning Huang91aaeed2015-09-24 14:51:56 +08001060
1061 @Command('forward', 'forward remote port to local port', [
1062 Arg('--list', dest='list_all', action='store_true', default=False,
1063 help='list all port forwarding sessions'),
1064 Arg('--remove', metavar='LOCAL_PORT', dest='remove', type=int,
1065 default=None,
1066 help='remove port forwarding for local port LOCAL_PORT'),
1067 Arg('--remove-all', dest='remove_all', action='store_true',
1068 default=False, help='remove all port forwarding'),
1069 Arg('remote', metavar='REMOTE_PORT', type=int, nargs='?'),
1070 Arg('local', metavar='LOCAL_PORT', type=int, nargs='?')])
1071 def Forward(self, args):
1072 if args.list_all:
1073 max_len = 10
1074 if len(self._state.forwards):
1075 max_len = max([len(v[0]) for v in self._state.forwards.values()])
1076
1077 print('%-*s %-8s %-8s' % (max_len, 'Client', 'Remote', 'Local'))
1078 for local in sorted(self._state.forwards.keys()):
1079 value = self._state.forwards[local]
1080 print('%-*s %-8s %-8s' % (max_len, value[0], value[1], local))
1081 return
1082
1083 if args.remove_all:
1084 self._server.RemoveAllForward()
1085 return
1086
1087 if args.remove:
1088 self._server.RemoveForward(args.remove)
1089 return
1090
1091 self.CheckClient()
1092
1093 if args.local is None:
1094 args.local = args.remote
1095 remote = int(args.remote)
1096 local = int(args.local)
1097
1098 def HandleConnection(conn):
1099 headers = []
1100 if self._state.username is not None and self._state.password is not None:
1101 headers.append(BasicAuthHeader(self._state.username,
1102 self._state.password))
1103
1104 scheme = 'ws%s://' % ('s' if self._state.ssl else '')
1105 ws = ForwarderWebSocketClient(
1106 conn,
1107 scheme + '%s:%d/api/agent/forward/%s?port=%d' %
1108 (self._state.host, self._state.port, self._selected_mid, remote),
1109 headers=headers)
1110 try:
1111 ws.connect()
1112 ws.run()
1113 except Exception as e:
1114 print('error: %s' % e)
1115 finally:
1116 ws.close()
1117
1118 server = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
1119 server.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
1120 server.bind(('0.0.0.0', local))
1121 server.listen(5)
1122
1123 pid = os.fork()
1124 if pid == 0:
1125 while True:
1126 conn, unused_addr = server.accept()
1127 t = threading.Thread(target=HandleConnection, args=(conn,))
1128 t.daemon = True
1129 t.start()
1130 else:
1131 self._server.AddForward(self._selected_mid, remote, local, pid)
1132
1133
1134def main():
1135 logging.basicConfig(level=logging.INFO)
1136
1137 # Add DaemonState to JSONRPC lib classes
1138 Config.instance().classes.add(DaemonState)
1139
1140 ovl = OverlordCLIClient()
1141 try:
1142 ovl.Main()
1143 except KeyboardInterrupt:
1144 print('Ctrl-C received, abort')
1145 except Exception as e:
1146 print('error: %s' % e)
1147
1148
1149if __name__ == '__main__':
1150 main()