blob: df7eaef1fd8170567835ad1a007622755e1b79cb [file] [log] [blame]
#!/usr/bin/env python3
# Copyright 2015 The Chromium OS Authors. All rights reserved.
# Use of this source code is governed by a BSD-style license that can be
# found in the LICENSE file.
import argparse
import ast
import base64
import fcntl
import functools
import getpass
import hashlib
import http.client
from io import StringIO
import json
import logging
import os
import re
import select
import shutil
import signal
import socket
import ssl
import struct
import subprocess
import sys
import tempfile
import termios
import threading
import time
import tty
import unicodedata # required by pyinstaller, pylint: disable=unused-import
import urllib.error
import urllib.parse
import urllib.request
import jsonrpclib
from jsonrpclib import config
from jsonrpclib.SimpleJSONRPCServer import SimpleJSONRPCServer
from ws4py.client import WebSocketBaseClient
import yaml
from cros.factory.utils import net_utils
from cros.factory.utils import process_utils
_CERT_DIR = os.path.expanduser('~/.config/ovl')
_ESCAPE = '~'
_BUFSIZ = 8192
_OVERLORD_PORT = 4455
_OVERLORD_HTTP_PORT = 9000
_OVERLORD_CLIENT_DAEMON_PORT = 4488
_OVERLORD_CLIENT_DAEMON_RPC_ADDR = ('127.0.0.1', _OVERLORD_CLIENT_DAEMON_PORT)
_CONNECT_TIMEOUT = 3
_DEFAULT_HTTP_TIMEOUT = 30
_LIST_CACHE_TIMEOUT = 2
_DEFAULT_TERMINAL_WIDTH = 80
_RETRY_TIMES = 3
# echo -n overlord | md5sum
_HTTP_BOUNDARY_MAGIC = '9246f080c855a69012707ab53489b921'
# Terminal resize control
_CONTROL_START = 128
_CONTROL_END = 129
# Stream control
_STDIN_CLOSED = '##STDIN_CLOSED##'
_SSH_CONTROL_SOCKET_PREFIX = os.path.join(tempfile.gettempdir(),
'ovl-ssh-control-')
_TLS_CERT_FAILED_WARNING = """
@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@
@ WARNING: REMOTE HOST VERIFICATION HAS FAILED! @
@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@
Failed Reason: %s.
Please use -c option to specify path of root CA certificate.
This root CA certificate should be the one that signed the certificate used by
overlord server."""
def GetVersionDigest():
"""Return the sha1sum of the current executing script."""
# Check python script by default
filename = __file__
# If we are running from a frozen binary, we should calculate the checksum
# against that binary instead of the python script.
# See: https://pyinstaller.readthedocs.io/en/stable/runtime-information.html
if getattr(sys, 'frozen', False):
filename = sys.executable
with open(filename, 'rb') as f:
return hashlib.sha1(f.read()).hexdigest()
def GetTLSCertPath(host):
return os.path.join(_CERT_DIR, '%s.cert' % host)
def UrlOpen(state, url):
"""Wrapper for urllib.request.urlopen.
It selects correct HTTP scheme according to self._state.ssl, add HTTP
basic auth headers, and add specify correct SSL context.
"""
url = MakeRequestUrl(state, url)
request = urllib.request.Request(url)
if state.username is not None and state.password is not None:
request.add_header(*BasicAuthHeader(state.username, state.password))
return urllib.request.urlopen(request, timeout=_DEFAULT_HTTP_TIMEOUT,
context=state.ssl_context)
def KillGraceful(pid, wait_secs=1):
"""Kill a process gracefully by first sending SIGTERM, wait for some time,
then send SIGKILL to make sure it's killed."""
try:
os.kill(pid, signal.SIGTERM)
time.sleep(wait_secs)
os.kill(pid, signal.SIGKILL)
except OSError:
pass
def AutoRetry(action_name, retries):
"""Decorator for retry function call."""
def Wrap(func):
@functools.wraps(func)
def Loop(*args, **kwargs):
for unused_i in range(retries):
try:
func(*args, **kwargs)
except Exception as e:
print('error: %s: %s: retrying ...' % (args[0], e))
else:
break
else:
print('error: failed to %s %s' % (action_name, args[0]))
return Loop
return Wrap
def BasicAuthHeader(user, password):
"""Return HTTP basic auth header."""
credential = base64.b64encode(
b'%s:%s' % (user.encode('utf-8'), password.encode('utf-8')))
return ('Authorization', 'Basic %s' % credential.decode('utf-8'))
def GetTerminalSize():
"""Retrieve terminal window size."""
ws = struct.pack('HHHH', 0, 0, 0, 0)
ws = fcntl.ioctl(0, termios.TIOCGWINSZ, ws)
lines, columns, unused_x, unused_y = struct.unpack('HHHH', ws)
return lines, columns
def MakeRequestUrl(state, url):
return 'http%s://%s' % ('s' if state.ssl else '', url)
class ProgressBar:
SIZE_WIDTH = 11
SPEED_WIDTH = 10
DURATION_WIDTH = 6
PERCENTAGE_WIDTH = 8
def __init__(self, name):
self._start_time = time.time()
self._name = name
self._size = 0
self._width = 0
self._name_width = 0
self._name_max = 0
self._stat_width = 0
self._max = 0
self._CalculateSize()
self.SetProgress(0)
def _CalculateSize(self):
self._width = GetTerminalSize()[1] or _DEFAULT_TERMINAL_WIDTH
self._name_width = int(self._width * 0.3)
self._name_max = self._name_width
self._stat_width = self.SIZE_WIDTH + self.SPEED_WIDTH + self.DURATION_WIDTH
self._max = (self._width - self._name_width - self._stat_width -
self.PERCENTAGE_WIDTH)
def _SizeToHuman(self, size_in_bytes):
if size_in_bytes < 1024:
unit = 'B'
value = size_in_bytes
elif size_in_bytes < 1024 ** 2:
unit = 'KiB'
value = size_in_bytes / 1024
elif size_in_bytes < 1024 ** 3:
unit = 'MiB'
value = size_in_bytes / (1024 ** 2)
elif size_in_bytes < 1024 ** 4:
unit = 'GiB'
value = size_in_bytes / (1024 ** 3)
return ' %6.1f %3s' % (value, unit)
def _SpeedToHuman(self, speed_in_bs):
if speed_in_bs < 1024:
unit = 'B'
value = speed_in_bs
elif speed_in_bs < 1024 ** 2:
unit = 'K'
value = speed_in_bs / 1024
elif speed_in_bs < 1024 ** 3:
unit = 'M'
value = speed_in_bs / (1024 ** 2)
elif speed_in_bs < 1024 ** 4:
unit = 'G'
value = speed_in_bs / (1024 ** 3)
return ' %6.1f%s/s' % (value, unit)
def _DurationToClock(self, duration):
return ' %02d:%02d' % (duration // 60, duration % 60)
def SetProgress(self, percentage, size=None):
current_width = GetTerminalSize()[1]
if self._width != current_width:
self._CalculateSize()
if size is not None:
self._size = size
elapse_time = time.time() - self._start_time
speed = self._size / elapse_time
size_str = self._SizeToHuman(self._size)
speed_str = self._SpeedToHuman(speed)
elapse_str = self._DurationToClock(elapse_time)
width = int(self._max * percentage / 100.0)
sys.stdout.write(
'%*s' % (- self._name_max,
self._name if len(self._name) <= self._name_max else
self._name[:self._name_max - 4] + ' ...') +
size_str + speed_str + elapse_str +
((' [' + '#' * width + ' ' * (self._max - width) + ']' +
'%4d%%' % int(percentage)) if self._max > 2 else '') + '\r')
sys.stdout.flush()
def End(self):
self.SetProgress(100.0)
sys.stdout.write('\n')
sys.stdout.flush()
class DaemonState:
"""DaemonState is used for storing Overlord state info."""
def __init__(self):
self.version_sha1sum = GetVersionDigest()
self.host = None
self.port = None
self.ssl = False
self.ssl_self_signed = False
self.ssl_context = ssl.SSLContext(ssl.PROTOCOL_TLSv1_2)
self.ssh = False
self.orig_host = None
self.ssh_pid = None
self.username = None
self.password = None
self.selected_mid = None
self.forwards = {}
self.listing = []
self.last_list = 0
class OverlordClientDaemon:
"""Overlord Client Daemon."""
def __init__(self):
# Use full module path for jsonrpclib to resolve.
import cros.factory.tools.ovl
self._state = cros.factory.tools.ovl.DaemonState()
self._server = None
def Start(self):
self.StartRPCServer()
def StartRPCServer(self):
self._server = SimpleJSONRPCServer(_OVERLORD_CLIENT_DAEMON_RPC_ADDR,
logRequests=False)
exports = [
(self.State, 'State'),
(self.Ping, 'Ping'),
(self.GetPid, 'GetPid'),
(self.Connect, 'Connect'),
(self.Clients, 'Clients'),
(self.SelectClient, 'SelectClient'),
(self.AddForward, 'AddForward'),
(self.RemoveForward, 'RemoveForward'),
(self.RemoveAllForward, 'RemoveAllForward'),
]
for func, name in exports:
self._server.register_function(func, name)
pid = os.fork()
if pid == 0:
for fd in range(3):
os.close(fd)
self._server.serve_forever()
@staticmethod
def GetRPCServer():
"""Returns the Overlord client daemon RPC server."""
server = jsonrpclib.Server('http://%s:%d' %
_OVERLORD_CLIENT_DAEMON_RPC_ADDR)
try:
server.Ping()
except Exception:
return None
return server
def State(self):
return self._state
def Ping(self):
return True
def GetPid(self):
return os.getpid()
def _GetJSON(self, path):
url = '%s:%d%s' % (self._state.host, self._state.port, path)
return json.loads(UrlOpen(self._state, url).read())
def _TLSEnabled(self):
"""Determine if TLS is enabled on given server address."""
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
try:
# Allow any certificate since we only want to check if server talks TLS.
context = ssl.SSLContext(ssl.PROTOCOL_TLSv1_2)
context.verify_mode = ssl.CERT_NONE
sock = context.wrap_socket(sock, server_hostname=self._state.host)
sock.settimeout(_CONNECT_TIMEOUT)
sock.connect((self._state.host, self._state.port))
return True
except ssl.SSLError:
return False
except socket.error: # Connect refused or timeout
raise
except Exception:
return False # For whatever reason above failed, assume False
def _CheckTLSCertificate(self, check_hostname=True):
"""Check TLS certificate.
Returns:
A tupple (check_result, if_certificate_is_loaded)
"""
def _DoConnect(context):
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
try:
sock.settimeout(_CONNECT_TIMEOUT)
sock = context.wrap_socket(sock, server_hostname=self._state.host)
sock.connect((self._state.host, self._state.port))
except ssl.SSLError:
return False
finally:
sock.close()
# Save SSLContext for future use.
self._state.ssl_context = context
return True
# First try connect with built-in certificates
tls_context = ssl.create_default_context(ssl.Purpose.SERVER_AUTH)
if _DoConnect(tls_context):
return True
# Try with already saved certificate, if any.
tls_context = ssl.SSLContext(ssl.PROTOCOL_TLSv1_2)
tls_context.verify_mode = ssl.CERT_REQUIRED
tls_context.check_hostname = check_hostname
tls_cert_path = GetTLSCertPath(self._state.host)
if os.path.exists(tls_cert_path):
tls_context.load_verify_locations(tls_cert_path)
self._state.ssl_self_signed = True
return _DoConnect(tls_context)
def Connect(self, host, port=_OVERLORD_HTTP_PORT, ssh_pid=None,
username=None, password=None, orig_host=None,
check_hostname=True):
self._state.username = username
self._state.password = password
self._state.host = host
self._state.port = port
self._state.ssl = False
self._state.ssl_self_signed = False
self._state.orig_host = orig_host
self._state.ssh_pid = ssh_pid
self._state.selected_mid = None
tls_enabled = self._TLSEnabled()
if tls_enabled:
if not os.path.exists(os.path.join(_CERT_DIR, '%s.cert' % host)):
return 'SSLCertificateNotExisted'
if not self._CheckTLSCertificate(check_hostname):
return 'SSLVerifyFailed'
try:
self._state.ssl = tls_enabled
UrlOpen(self._state, '%s:%d' % (host, port))
except urllib.error.HTTPError as e:
return ('HTTPError', e.getcode(), str(e),
e.read().strip().decode('utf-8'))
except Exception as e:
return str(e)
else:
return True
def Clients(self):
if time.time() - self._state.last_list <= _LIST_CACHE_TIMEOUT:
return self._state.listing
self._state.listing = self._GetJSON('/api/agents/list')
self._state.last_list = time.time()
return self._state.listing
def SelectClient(self, mid):
self._state.selected_mid = mid
def AddForward(self, mid, remote, local, pid):
self._state.forwards[local] = (mid, remote, pid)
def RemoveForward(self, local_port):
try:
unused_mid, unused_remote, pid = self._state.forwards[local_port]
KillGraceful(pid)
del self._state.forwards[local_port]
except (KeyError, OSError):
pass
def RemoveAllForward(self):
for unused_mid, unused_remote, pid in self._state.forwards.values():
try:
KillGraceful(pid)
except OSError:
pass
self._state.forwards = {}
class SSLEnabledWebSocketBaseClient(WebSocketBaseClient):
def __init__(self, state, *args, **kwargs):
cafile = ssl.get_default_verify_paths().openssl_cafile
# For some system / distribution, python can not detect system cafile path.
# In such case we fallback to the default path.
if not os.path.exists(cafile):
cafile = '/etc/ssl/certs/ca-certificates.crt'
if state.ssl_self_signed:
cafile = GetTLSCertPath(state.host)
ssl_options = {
'cert_reqs': ssl.CERT_REQUIRED,
'ca_certs': cafile
}
# ws4py does not allow you to specify SSLContext, but rather passing in the
# argument of ssl.wrap_socket
super(SSLEnabledWebSocketBaseClient, self).__init__(
ssl_options=ssl_options, *args, **kwargs)
class TerminalWebSocketClient(SSLEnabledWebSocketBaseClient):
def __init__(self, state, mid, escape, *args, **kwargs):
super(TerminalWebSocketClient, self).__init__(state, *args, **kwargs)
self._mid = mid
self._escape = escape
self._stdin_fd = sys.stdin.fileno()
self._old_termios = None
def handshake_ok(self):
pass
def opened(self):
nonlocals = {'size': (80, 40)}
def _ResizeWindow():
size = GetTerminalSize()
if size != nonlocals['size']: # Size not changed, ignore
control = {'command': 'resize', 'params': list(size)}
payload = chr(_CONTROL_START) + json.dumps(control) + chr(_CONTROL_END)
nonlocals['size'] = size
try:
self.send(payload, binary=True)
except Exception:
pass
def _FeedInput():
self._old_termios = termios.tcgetattr(self._stdin_fd)
tty.setraw(self._stdin_fd)
READY, ENTER_PRESSED, ESCAPE_PRESSED = range(3)
try:
state = READY
while True:
# Check if terminal is resized
_ResizeWindow()
ch = sys.stdin.read(1)
# Scan for escape sequence
if self._escape:
if state == READY:
state = ENTER_PRESSED if ch == chr(0x0d) else READY
elif state == ENTER_PRESSED:
state = ESCAPE_PRESSED if ch == self._escape else READY
elif state == ESCAPE_PRESSED:
if ch == '.':
self.close()
break
else:
state = READY
self.send(ch)
except (KeyboardInterrupt, RuntimeError):
pass
t = threading.Thread(target=_FeedInput)
t.daemon = True
t.start()
def closed(self, code, reason=None):
del code, reason # Unused.
termios.tcsetattr(self._stdin_fd, termios.TCSANOW, self._old_termios)
print('Connection to %s closed.' % self._mid)
def received_message(self, message):
if message.is_binary:
sys.stdout.write(message.data.decode('utf-8'))
sys.stdout.flush()
class ShellWebSocketClient(SSLEnabledWebSocketBaseClient):
def __init__(self, state, output, *args, **kwargs):
"""Constructor.
Args:
output: output file object.
"""
self.output = output
super(ShellWebSocketClient, self).__init__(state, *args, **kwargs)
def handshake_ok(self):
pass
def opened(self):
def _FeedInput():
try:
while True:
data = sys.stdin.buffer.read(1)
if not data:
self.send(_STDIN_CLOSED * 2)
break
self.send(data, binary=True)
except (KeyboardInterrupt, RuntimeError):
pass
t = threading.Thread(target=_FeedInput)
t.daemon = True
t.start()
def closed(self, code, reason=None):
pass
def received_message(self, message):
if message.is_binary:
self.output.write(message.data.decode('utf-8'))
self.output.flush()
class ForwarderWebSocketClient(SSLEnabledWebSocketBaseClient):
def __init__(self, state, sock, *args, **kwargs):
super(ForwarderWebSocketClient, self).__init__(state, *args, **kwargs)
self._sock = sock
self._stop = threading.Event()
def handshake_ok(self):
pass
def opened(self):
def _FeedInput():
try:
self._sock.setblocking(False)
while True:
rd, unused_w, unused_x = select.select([self._sock], [], [], 0.5)
if self._stop.is_set():
break
if self._sock in rd:
data = self._sock.recv(_BUFSIZ)
if not data:
self.close()
break
self.send(data, binary=True)
except Exception:
pass
finally:
self._sock.close()
t = threading.Thread(target=_FeedInput)
t.daemon = True
t.start()
def closed(self, code, reason=None):
del code, reason # Unused.
self._stop.set()
sys.exit(0)
def received_message(self, message):
if message.is_binary:
self._sock.send(message.data)
def Arg(*args, **kwargs):
return (args, kwargs)
def Command(command, help_msg=None, args=None):
"""Decorator for adding argparse parameter for a method."""
if args is None:
args = []
def WrapFunc(func):
@functools.wraps(func)
def Wrapped(*args, **kwargs):
return func(*args, **kwargs)
# pylint: disable=protected-access
Wrapped.__arg_attr = {'command': command, 'help': help_msg, 'args': args}
return Wrapped
return WrapFunc
def ParseMethodSubCommands(cls):
"""Decorator for a class using the @Command decorator.
This decorator retrieve command info from each method and append it in to the
SUBCOMMANDS class variable, which is later used to construct parser.
"""
for unused_key, method in cls.__dict__.items():
if hasattr(method, '__arg_attr'):
# pylint: disable=protected-access
cls.SUBCOMMANDS.append(method.__arg_attr)
return cls
@ParseMethodSubCommands
class OverlordCLIClient:
"""Overlord command line interface client."""
SUBCOMMANDS = []
def __init__(self):
self._parser = self._BuildParser()
self._selected_mid = None
self._server = None
self._state = None
self._escape = None
def _BuildParser(self):
root_parser = argparse.ArgumentParser(prog='ovl')
subparsers = root_parser.add_subparsers(title='subcommands',
dest='subcommand')
subparsers.required = True
root_parser.add_argument('-s', dest='selected_mid', action='store',
default=None,
help='select target to execute command on')
root_parser.add_argument('-S', dest='select_mid_before_action',
action='store_true', default=False,
help='select target before executing command')
root_parser.add_argument('-e', dest='escape', metavar='ESCAPE_CHAR',
action='store', default=_ESCAPE, type=str,
help='set shell escape character, \'none\' to '
'disable escape completely')
for attr in self.SUBCOMMANDS:
parser = subparsers.add_parser(attr['command'], help=attr['help'])
parser.set_defaults(which=attr['command'])
for arg in attr['args']:
parser.add_argument(*arg[0], **arg[1])
return root_parser
def Main(self):
# We want to pass the rest of arguments after shell command directly to the
# function without parsing it.
try:
index = sys.argv.index('shell')
except ValueError:
args = self._parser.parse_args()
else:
args = self._parser.parse_args(sys.argv[1:index + 1])
command = args.which
self._selected_mid = args.selected_mid
if args.escape and args.escape != 'none':
self._escape = args.escape[0]
if command == 'start-server':
self.StartServer()
return
if command == 'kill-server':
self.KillServer()
return
self.CheckDaemon()
if command == 'status':
self.Status()
return
if command == 'connect':
self.Connect(args)
return
# The following command requires connection to the server
self.CheckConnection()
if args.select_mid_before_action:
self.SelectClient(store=False)
if command == 'select':
self.SelectClient(args)
elif command == 'ls':
self.ListClients(args)
elif command == 'shell':
command = sys.argv[sys.argv.index('shell') + 1:]
self.Shell(command)
elif command == 'push':
self.Push(args)
elif command == 'pull':
self.Pull(args)
elif command == 'forward':
self.Forward(args)
def _HTTPPostFile(self, url, filename, progress=None, user=None, passwd=None):
"""Perform HTTP POST and upload file to Overlord.
To minimize the external dependencies, we construct the HTTP post request
by ourselves.
"""
url = MakeRequestUrl(self._state, url)
size = os.stat(filename).st_size
boundary = '-----------%s' % _HTTP_BOUNDARY_MAGIC
CRLF = '\r\n'
parse = urllib.parse.urlparse(url)
part_headers = [
'--' + boundary,
'Content-Disposition: form-data; name="file"; '
'filename="%s"' % os.path.basename(filename),
'Content-Type: application/octet-stream',
'', ''
]
part_header = CRLF.join(part_headers)
end_part = CRLF + '--' + boundary + '--' + CRLF
content_length = len(part_header) + size + len(end_part)
if parse.scheme == 'http':
h = http.client.HTTPConnection(parse.netloc)
else:
h = http.client.HTTPSConnection(parse.netloc,
context=self._state.ssl_context)
post_path = url[url.index(parse.netloc) + len(parse.netloc):]
h.putrequest('POST', post_path)
h.putheader('Content-Length', content_length)
h.putheader('Content-Type', 'multipart/form-data; boundary=%s' % boundary)
if user and passwd:
h.putheader(*BasicAuthHeader(user, passwd))
h.endheaders()
h.send(part_header.encode('utf-8'))
count = 0
with open(filename, 'rb') as f:
while True:
data = f.read(_BUFSIZ)
if not data:
break
count += len(data)
if progress:
progress(count * 100 // size, count)
h.send(data)
h.send(end_part.encode('utf-8'))
progress(100)
if count != size:
logging.warning('file changed during upload, upload may be truncated.')
resp = h.getresponse()
return resp.status == 200
def CheckDaemon(self):
self._server = OverlordClientDaemon.GetRPCServer()
if self._server is None:
print('* daemon not running, starting it now on port %d ... *' %
_OVERLORD_CLIENT_DAEMON_PORT)
self.StartServer()
self._state = self._server.State()
sha1sum = GetVersionDigest()
if sha1sum != self._state.version_sha1sum:
print('ovl server is out of date. killing...')
KillGraceful(self._server.GetPid())
self.StartServer()
def GetSSHControlFile(self, host):
return _SSH_CONTROL_SOCKET_PREFIX + host
def SSHTunnel(self, user, host, port):
"""SSH forward the remote overlord server.
Overlord server may not have port 9000 open to the public network, in such
case we can SSH forward the port to localhost.
"""
control_file = self.GetSSHControlFile(host)
try:
os.unlink(control_file)
except Exception:
pass
subprocess.Popen([
'ssh', '-Nf', '-M', '-S', control_file, '-L', '9000:localhost:9000',
'-p',
str(port),
'%s%s' % (user + '@' if user else '', host)
]).wait()
p = process_utils.Spawn([
'ssh',
'-S', control_file,
'-O', 'check', host,
], read_stderr=True, ignore_stdout=True)
s = re.search(r'pid=(\d+)', p.stderr_data)
if s:
return int(s.group(1))
raise RuntimeError('can not establish ssh connection')
def CheckConnection(self):
if self._state.host is None:
raise RuntimeError('not connected to any server, abort')
try:
self._server.Clients()
except Exception:
raise RuntimeError('remote server disconnected, abort')
if self._state.ssh_pid is not None:
ret = subprocess.Popen(['kill', '-0', str(self._state.ssh_pid)],
stdout=subprocess.PIPE,
stderr=subprocess.PIPE).wait()
if ret != 0:
raise RuntimeError('ssh tunnel disconnected, please re-connect')
def CheckClient(self):
if self._selected_mid is None:
if self._state.selected_mid is None:
raise RuntimeError('No client is selected')
self._selected_mid = self._state.selected_mid
if not any(client['mid'] == self._selected_mid
for client in self._server.Clients()):
raise RuntimeError('client %s disappeared' % self._selected_mid)
def CheckOutput(self, command):
headers = []
if self._state.username is not None and self._state.password is not None:
headers.append(BasicAuthHeader(self._state.username,
self._state.password))
scheme = 'ws%s://' % ('s' if self._state.ssl else '')
sio = StringIO()
ws = ShellWebSocketClient(
self._state, sio, scheme + '%s:%d/api/agent/shell/%s?command=%s' % (
self._state.host, self._state.port,
urllib.parse.quote(self._selected_mid),
urllib.parse.quote(command)),
headers=headers)
ws.connect()
ws.run()
return sio.getvalue()
@Command('status', 'show Overlord connection status')
def Status(self):
if self._state.host is None:
print('Not connected to any host.')
else:
if self._state.ssh_pid is not None:
print('Connected to %s with SSH tunneling.' % self._state.orig_host)
else:
print('Connected to %s:%d.' % (self._state.host, self._state.port))
if self._selected_mid is None:
self._selected_mid = self._state.selected_mid
if self._selected_mid is None:
print('No client is selected.')
else:
print('Client %s selected.' % self._selected_mid)
@Command('connect', 'connect to Overlord server', [
Arg('host', metavar='HOST', type=str, default='localhost',
help='Overlord hostname/IP'),
Arg('port', metavar='PORT', type=int, default=_OVERLORD_HTTP_PORT,
help='Overlord port'),
Arg('-f', '--forward', dest='ssh_forward', default=False,
action='store_true', help='connect with SSH forwarding to the host'),
Arg('-p', '--ssh-port', dest='ssh_port', default=22, type=int,
help='SSH server port for SSH forwarding'),
Arg('-l', '--ssh-login', dest='ssh_login', default='', type=str,
help='SSH server login name for SSH forwarding'),
Arg('-u', '--user', dest='user', default=None, type=str,
help='Overlord HTTP auth username'),
Arg('-w', '--passwd', dest='passwd', default=None, type=str,
help='Overlord HTTP auth password'),
Arg('-c', '--root-CA', dest='cert', default=None, type=str,
help='Path to root CA certificate, only assign at the first time'),
Arg('-i', '--no-check-hostname', dest='check_hostname', default=True,
action='store_false', help='Ignore SSL cert hostname check'),
Arg('-b', '--certificate-dir', dest='certificate_dir', default=None,
type=str, help='Path to overlord certificate directory')
])
def Connect(self, args):
ssh_pid = None
host = args.host
orig_host = args.host
if args.certificate_dir:
args.cert = os.path.join(args.certificate_dir, 'rootCA.pem')
ovl_password_file = os.path.join(args.certificate_dir, 'ovl_password')
with open(ovl_password_file, 'r') as f:
args.passwd = f.read().strip()
args.user = 'ovl'
if args.cert and os.path.exists(args.cert):
os.makedirs(_CERT_DIR, exist_ok=True)
shutil.copy(args.cert, os.path.join(_CERT_DIR, '%s.cert' % host))
if args.ssh_forward:
# Kill previous SSH tunnel
self.KillSSHTunnel()
ssh_pid = self.SSHTunnel(args.ssh_login, args.host, args.ssh_port)
host = 'localhost'
username_provided = args.user is not None
password_provided = args.passwd is not None
prompt = False
for unused_i in range(3): # pylint: disable=too-many-nested-blocks
try:
if prompt:
if not username_provided:
args.user = input('Username: ')
if not password_provided:
args.passwd = getpass.getpass('Password: ')
ret = self._server.Connect(host, args.port, ssh_pid, args.user,
args.passwd, orig_host,
args.check_hostname)
if isinstance(ret, list):
if ret[0] == 'HTTPError':
code, except_str, body = ret[1:]
if code == 401:
print('connect: %s' % body)
prompt = True
if not username_provided or not password_provided:
continue
break
logging.error('%s; %s', except_str, body)
if ret in ('SSLCertificateNotExisted', 'SSLVerifyFailed'):
print(_TLS_CERT_FAILED_WARNING % ret)
return
if ret is not True:
print('can not connect to %s: %s' % (host, ret))
else:
print('connection to %s:%d established.' % (host, args.port))
except Exception as e:
logging.exception(e)
else:
break
@Command('start-server', 'start overlord CLI client server')
def StartServer(self):
self._server = OverlordClientDaemon.GetRPCServer()
if self._server is None:
OverlordClientDaemon().Start()
time.sleep(1)
self._server = OverlordClientDaemon.GetRPCServer()
if self._server is not None:
print('* daemon started successfully *\n')
@Command('kill-server', 'kill overlord CLI client server')
def KillServer(self):
self._server = OverlordClientDaemon.GetRPCServer()
if self._server is None:
return
self._state = self._server.State()
# Kill SSH Tunnel
self.KillSSHTunnel()
# Kill server daemon
KillGraceful(self._server.GetPid())
def KillSSHTunnel(self):
if self._state.ssh_pid is not None:
KillGraceful(self._state.ssh_pid)
def _FilterClients(self, clients, prop_filters, mid=None):
def _ClientPropertiesMatch(client, key, regex):
try:
return bool(re.search(regex, client['properties'][key]))
except KeyError:
return False
for prop_filter in prop_filters:
key, sep, regex = prop_filter.partition('=')
if not sep:
# The filter doesn't contains =.
raise ValueError('Invalid filter condition %r' % filter)
clients = [c for c in clients if _ClientPropertiesMatch(c, key, regex)]
if mid is not None:
client = next((c for c in clients if c['mid'] == mid), None)
if client:
return [client]
clients = [c for c in clients if c['mid'].startswith(mid)]
return clients
@Command('ls', 'list clients', [
Arg(
'-f', '--filter', default=[], dest='filters', action='append',
help=('Conditions to filter clients by properties. '
'Should be in form "key=regex", where regex is the regular '
'expression that should be found in the value. '
'Multiple --filter arguments would be ANDed.')),
Arg('-m', '--mid-only', default=False, action='store_true',
help='Print mid only.'),
Arg('-v', '--verbose', default=False, action='store_true',
help='Print properties of each client.')
])
def ListClients(self, args):
clients = self._FilterClients(self._server.Clients(), args.filters)
if args.verbose:
for client in clients:
print(yaml.safe_dump(client, default_flow_style=False))
return
# Used in station_setup to ckeck if there is duplicate mid.
if args.mid_only:
for client in clients:
print(client['mid'])
return
def FormatPrint(length, string):
print('%*s' % (length + 2, string), end='|')
columns = [
'mid', 'serial', 'status', 'pytest', 'model', 'ip', 'track_connection'
]
columns_max_len = {column: len(column)
for column in columns}
for client in clients:
for column in columns:
columns_max_len[column] = max(columns_max_len[column],
len(str(client[column])))
for column in columns:
FormatPrint(columns_max_len[column], column)
print()
for client in clients:
for column in columns:
FormatPrint(columns_max_len[column], str(client[column]))
print()
@Command('select', 'select default client', [
Arg('-f', '--filter', default=[], dest='filters', action='append',
help=('Conditions to filter clients by properties. '
'Should be in form "key=regex", where regex is the regular '
'expression that should be found in the value. '
'Multiple --filter arguments would be ANDed.')),
Arg('mid', metavar='mid', nargs='?', default=None)])
def SelectClient(self, args=None, store=True):
mid = args.mid if args is not None else None
filters = args.filters if args is not None else []
clients = self._FilterClients(self._server.Clients(), filters, mid=mid)
if not clients:
raise RuntimeError('select: client not found')
if len(clients) == 1:
mid = clients[0]['mid']
else:
# This case would not happen when args.mid is specified.
print('Select from the following clients:')
for i, client in enumerate(clients):
print(' %d. %s' % (i + 1, client['mid']))
print('\nSelection: ', end='')
try:
choice = int(input()) - 1
mid = clients[choice]['mid']
except ValueError:
raise RuntimeError('select: invalid selection')
except IndexError:
raise RuntimeError('select: selection out of range')
self._selected_mid = mid
if store:
self._server.SelectClient(mid)
print('Client %s selected' % mid)
@Command('shell', 'open a shell or execute a shell command', [
Arg('command', metavar='CMD', nargs='?', help='command to execute')])
def Shell(self, command=None):
if command is None:
command = []
self.CheckClient()
headers = []
if self._state.username is not None and self._state.password is not None:
headers.append(BasicAuthHeader(self._state.username,
self._state.password))
scheme = 'ws%s://' % ('s' if self._state.ssl else '')
if command:
cmd = ' '.join(command)
ws = ShellWebSocketClient(
self._state, sys.stdout,
scheme + '%s:%d/api/agent/shell/%s?command=%s' % (
self._state.host, self._state.port,
urllib.parse.quote(self._selected_mid), urllib.parse.quote(cmd)),
headers=headers)
else:
ws = TerminalWebSocketClient(
self._state, self._selected_mid, self._escape,
scheme + '%s:%d/api/agent/tty/%s' % (
self._state.host, self._state.port,
urllib.parse.quote(self._selected_mid)),
headers=headers)
try:
ws.connect()
ws.run()
except socket.error as e:
if e.errno == 32: # Broken pipe
pass
else:
raise
@Command('push', 'push a file or directory to remote', [
Arg('srcs', nargs='+', metavar='SOURCE'),
Arg('dst', metavar='DESTINATION')])
def Push(self, args):
self.CheckClient()
@AutoRetry('push', _RETRY_TIMES)
def _push(src, dst):
src_base = os.path.basename(src)
# Local file is a link
if os.path.islink(src):
pbar = ProgressBar(src_base)
link_path = os.readlink(src)
self.CheckOutput('mkdir -p %(dirname)s; '
'if [ -d "%(dst)s" ]; then '
'ln -sf "%(link_path)s" "%(dst)s/%(link_name)s"; '
'else ln -sf "%(link_path)s" "%(dst)s"; fi' %
dict(dirname=os.path.dirname(dst),
link_path=link_path, dst=dst,
link_name=src_base))
pbar.End()
return
mode = '0%o' % (0x1FF & os.stat(src).st_mode)
url = ('%s:%d/api/agent/upload/%s?dest=%s&perm=%s' %
(self._state.host, self._state.port,
urllib.parse.quote(self._selected_mid), dst, mode))
try:
UrlOpen(self._state, url + '&filename=%s' % src_base)
except urllib.error.HTTPError as e:
msg = json.loads(e.read()).get('error', None)
raise RuntimeError('push: %s' % msg)
pbar = ProgressBar(src_base)
self._HTTPPostFile(url, src, pbar.SetProgress,
self._state.username, self._state.password)
pbar.End()
def _push_single_target(src, dst):
if os.path.isdir(src):
dst_exists = ast.literal_eval(self.CheckOutput(
'stat %s >/dev/null 2>&1 && echo True || echo False' % dst))
for root, unused_x, files in os.walk(src):
# If destination directory does not exist, we should strip the first
# layer of directory. For example: src_dir contains a single file 'A'
#
# push src_dir dest_dir
#
# If dest_dir exists, the resulting directory structure should be:
# dest_dir/src_dir/A
# If dest_dir does not exist, the resulting directory structure should
# be:
# dest_dir/A
dst_root = root if dst_exists else root[len(src):].lstrip('/')
for name in files:
_push(os.path.join(root, name),
os.path.join(dst, dst_root, name))
else:
_push(src, dst)
if len(args.srcs) > 1:
dst_type = self.CheckOutput('stat \'%s\' --printf \'%%F\' '
'2>/dev/null' % args.dst).strip()
if not dst_type:
raise RuntimeError('push: %s: No such file or directory' % args.dst)
if dst_type != 'directory':
raise RuntimeError('push: %s: Not a directory' % args.dst)
for src in args.srcs:
if not os.path.exists(src):
raise RuntimeError('push: can not stat "%s": no such file or directory'
% src)
if not os.access(src, os.R_OK):
raise RuntimeError('push: can not open "%s" for reading' % src)
_push_single_target(src, args.dst)
@Command('pull', 'pull a file or directory from remote', [
Arg('src', metavar='SOURCE'),
Arg('dst', metavar='DESTINATION', default='.', nargs='?')])
def Pull(self, args):
self.CheckClient()
@AutoRetry('pull', _RETRY_TIMES)
def _pull(src, dst, ftype, perm=0o644, link=None):
try:
os.makedirs(os.path.dirname(dst))
except Exception:
pass
src_base = os.path.basename(src)
# Remote file is a link
if ftype == 'l':
pbar = ProgressBar(src_base)
if os.path.exists(dst):
os.remove(dst)
os.symlink(link, dst)
pbar.End()
return
url = ('%s:%d/api/agent/download/%s?filename=%s' %
(self._state.host, self._state.port,
urllib.parse.quote(self._selected_mid), urllib.parse.quote(src)))
try:
h = UrlOpen(self._state, url)
except urllib.error.HTTPError as e:
msg = json.loads(e.read()).get('error', 'unkown error')
raise RuntimeError('pull: %s' % msg)
except KeyboardInterrupt:
return
pbar = ProgressBar(src_base)
with open(dst, 'wb') as f:
os.fchmod(f.fileno(), perm)
total_size = int(h.headers.get('Content-Length'))
downloaded_size = 0
while True:
data = h.read(_BUFSIZ)
if not data:
break
downloaded_size += len(data)
pbar.SetProgress(downloaded_size * 100 / total_size,
downloaded_size)
f.write(data)
pbar.End()
# Use find to get a listing of all files under a root directory. The 'stat'
# command is used to retrieve the filename and it's filemode.
output = self.CheckOutput(
'cd $HOME; '
'stat "%(src)s" >/dev/null && '
'find "%(src)s" \'(\' -type f -o -type l \')\' '
'-printf \'%%m\t%%p\t%%y\t%%l\n\''
% {'src': args.src})
# We got error from the stat command
if output.startswith('stat: '):
sys.stderr.write(output)
return
entries = output.strip('\n').split('\n')
common_prefix = os.path.dirname(args.src)
if len(entries) == 1:
entry = entries[0]
perm, src_path, ftype, link = entry.split('\t', -1)
if os.path.isdir(args.dst):
dst = os.path.join(args.dst, os.path.basename(src_path))
else:
dst = args.dst
_pull(src_path, dst, ftype, int(perm, base=8), link)
else:
if not os.path.exists(args.dst):
common_prefix = args.src
for entry in entries:
perm, src_path, ftype, link = entry.split('\t', -1)
rel_dst = src_path[len(common_prefix):].lstrip('/')
_pull(src_path, os.path.join(args.dst, rel_dst), ftype,
int(perm, base=8), link)
@Command('forward', 'forward remote port to local port', [
Arg('--list', dest='list_all', action='store_true', default=False,
help='list all port forwarding sessions'),
Arg('--remove', metavar='LOCAL_PORT', dest='remove', type=int,
default=None,
help='remove port forwarding for local port LOCAL_PORT'),
Arg('--remove-all', dest='remove_all', action='store_true',
default=False, help='remove all port forwarding'),
Arg('remote', metavar='REMOTE_PORT', type=int, nargs='?'),
Arg('local', metavar='LOCAL_PORT', type=int, nargs='?')])
def Forward(self, args):
if args.list_all:
max_len = 10
if self._state.forwards:
max_len = max([len(v[0]) for v in self._state.forwards.values()])
print('%-*s %-8s %-8s' % (max_len, 'Client', 'Remote', 'Local'))
for local in sorted(self._state.forwards.keys()):
value = self._state.forwards[local]
print('%-*s %-8s %-8s' % (max_len, value[0], value[1], local))
return
if args.remove_all:
self._server.RemoveAllForward()
return
if args.remove:
self._server.RemoveForward(args.remove)
return
self.CheckClient()
if args.remote is None:
raise RuntimeError('remote port not specified')
if args.local is None:
args.local = net_utils.FindUnusedPort()
remote = int(args.remote)
local = int(args.local)
def HandleConnection(conn):
headers = []
if self._state.username is not None and self._state.password is not None:
headers.append(BasicAuthHeader(self._state.username,
self._state.password))
scheme = 'ws%s://' % ('s' if self._state.ssl else '')
ws = ForwarderWebSocketClient(
self._state, conn,
scheme + '%s:%d/api/agent/forward/%s?port=%d' % (
self._state.host, self._state.port,
urllib.parse.quote(self._selected_mid), remote),
headers=headers)
try:
ws.connect()
ws.run()
except Exception as e:
print('error: %s' % e)
finally:
ws.close()
server = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
server.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
server.bind(('0.0.0.0', local))
server.listen(5)
pid = os.fork()
if pid == 0:
while True:
conn, unused_addr = server.accept()
t = threading.Thread(target=HandleConnection, args=(conn,))
t.daemon = True
t.start()
else:
print('ovl_forward_port: http://localhost:%d' % local)
self._server.AddForward(self._selected_mid, remote, local, pid)
def main():
# Setup logging format
logger = logging.getLogger()
logger.setLevel(logging.DEBUG)
handler = logging.StreamHandler()
formatter = logging.Formatter('%(asctime)s %(message)s', '%Y/%m/%d %H:%M:%S')
handler.setFormatter(formatter)
logger.addHandler(handler)
# Add DaemonState to JSONRPC lib classes
config.DEFAULT.classes.add(DaemonState)
ovl = OverlordCLIClient()
try:
ovl.Main()
except KeyboardInterrupt:
print('Ctrl-C received, abort')
except Exception as e:
logging.exception('exit with error [%s]', e)
if __name__ == '__main__':
main()