blob: c44620c2f2477fa59146b04c2954ab30db1454d4 [file] [log] [blame]
Wei-Ning Huang91aaeed2015-09-24 14:51:56 +08001#!/usr/bin/python -u
2# -*- coding: utf-8 -*-
3#
4# Copyright 2015 The Chromium OS Authors. All rights reserved.
5# Use of this source code is governed by a BSD-style license that can be
6# found in the LICENSE file.
7
8from __future__ import print_function
9
10import argparse
11import ast
12import base64
13import fcntl
14import hashlib
15import httplib
16import json
17import jsonrpclib
18import logging
19import os
20import re
21import select
22import signal
23import socket
24import StringIO
25import struct
26import subprocess
27import sys
28import tempfile
29import termios
30import threading
31import time
32import tty
33import urllib2
34import urlparse
35
36from jsonrpclib.SimpleJSONRPCServer import SimpleJSONRPCServer
37from jsonrpclib.config import Config
38from ws4py.client import WebSocketBaseClient
39
40# Python version >= 2.7.9 enables SSL check by default, bypass it.
41try:
42 import ssl
43 # pylint: disable=W0212
44 ssl._create_default_https_context = ssl._create_unverified_context
45except Exception:
46 pass
47
48
49_ESCAPE = '~'
50_BUFSIZ = 8192
51_OVERLORD_PORT = 4455
52_OVERLORD_HTTP_PORT = 9000
53_OVERLORD_CLIENT_DAEMON_PORT = 4488
54_OVERLORD_CLIENT_DAEMON_RPC_ADDR = ('127.0.0.1', _OVERLORD_CLIENT_DAEMON_PORT)
55
56_LIST_CACHE_TIMEOUT = 2
57_DEFAULT_TERMINAL_WIDTH = 80
58
59# echo -n overlord | md5sum
60_HTTP_BOUNDARY_MAGIC = '9246f080c855a69012707ab53489b921'
61
62_CONTROL_START = 128
63_CONTROL_END = 129
64_SSH_CONTROL_SOCKET_PREFIX = os.path.join(tempfile.gettempdir(),
65 'ovl-ssh-control-')
66
67# A string that will always be included in the response of
68# GET http://OVERLORD_SERVER:_OVERLORD_HTTP_PORT
69_OVERLORD_RESPONSE_KEYWORD = '<html>'
70
71
72def GetVersionDigest():
73 """Return the sha1sum of the current executing script."""
74 with open(__file__, 'r') as f:
75 return hashlib.sha1(f.read()).hexdigest()
76
77
78def KillGraceful(pid, wait_secs=1):
79 """Kill a process gracefully by first sending SIGTERM, wait for some time,
80 then send SIGKILL to make sure it's killed."""
81 try:
82 os.kill(pid, signal.SIGTERM)
83 time.sleep(wait_secs)
84 os.kill(pid, signal.SIGKILL)
85 except OSError:
86 pass
87
88
89def BasicAuthHeader(user, password):
90 """Return HTTP basic auth header."""
91 credential = base64.b64encode('%s:%s' % (user, password))
92 return ('Authorization', 'Basic %s' % credential)
93
94
95def GetTerminalSize():
96 """Retrieve terminal window size."""
97 ws = struct.pack('HHHH', 0, 0, 0, 0)
98 ws = fcntl.ioctl(0, termios.TIOCGWINSZ, ws)
99 lines, columns, unused_x, unused_y = struct.unpack('HHHH', ws)
100 return lines, columns
101
102
103def MakeRequestUrl(state, url):
104 return 'http%s://%s' % ('s' if state.ssl else '', url)
105
106
107class ProgressBar(object):
108 SIZE_WIDTH = 11
109 SPEED_WIDTH = 10
110 DURATION_WIDTH = 6
111 PERCENTAGE_WIDTH = 8
112
113 def __init__(self, name):
114 self._start_time = time.time()
115 self._name = name
116 self._size = 0
117 self._width = 0
118 self._name_width = 0
119 self._name_max = 0
120 self._stat_width = 0
121 self._max = 0
122 self.CalculateSize()
123 self.SetProgress(0)
124
125 def CalculateSize(self):
126 self._width = GetTerminalSize()[1] or _DEFAULT_TERMINAL_WIDTH
127 self._name_width = int(self._width * 0.3)
128 self._name_max = self._name_width
129 self._stat_width = self.SIZE_WIDTH + self.SPEED_WIDTH + self.DURATION_WIDTH
130 self._max = (self._width - self._name_width - self._stat_width -
131 self.PERCENTAGE_WIDTH)
132
133 def SizeToHuman(self, size_in_bytes):
134 if size_in_bytes < 1024:
135 unit = 'B'
136 value = size_in_bytes
137 elif size_in_bytes < 1024 ** 2:
138 unit = 'KiB'
139 value = size_in_bytes / 1024.0
140 elif size_in_bytes < 1024 ** 3:
141 unit = 'MiB'
142 value = size_in_bytes / (1024.0 ** 2)
143 elif size_in_bytes < 1024 ** 4:
144 unit = 'GiB'
145 value = size_in_bytes / (1024.0 ** 3)
146 return ' %6.1f %3s' % (value, unit)
147
148 def SpeedToHuman(self, speed_in_bs):
149 if speed_in_bs < 1024:
150 unit = 'B'
151 value = speed_in_bs
152 elif speed_in_bs < 1024 ** 2:
153 unit = 'K'
154 value = speed_in_bs / 1024.0
155 elif speed_in_bs < 1024 ** 3:
156 unit = 'M'
157 value = speed_in_bs / (1024.0 ** 2)
158 elif speed_in_bs < 1024 ** 4:
159 unit = 'G'
160 value = speed_in_bs / (1024.0 ** 3)
161 return ' %6.1f%s/s' % (value, unit)
162
163 def DurationToClock(self, duration):
164 return ' %02d:%02d' % (duration / 60, duration % 60)
165
166 def SetProgress(self, percentage, size=None):
167 current_width = GetTerminalSize()[1]
168 if self._width != current_width:
169 self.CalculateSize()
170
171 if size is not None:
172 self._size = size
173
174 elapse_time = time.time() - self._start_time
175 speed = self._size / float(elapse_time)
176
177 size_str = self.SizeToHuman(self._size)
178 speed_str = self.SpeedToHuman(speed)
179 elapse_str = self.DurationToClock(elapse_time)
180
181 width = int(self._max * percentage / 100.0)
182 sys.stdout.write(
183 '%*s' % (- self._name_max,
184 self._name if len(self._name) <= self._name_max else
185 self._name[:self._name_max - 4] + ' ...') +
186 size_str + speed_str + elapse_str +
187 ((' [' + '#' * width + ' ' * (self._max - width) + ']' +
188 '%4d%%' % int(percentage)) if self._max > 2 else '') + '\r')
189 sys.stdout.flush()
190
191 def End(self):
192 self.SetProgress(100.0)
193 sys.stdout.write('\n')
194 sys.stdout.flush()
195
196
197class DaemonState(object):
198 """DaemonState is used for storing Overlord state info."""
199 def __init__(self):
200 self.version_sha1sum = GetVersionDigest()
201 self.host = None
202 self.port = None
203 self.ssl = False
204 self.ssh = False
205 self.orig_host = None
206 self.ssh_pid = None
207 self.username = None
208 self.password = None
209 self.selected_mid = None
210 self.forwards = {}
211 self.listing = []
212 self.last_list = 0
213
214
215class OverlordClientDaemon(object):
216 """Overlord Client Daemon."""
217 def __init__(self):
218 self._state = DaemonState()
219 self._server = None
220
221 def Start(self):
222 self.StartRPCServer()
223
224 def StartRPCServer(self):
225 self._server = SimpleJSONRPCServer(_OVERLORD_CLIENT_DAEMON_RPC_ADDR,
226 logRequests=False)
227 exports = [
228 (self.State, 'State'),
229 (self.Ping, 'Ping'),
230 (self.GetPid, 'GetPid'),
231 (self.Connect, 'Connect'),
232 (self.Clients, 'Clients'),
233 (self.SelectClient, 'SelectClient'),
234 (self.AddForward, 'AddForward'),
235 (self.RemoveForward, 'RemoveForward'),
236 (self.RemoveAllForward, 'RemoveAllForward'),
237 ]
238 for func, name in exports:
239 self._server.register_function(func, name)
240
241 pid = os.fork()
242 if pid == 0:
243 self._server.serve_forever()
244
245 @staticmethod
246 def GetRPCServer():
247 """Returns the Overlord client daemon RPC server."""
248 server = jsonrpclib.Server('http://%s:%d' %
249 _OVERLORD_CLIENT_DAEMON_RPC_ADDR)
250 try:
251 server.Ping()
252 except Exception:
253 return None
254 return server
255
256 def State(self):
257 return self._state
258
259 def Ping(self):
260 return True
261
262 def GetPid(self):
263 return os.getpid()
264
265 def _UrlOpen(self, url):
266 """Wrapper for urllib2.urlopen.
267
268 It selects correct HTTP scheme according to self._stat.ssl and add HTTP
269 basic auth headers.
270 """
271 url = MakeRequestUrl(self._state, url)
272 request = urllib2.Request(url)
273 if self._state.username is not None and self._state.password is not None:
274 request.add_header(*BasicAuthHeader(self._state.username,
275 self._state.password))
276 return urllib2.urlopen(request)
277
278 def _GetJSON(self, path):
279 url = '%s:%d%s' % (self._state.host, self._state.port, path)
280 return json.loads(self._UrlOpen(url).read())
281
282 def Connect(self, host, port=_OVERLORD_HTTP_PORT, ssh_pid=None,
283 username=None, password=None, orig_host=None):
284 self._state.username = username
285 self._state.password = password
286 self._state.host = host
287 self._state.port = port
288 self._state.ssl = False
289 self._state.orig_host = orig_host
290 self._state.ssh_pid = ssh_pid
291 self._state.selected_mid = None
292
293 try:
294 h = self._UrlOpen('%s:%d' % (host, port))
295 # Probably not an HTTP server, try HTTPS
296 if _OVERLORD_RESPONSE_KEYWORD not in h.read():
297 self._state.ssl = True
298 self._UrlOpen('%s:%d' % (host, port))
299 except urllib2.HTTPError as e:
300 logging.exception(e)
301 return e.getcode()
302 except Exception as e:
303 logging.exception(e)
304 return str(e)
305 return True
306
307 def Clients(self):
308 if time.time() - self._state.last_list <= _LIST_CACHE_TIMEOUT:
309 return self._state.listing
310
311 mids = [client['mid'] for client in self._GetJSON('/api/agents/list')]
312 self._state.listing = sorted(list(set(mids)))
313 self._state.last_list = time.time()
314 return self._state.listing
315
316 def SelectClient(self, mid):
317 self._state.selected_mid = mid
318
319 def AddForward(self, mid, remote, local, pid):
320 self._state.forwards[local] = (mid, remote, pid)
321
322 def RemoveForward(self, local_port):
323 try:
324 unused_mid, unused_remote, pid = self._state.forwards[local_port]
325 KillGraceful(pid)
326 del self._state.forwards[local_port]
327 except (KeyError, OSError):
328 pass
329
330 def RemoveAllForward(self):
331 for unused_mid, unused_remote, pid in self._state.forwards.values():
332 try:
333 KillGraceful(pid)
334 except OSError:
335 pass
336 self._state.forwards = {}
337
338
339class TerminalWebSocketClient(WebSocketBaseClient):
340 def __init__(self, mid, *args, **kwargs):
341 super(TerminalWebSocketClient, self).__init__(*args, **kwargs)
342 self._mid = mid
343 self._stdin_fd = sys.stdin.fileno()
344 self._old_termios = None
345
346 def handshake_ok(self):
347 pass
348
349 def opened(self):
350 nonlocals = {'size': (80, 40)}
351
352 def _ResizeWindow():
353 size = GetTerminalSize()
354 if size != nonlocals['size']: # Size not changed, ignore
355 control = {'command': 'resize', 'params': list(size)}
356 payload = chr(_CONTROL_START) + json.dumps(control) + chr(_CONTROL_END)
357 nonlocals['size'] = size
358 try:
359 self.send(payload, binary=True)
360 except Exception:
361 pass
362
363 def _FeedInput():
364 flags = fcntl.fcntl(sys.stdin, fcntl.F_GETFL)
365 fcntl.fcntl(sys.stdin, fcntl.F_SETFL, flags | os.O_NONBLOCK)
366
367 self._old_termios = termios.tcgetattr(self._stdin_fd)
368 tty.setraw(self._stdin_fd)
369
370 READY, ENTER_PRESSED, ESCAPE_PRESSED = range(3)
371
372 try:
373 state = READY
374 while True:
375 rd, unused_w, unused_x = select.select([sys.stdin], [], [], 0.5)
376
377 # We can't install a signal handler in the main thread since it'll
378 # interrupt the read/write system call (ws4py performing send/recv).
379 # Use polling instead (select's timeout is 0.5 seconds)
380 _ResizeWindow()
381
382 if sys.stdin in rd:
383 data = sys.stdin.read()
384
385 # Scan for escape sequence
386 for x in data:
387 if state == READY:
388 state = ENTER_PRESSED if x == chr(0x0d) else READY
389 elif state == ENTER_PRESSED:
390 state = ESCAPE_PRESSED if x == _ESCAPE else READY
391 elif state == ESCAPE_PRESSED:
392 if x == '.':
393 self.close()
394 raise RuntimeError('quit')
395 else:
396 state = READY
397
398 self.send(data)
399 except (KeyboardInterrupt, RuntimeError):
400 pass
401
402 t = threading.Thread(target=_FeedInput)
403 t.daemon = True
404 t.start()
405
406 def closed(self, code, reason=None):
407 termios.tcsetattr(self._stdin_fd, termios.TCSANOW, self._old_termios)
408 print('Connection to %s closed.' % self._mid)
409
410 def received_message(self, msg):
411 if msg.is_binary:
412 sys.stdout.write(msg.data)
413 sys.stdout.flush()
414
415
416class ShellWebSocketClient(WebSocketBaseClient):
417 def __init__(self, output, *args, **kwargs):
418 """Constructor.
419
420 Args:
421 output: output file object.
422 """
423 self.output = output
424 super(ShellWebSocketClient, self).__init__(*args, **kwargs)
425
426 def handshake_ok(self):
427 pass
428
429 def opened(self):
430 pass
431
432 def closed(self, code, reason=None):
433 pass
434
435 def received_message(self, msg):
436 if msg.is_binary:
437 self.output.write(msg.data)
438 self.output.flush()
439
440
441class ForwarderWebSocketClient(WebSocketBaseClient):
442 def __init__(self, sock, *args, **kwargs):
443 super(ForwarderWebSocketClient, self).__init__(*args, **kwargs)
444 self._sock = sock
445 self._stop = threading.Event()
446
447 def handshake_ok(self):
448 pass
449
450 def opened(self):
451 def _FeedInput():
452 try:
453 self._sock.setblocking(False)
454 while True:
455 rd, unused_w, unused_x = select.select([self._sock], [], [], 0.5)
456 if self._stop.is_set():
457 break
458 if self._sock in rd:
459 data = self._sock.recv(_BUFSIZ)
460 if len(data) == 0:
461 break
462 self.send(data, binary=True)
463 except Exception:
464 pass
465 finally:
466 self._sock.close()
467 self.close()
468
469 t = threading.Thread(target=_FeedInput)
470 t.daemon = True
471 t.start()
472
473 def closed(self, code, reason=None):
474 self._stop.set()
475 sys.exit(0)
476
477 def received_message(self, msg):
478 if msg.is_binary:
479 self._sock.send(msg.data)
480
481
482def Arg(*args, **kwargs):
483 return (args, kwargs)
484
485
486def Command(command, help_msg=None, args=None):
487 """Decorator for adding argparse parameter for a method."""
488 if args is None:
489 args = []
490 def WrapFunc(func):
491 def Wrapped(*args, **kwargs):
492 return func(*args, **kwargs)
493 # pylint: disable=W0212
494 Wrapped.__arg_attr = {'command': command, 'help': help_msg, 'args': args}
495 return Wrapped
496 return WrapFunc
497
498
499def ParseMethodSubCommands(cls):
500 """Decorator for a class using the @Command decorator.
501
502 This decorator retrieve command info from each method and append it in to the
503 SUBCOMMANDS class variable, which is later used to construct parser.
504 """
505 for unused_key, method in cls.__dict__.iteritems():
506 if hasattr(method, '__arg_attr'):
507 cls.SUBCOMMANDS.append(method.__arg_attr) # pylint: disable=W0212
508 return cls
509
510
511@ParseMethodSubCommands
512class OverlordCLIClient(object):
513 """Overlord command line interface client."""
514
515 SUBCOMMANDS = []
516
517 def __init__(self):
518 self._parser = self._BuildParser()
519 self._selected_mid = None
520 self._server = None
521 self._state = None
522
523 def _BuildParser(self):
524 root_parser = argparse.ArgumentParser(prog='ovl')
525 subparsers = root_parser.add_subparsers(help='sub-command')
526
527 root_parser.add_argument('-s', dest='selected_mid', action='store',
528 default=None,
529 help='select target to execute command on')
530 root_parser.add_argument('-S', dest='select_mid_before_action',
531 action='store_true', default=False,
532 help='select target before executing command')
533
534 for attr in self.SUBCOMMANDS:
535 parser = subparsers.add_parser(attr['command'], help=attr['help'])
536 parser.set_defaults(which=attr['command'])
537 for arg in attr['args']:
538 parser.add_argument(*arg[0], **arg[1])
539
540 return root_parser
541
542 def Main(self):
543 # We want to pass the rest of arguments after shell command directly to the
544 # function without parsing it.
545 try:
546 index = sys.argv.index('shell')
547 except ValueError:
548 args = self._parser.parse_args()
549 else:
550 args = self._parser.parse_args(sys.argv[1:index + 1])
551
552 command = args.which
553 self._selected_mid = args.selected_mid
554
555 if command == 'kill-server':
556 self.KillServer()
557 return
558
559 self.StartDaemon()
560 if command == 'status':
561 self.Status()
562 return
563 elif command == 'connect':
564 self.Connect(args)
565 return
566
567 # The following command requires connection to the server
568 self.CheckConnection()
569
570 if args.select_mid_before_action:
571 self.SelectClient(store=False)
572
573 if command == 'select':
574 self.SelectClient(args)
575 elif command == 'ls':
576 self.ListClients()
577 elif command == 'shell':
578 command = sys.argv[sys.argv.index('shell') + 1:]
579 self.Shell(command)
580 elif command == 'push':
581 self.Push(args)
582 elif command == 'pull':
583 self.Pull(args)
584 elif command == 'forward':
585 self.Forward(args)
586
587 def _UrlOpen(self, url):
588 url = MakeRequestUrl(self._state, url)
589 request = urllib2.Request(url)
590 if self._state.username is not None and self._state.password is not None:
591 request.add_header(*BasicAuthHeader(self._state.username,
592 self._state.password))
593 return urllib2.urlopen(request)
594
595 def _HTTPPostFile(self, url, filename, progress=None, user=None, passwd=None):
596 """Perform HTTP POST and upload file to Overlord.
597
598 To minimize the external dependencies, we construct the HTTP post request
599 by ourselves.
600 """
601 url = MakeRequestUrl(self._state, url)
602 size = os.stat(filename).st_size
603 boundary = '-----------%s' % _HTTP_BOUNDARY_MAGIC
604 CRLF = '\r\n'
605 parse = urlparse.urlparse(url)
606
607 part_headers = [
608 '--' + boundary,
609 'Content-Disposition: form-data; name="file"; '
610 'filename="%s"' % os.path.basename(filename),
611 'Content-Type: application/octet-stream',
612 '', ''
613 ]
614 part_header = CRLF.join(part_headers)
615 end_part = CRLF + '--' + boundary + '--' + CRLF
616
617 content_length = len(part_header) + size + len(end_part)
618 if parse.scheme == 'http':
619 h = httplib.HTTP(parse.netloc)
620 else:
621 h = httplib.HTTPS(parse.netloc)
622
623 post_path = url[url.index(parse.netloc) + len(parse.netloc):]
624 h.putrequest('POST', post_path)
625 h.putheader('Content-Length', content_length)
626 h.putheader('Content-Type', 'multipart/form-data; boundary=%s' % boundary)
627
628 if user and passwd:
629 h.putheader(*BasicAuthHeader(user, passwd))
630 h.endheaders()
631 h.send(part_header)
632
633 count = 0
634 with open(filename, 'r') as f:
635 while True:
636 data = f.read(_BUFSIZ)
637 if not data:
638 break
639 count += len(data)
640 if progress:
641 progress(int(count * 100.0 / size), count)
642 h.send(data)
643
644 h.send(end_part)
645 progress(100)
646
647 if count != size:
648 logging.warning('file changed during upload, upload may be truncated.')
649
650 errcode, unused_x, unused_y = h.getreply()
651 return errcode == 200
652
653 def StartDaemon(self):
654 self._server = OverlordClientDaemon.GetRPCServer()
655 if self._server is None:
656 print('* daemon not running, starting it now on port %d ... *' %
657 _OVERLORD_CLIENT_DAEMON_PORT)
658 OverlordClientDaemon().Start()
659 time.sleep(1)
660 self._server = OverlordClientDaemon.GetRPCServer()
661 if self._server is not None:
662 print('* daemon started successfully *')
663
664 self._state = self._server.State()
665 sha1sum = GetVersionDigest()
666
667 if sha1sum != self._state.version_sha1sum:
668 print('ovl server is out of date. killing...')
669 KillGraceful(self._server.GetPid())
670 self.StartDaemon()
671
672 def GetSSHControlFile(self, host):
673 return _SSH_CONTROL_SOCKET_PREFIX + host
674
675 def SSHTunnel(self, user, host, port):
676 """SSH forward the remote overlord server.
677
678 Overlord server may not have port 9000 open to the public network, in such
679 case we can SSH forward the port to localhost.
680 """
681
682 control_file = self.GetSSHControlFile(host)
683 try:
684 os.unlink(control_file)
685 except Exception:
686 pass
687
688 subprocess.Popen([
689 'ssh', '-Nf',
690 '-M', # Enable master mode
691 '-S', control_file,
692 '-L', '9000:localhost:9000',
693 '-p', str(port),
694 '%s%s' % (user + '@' if user else '', host)
695 ]).wait()
696
697 p = subprocess.Popen([
698 'ssh',
699 '-S', control_file,
700 '-O', 'check', host,
701 ], stderr=subprocess.PIPE)
702 unused_stdout, stderr = p.communicate()
703
704 s = re.search(r'pid=(\d+)', stderr)
705 if s:
706 return int(s.group(1))
707
708 raise RuntimeError('can not establish ssh connection')
709
710 def CheckConnection(self):
711 if self._state.host is None:
712 raise RuntimeError('not connected to any server, abort')
713
714 try:
715 self._server.Clients()
716 except Exception:
717 raise RuntimeError('remote server disconnected, abort')
718
719 if self._state.ssh_pid is not None:
720 ret = subprocess.Popen(['kill', '-0', str(self._state.ssh_pid)],
721 stdout=subprocess.PIPE,
722 stderr=subprocess.PIPE).wait()
723 if ret != 0:
724 raise RuntimeError('ssh tunnel disconnected, please re-connect')
725
726 def CheckClient(self):
727 if self._selected_mid is None:
728 if self._state.selected_mid is None:
729 raise RuntimeError('No client is selected')
730 self._selected_mid = self._state.selected_mid
731
732 if self._selected_mid not in self._server.Clients():
733 raise RuntimeError('client %s disappeared' % self._selected_mid)
734
735 def CheckOutput(self, command):
736 headers = []
737 if self._state.username is not None and self._state.password is not None:
738 headers.append(BasicAuthHeader(self._state.username,
739 self._state.password))
740
741 scheme = 'ws%s://' % ('s' if self._state.ssl else '')
742 sio = StringIO.StringIO()
743 ws = ShellWebSocketClient(sio,
744 scheme + '%s:%d/api/agent/shell/%s?command=%s' %
745 (self._state.host, self._state.port,
746 self._selected_mid, urllib2.quote(command)),
747 headers=headers)
748 ws.connect()
749 ws.run()
750 return sio.getvalue()
751
752 @Command('status', 'show Overlord connection status')
753 def Status(self):
754 if self._state.host is None:
755 print('Not connected to any host.')
756 else:
757 if self._state.ssh_pid is not None:
758 print('Connected to %s with SSH tunneling.' % self._state.orig_host)
759 else:
760 print('Connected to %s:%d.' % (self._state.host, self._state.port))
761
762 if self._selected_mid is None:
763 self._selected_mid = self._state.selected_mid
764
765 if self._selected_mid is None:
766 print('No client is selected.')
767 else:
768 print('Client %s selected.' % self._selected_mid)
769
770 @Command('connect', 'connect to Overlord server', [
771 Arg('host', metavar='HOST', type=str, default='localhost',
772 help='Overlord hostname/IP'),
773 Arg('port', metavar='PORT', type=int,
774 default=_OVERLORD_HTTP_PORT, help='Overlord port'),
775 Arg('-f', '--forward', dest='ssh_forward', default=False,
776 action='store_true',
777 help='connect with SSH forwarding to the host'),
778 Arg('-p', '--ssh-port', dest='ssh_port', default=22,
779 type=int, help='SSH server port for SSH forwarding'),
780 Arg('-l', '--ssh-login', dest='ssh_login', default='',
781 type=str, help='SSH server login name for SSH forwarding'),
782 Arg('-u', '--user', dest='user', default=None,
783 type=str, help='Overlord HTTP auth username'),
784 Arg('-w', '--passwd', dest='passwd', default=None, type=str,
785 help='Overlord HTTP auth password')])
786 def Connect(self, args):
787 ssh_pid = None
788 host = args.host
789 orig_host = args.host
790
791 if args.ssh_forward:
792 # Kill previous SSH tunnel
793 self.KillSSHTunnel()
794
795 ssh_pid = self.SSHTunnel(args.ssh_login, args.host, args.ssh_port)
796 host = 'localhost'
797
798 status = self._server.Connect(host, args.port, ssh_pid, args.user,
799 args.passwd, orig_host)
800 if status is not True:
801 if isinstance(status, int):
802 if status == 401:
803 msg = '401 Unauthorized'
804 else:
805 msg = 'HTTP %d' % status
806 else:
807 msg = status
808 print('can not connect to %s: %s' % (host, msg))
809
810 @Command('kill-server', 'kill overlord CLI client server')
811 def KillServer(self):
812 self._server = OverlordClientDaemon.GetRPCServer()
813 if self._server is None:
814 return
815
816 # Kill SSH Tunnel
817 self.KillSSHTunnel()
818
819 # Kill server daemon
820 KillGraceful(self._server.GetPid())
821
822 def KillSSHTunnel(self):
823 if self._state.ssh_pid is not None:
824 KillGraceful(self._state.ssh_pid)
825
826 @Command('ls', 'list all clients')
827 def ListClients(self):
828 for client in self._server.Clients():
829 print(client)
830
831 @Command('select', 'select default client', [
832 Arg('mid', metavar='mid', nargs='?', default=None)])
833 def SelectClient(self, args=None, store=True):
834 clients = self._server.Clients()
835
836 mid = args.mid if args is not None else None
837 if mid is None:
838 print('Select from the following clients:')
839 for i, client in enumerate(clients):
840 print(' %d. %s' % (i + 1, client))
841
842 print('\nSelection: ', end='')
843 try:
844 choice = int(raw_input()) - 1
845 mid = clients[choice]
846 except ValueError:
847 raise RuntimeError('select: invalid selection')
848 except IndexError:
849 raise RuntimeError('select: selection out of range')
850 else:
851 if mid not in clients:
852 raise RuntimeError('select: client %s does not exist' % mid)
853
854 self._selected_mid = mid
855 if store:
856 self._server.SelectClient(mid)
857 print('Client %s selected' % mid)
858
859 @Command('shell', 'open a shell or execute a shell command', [
860 Arg('command', metavar='CMD', nargs='?', help='command to execute')])
861 def Shell(self, command=None):
862 if command is None:
863 command = []
864 self.CheckClient()
865
866 headers = []
867 if self._state.username is not None and self._state.password is not None:
868 headers.append(BasicAuthHeader(self._state.username,
869 self._state.password))
870
871 scheme = 'ws%s://' % ('s' if self._state.ssl else '')
872 if len(command) == 0:
873 ws = TerminalWebSocketClient(self._selected_mid,
874 scheme + '%s:%d/api/agent/tty/%s' %
875 (self._state.host, self._state.port,
876 self._selected_mid), headers=headers)
877 else:
878 cmd = ' '.join(command)
879 ws = ShellWebSocketClient(sys.stdout,
880 scheme + '%s:%d/api/agent/shell/%s?command=%s' %
881 (self._state.host, self._state.port,
882 self._selected_mid, urllib2.quote(cmd)),
883 headers=headers)
884 ws.connect()
885 ws.run()
886
887 @Command('push', 'push a file or directory to remote', [
888 Arg('src', metavar='SOURCE'),
889 Arg('dst', metavar='DESTINATION')])
890 def Push(self, args):
891 self.CheckClient()
892
893 if not os.path.exists(args.src):
894 raise RuntimeError('push: can not stat "%s": no such file or directory'
895 % args.src)
896
897 if not os.access(args.src, os.R_OK):
898 raise RuntimeError('push: can not open "%s" for reading' % args.src)
899
900 def _push(src, dst):
901 src_base = os.path.basename(src)
902 mode = '0%o' % (0x1FF & os.stat(src).st_mode)
903 url = ('%s:%d/api/agent/upload/%s?dest=%s&perm=%s' %
904 (self._state.host, self._state.port, self._selected_mid, dst,
905 mode))
906 try:
907 self._UrlOpen(url + '&filename=%s' % src_base)
908 except urllib2.HTTPError as e:
909 msg = json.loads(e.read()).get('error', None)
910 raise RuntimeError('push: %s' % msg)
911
912 pbar = ProgressBar(src_base)
913 self._HTTPPostFile(url, src, pbar.SetProgress,
914 self._state.username, self._state.password)
915 pbar.End()
916
917 if os.path.isdir(args.src):
918 dst_exists = ast.literal_eval(self.CheckOutput(
919 'stat %s >/dev/null 2>&1 && echo True || echo False' % args.dst))
920 for root, unused_x, files in os.walk(args.src):
921 # If destination directory does not exist, we should strip the first
922 # layer of directory. For example: src_dir contains a single file 'A'
923 #
924 # push src_dir dest_dir
925 #
926 # If dest_dir exists, the resulting directory structure should be:
927 # dest_dir/src_dir/A
928 # If dest_dir does not exist, the resulting directory structure should
929 # be:
930 # dest_dir/A
931 dst_root = root if dst_exists else root[len(args.src):].lstrip('/')
932 for name in files:
933 _push(os.path.join(root, name),
934 os.path.join(args.dst, dst_root, name))
935 else:
936 _push(args.src, args.dst)
937
938 @Command('pull', 'pull a file or directory from remote', [
939 Arg('src', metavar='SOURCE'),
940 Arg('dst', metavar='DESTINATION', default='.', nargs='?')])
941 def Pull(self, args):
942 self.CheckClient()
943
944 def _pull(src, dst, perm=0644):
945 url = ('%s:%d/api/agent/download/%s?filename=%s' %
946 (self._state.host, self._state.port, self._selected_mid,
947 urllib2.quote(src)))
948 try:
949 h = self._UrlOpen(url)
950 except urllib2.HTTPError as e:
951 msg = json.loads(e.read()).get('error', 'unkown error')
952 raise RuntimeError('pull: %s' % msg)
953 except KeyboardInterrupt:
954 return
955
956 try:
957 os.makedirs(os.path.dirname(dst))
958 except Exception:
959 pass
960
961 pbar = ProgressBar(os.path.basename(src))
962 with open(dst, 'w') as f:
963 os.fchmod(f.fileno(), perm)
964 total_size = int(h.headers.get('Content-Length'))
965 downloaded_size = 0
966 while True:
967 data = h.read(_BUFSIZ)
968 downloaded_size += len(data)
969 pbar.SetProgress(float(downloaded_size) * 100 / total_size,
970 downloaded_size)
971 if len(data) == 0:
972 break
973 f.write(data)
974 pbar.End()
975
976 # Use find to get a listing of all files under a root directory. The 'stat'
977 # command is used to retrieve the filename and it's filemode.
978 output = self.CheckOutput(
979 'cd $HOME; '
980 'stat "%(src)s" >/dev/null && '
981 'find "%(src)s" \'(\' -type f -o -type l \')\' -printf \'%%m\t%%p\n\''
982 % {'src': args.src})
983
984 # We got error from the stat command
985 if output.startswith('stat: '):
986 sys.stderr.write(output)
987 return
988
989 entries = output.strip().split('\n')
990 common_prefix = os.path.dirname(args.src)
991
992 if len(entries) == 1:
993 entry = entries[0]
994 perm, src_path = entry.split('\t')
995 if os.path.isdir(args.dst):
996 dst = os.path.join(args.dst, os.path.basename(src_path))
997 else:
998 dst = args.dst
999 _pull(src_path, dst, int(perm, base=8))
1000 else:
1001 if not os.path.exists(args.dst):
1002 common_prefix = args.src
1003
1004 for entry in entries:
1005 perm, src_path = entry.split('\t')
1006 rel_dst = src_path[len(common_prefix):].lstrip('/')
1007 _pull(src_path, os.path.join(args.dst, rel_dst), int(perm, base=8))
1008
1009 @Command('forward', 'forward remote port to local port', [
1010 Arg('--list', dest='list_all', action='store_true', default=False,
1011 help='list all port forwarding sessions'),
1012 Arg('--remove', metavar='LOCAL_PORT', dest='remove', type=int,
1013 default=None,
1014 help='remove port forwarding for local port LOCAL_PORT'),
1015 Arg('--remove-all', dest='remove_all', action='store_true',
1016 default=False, help='remove all port forwarding'),
1017 Arg('remote', metavar='REMOTE_PORT', type=int, nargs='?'),
1018 Arg('local', metavar='LOCAL_PORT', type=int, nargs='?')])
1019 def Forward(self, args):
1020 if args.list_all:
1021 max_len = 10
1022 if len(self._state.forwards):
1023 max_len = max([len(v[0]) for v in self._state.forwards.values()])
1024
1025 print('%-*s %-8s %-8s' % (max_len, 'Client', 'Remote', 'Local'))
1026 for local in sorted(self._state.forwards.keys()):
1027 value = self._state.forwards[local]
1028 print('%-*s %-8s %-8s' % (max_len, value[0], value[1], local))
1029 return
1030
1031 if args.remove_all:
1032 self._server.RemoveAllForward()
1033 return
1034
1035 if args.remove:
1036 self._server.RemoveForward(args.remove)
1037 return
1038
1039 self.CheckClient()
1040
1041 if args.local is None:
1042 args.local = args.remote
1043 remote = int(args.remote)
1044 local = int(args.local)
1045
1046 def HandleConnection(conn):
1047 headers = []
1048 if self._state.username is not None and self._state.password is not None:
1049 headers.append(BasicAuthHeader(self._state.username,
1050 self._state.password))
1051
1052 scheme = 'ws%s://' % ('s' if self._state.ssl else '')
1053 ws = ForwarderWebSocketClient(
1054 conn,
1055 scheme + '%s:%d/api/agent/forward/%s?port=%d' %
1056 (self._state.host, self._state.port, self._selected_mid, remote),
1057 headers=headers)
1058 try:
1059 ws.connect()
1060 ws.run()
1061 except Exception as e:
1062 print('error: %s' % e)
1063 finally:
1064 ws.close()
1065
1066 server = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
1067 server.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
1068 server.bind(('0.0.0.0', local))
1069 server.listen(5)
1070
1071 pid = os.fork()
1072 if pid == 0:
1073 while True:
1074 conn, unused_addr = server.accept()
1075 t = threading.Thread(target=HandleConnection, args=(conn,))
1076 t.daemon = True
1077 t.start()
1078 else:
1079 self._server.AddForward(self._selected_mid, remote, local, pid)
1080
1081
1082def main():
1083 logging.basicConfig(level=logging.INFO)
1084
1085 # Add DaemonState to JSONRPC lib classes
1086 Config.instance().classes.add(DaemonState)
1087
1088 ovl = OverlordCLIClient()
1089 try:
1090 ovl.Main()
1091 except KeyboardInterrupt:
1092 print('Ctrl-C received, abort')
1093 except Exception as e:
1094 print('error: %s' % e)
1095
1096
1097if __name__ == '__main__':
1098 main()