diff --git a/.gitignore b/.gitignore index 13f8832..e03df32 100644 --- a/.gitignore +++ b/.gitignore @@ -5,5 +5,7 @@ .cache .coverage .tox +.envrc build dist +venv diff --git a/geofrontcli/cli.py b/geofrontcli/cli.py index 9bb6db9..6d70672 100644 --- a/geofrontcli/cli.py +++ b/geofrontcli/cli.py @@ -8,8 +8,10 @@ import logging import os import os.path +import pprint import subprocess import sys +import time import webbrowser from dirspec.basedir import load_config_paths, save_config_path @@ -21,6 +23,7 @@ NoTokenIdError, ProtocolVersionError, RemoteError, TokenIdError, UnfinishedAuthenticationError) from .key import PublicKey +from .utils import resolve_cmdarg_template from .version import VERSION @@ -54,6 +57,8 @@ version='%(prog)s ' + VERSION) subparsers = parser.add_subparsers() +logger = logging.getLogger('geofrontcli') + def get_server_url(): for path in load_config_paths(CONFIG_RESOURCE): @@ -120,13 +125,17 @@ def authenticate(args): """Authenticate to Geofront server.""" client = get_client() while True: - with client.authenticate() as url: - if args.open_browser: - print('Continue to authenticate in your web browser...') - webbrowser.open(url) - else: - print('Continue to authenticate in your web browser:') - print(url) + try: + with client.authenticate() as url: + if args.open_browser: + print('Continue to authenticate in your web browser...') + webbrowser.open(url) + else: + print('Continue to authenticate in your web browser:') + print(url) + except Exception as e: + # exception info is already provided in client.Client.authenticate() + return input('Press return to continue') try: client.identity @@ -205,18 +214,32 @@ def masterkey(args): def align_remote_list(remotes): - maxlength = max(map(len, remotes)) if remotes else 0 + if remotes: + maxlen_alias = max(map(len, remotes.keys())) + maxlen_user = max(map(lambda v: len(v['user']), remotes.values())) + maxlen_host = max(map(lambda v: len(v['host']), remotes.values())) + else: + maxlen_alias = 1 + maxlen_user = 1 + maxlen_host = 1 for alias, remote in sorted(remotes.items()): - if remote.endswith(':22'): - remote = remote[:-3] - yield '{0:{1}} {2}'.format(alias, maxlength, remote) + yield '{0:{1}} {2:{3}} @ {4:{5}} : {6}'.format( + alias, maxlen_alias, + remote['user'], maxlen_user, + remote['host'], maxlen_host, + remote['port']) @subparser def remotes(args): """List available remotes.""" client = get_client() - remotes = client.remotes + try: + remotes = client.remotes + except Exception: + # exception info is already provided in client.Client.remotes() + return + time.sleep(0.11) if args.alias: for alias in sorted(remotes): print(alias) @@ -233,6 +256,21 @@ def remotes(args): ) +@subparser +def remote(args): + """Get the information of a specific remote.""" + client = get_client() + remote = client.remote(args.remote) + time.sleep(0.11) + pprint.pprint(remote) + + +remote.add_argument( + 'remote', + help='the remote alias that you want to get information about' +) + + @subparser def authorize(args, alias=None): """Temporarily authorize you to access the given remote. @@ -261,30 +299,14 @@ def authorize(args, alias=None): ) -def get_ssh_options(remote): +def mangle_ssh_args(remote): """Translate the given ``remote`` to a corresponding :program:`ssh` - options. For example, it returns the following list for ``'user@host'``:: - - ['-l', 'user', 'host'] - - The remote can contain the port number or omit the user login as well - e.g. ``'host:22'``:: - - ['-p', '22', 'host'] - - """ - remote_match = REMOTE_PATTERN.match(remote) - if not remote_match: - raise ValueError('invalid remote format: ' + str(remote)) - options = [] - user = remote_match.group('user') - if user: - options.extend(['-l', user]) - port = remote_match.group('port') - if port: - options.extend(['-p', port]) - options.append(remote_match.group('host')) - return options + arguments including the login name and the port number explicitly.""" + return [ + '-l', remote['user'], + '-p', str(remote['port']), + remote['host'], + ] @subparser @@ -295,15 +317,15 @@ def colonize(args): """ client = get_client() - remote = client.remotes.get(args.remote, args.remote) try: - options = get_ssh_options(remote) - except ValueError as e: - colonize.error(str(e)) + remote = client.remotes.get(args.remote, args.remote) + except: + # exception info is already provided in client.Client.remote() + return cmd = [args.ssh] if args.identity_file: cmd.extend(['-i', args.identity_file]) - cmd.extend(options) + cmd.extend(mangle_ssh_args(remote)) cmd.extend([ 'mkdir', '~/.ssh', '&>', '/dev/null', '||', 'true', ';', 'echo', repr(str(client.master_key)), @@ -322,62 +344,115 @@ def colonize(args): @subparser -def ssh(args, alias=None): +def ssh(args): """SSH to the remote through Geofront's temporary authorization.""" - remote = authorize.call(args, alias=alias) + if args.tunnel and sys.version_info < (3, 6): + logger.error('To use the SSH proxy, you need to run geofront-cli on ' + 'Python 3.6 or higher.', + extra={'user_waiting': False}) + return + remote_match = REMOTE_PATTERN.match(args.remote) + if not remote_match: + raise ValueError('invalid remote format: ' + str(args.remote)) + alias = remote_match.group('host') + user = remote_match.group('user') + # port from remote_match is ignored + client = get_client() try: - options = get_ssh_options(remote) - except ValueError as e: - ssh.error(str(e)) - subprocess.call([args.ssh] + options) + remote = client.remote(alias, quiet=True) + except Exception: + # exception info is already provided in client.Client.remote() + return + if user and user != remote['user']: + remote['user'] = user # override username + else: + remote = authorize.call(args, alias=alias) + template = [ + args.ssh, + '-l', '$user', + '-p', '$port', + ] + if args.identity: + template.extend(['-i', args.identity]) + if args.dynamic_port: + template.extend(['-D', args.dynamic_port]) + template.append('$host') + if args.tunnel: + client.ssh_proxy(template, remote, alias or args.remote) + else: + cmdargs = resolve_cmdarg_template(template, remote) + subprocess.call(cmdargs) ssh.add_argument('remote', help='the remote alias to ssh') +ssh.add_argument('-i', '--identity', + help='alternative SSH identity (private key)') +ssh.add_argument('-D', '--dynamic-port', + help='port number to use for dynamic TCP forwarding') +ssh.add_argument('-t', '--tunnel', action='store_true', default=False, + help='use SSH tunneling via HTTPS WebSockets to access' + 'servers inside remote private networks') def parse_scp_path(path, args): """Parse remote:path format.""" if ':' not in path: return None, path - alias, path = path.split(':', 1) - remote = authorize.call(args, alias=alias) - return remote, path + host, path = path.split(':', 1) + return host, path @subparser def scp(args): - options = [] - src_remote, src_path = parse_scp_path(args.source, args) - dst_remote, dst_path = parse_scp_path(args.destination, args) - if src_remote and dst_remote: + """SCP from/to the remote through Geofront's temporary authorization.""" + if args.tunnel and sys.version_info < (3, 6): + logger.error('To use the SSH proxy, you need to run geofront-cli on ' + 'Python 3.6 or higher.', + extra={'user_waiting': False}) + template = [args.scp] + src_host, src_path = parse_scp_path(args.source, args) + dst_host, dst_path = parse_scp_path(args.destination, args) + if src_host and dst_host: scp.error('source and destination cannot be both ' 'remote paths at a time') - elif not (src_remote or dst_remote): + elif not (src_host or dst_host): scp.error('one of source and destination has to be a remote path') if args.ssh: - options.extend(['-S', args.ssh]) + template.extend(['-S', args.ssh]) if args.recursive: - options.append('-r') - remote = src_remote or dst_remote - remote_match = REMOTE_PATTERN.match(remote) - if not remote_match: - raise ValueError('invalid remote format: ' + str(remote)) - port = remote_match.group('port') - if port: - options.extend(['-P', port]) - host = remote_match.group('host') - user = remote_match.group('user') - if user: - host = user + '@' + host - if src_remote: - options.append(host + ':' + src_path) + template.append('-r') + if args.identity: + template.extend(['-i', args.identity]) + host = src_host or dst_host + host_match = REMOTE_PATTERN.match(host) + if not host_match: + raise ValueError('invalid remote format: ' + str(host)) + alias = host_match.group('host') + user = host_match.group('user') + # port from host_match is ignored + template.extend(['-P', '$port']) + if src_host: + template.append('$user@$host:' + src_path) else: - options.append(src_path) - if dst_remote: - options.append(host + ':' + dst_path) + template.append(src_path) + if dst_host: + template.append('$user@$host:' + dst_path) else: - options.append(dst_path) - subprocess.call([args.scp] + options) + template.append(dst_path) + client = get_client() + remote = client.remote(alias, quiet=True) + if user and user != remote_info['user']: + remote['user'] = user # override username + else: + remote = authorize.call(args, alias=alias) + if args.tunnel: + client.ssh_proxy(template, remote, alias) + else: + subprocess.call(resolve_cmdarg_template(template, { + 'host': remote['host'], + 'user': remote['user'], + 'port': remote['port'], + })) scp.add_argument( @@ -388,9 +463,14 @@ def scp(args): ) scp.add_argument( '-r', '-R', '--recursive', - action='store_true', + action='store_true', default=False, help='recursively copy entire directories' ) +scp.add_argument('-i', '--identity', + help='alternative SSH identity (private key)') +scp.add_argument('-t', '--tunnel', action='store_true', default=False, + help='use SSH tunneling via HTTPS WebSockets to access' + 'servers inside remote private networks') scp.add_argument('source', help='the source path to copy') scp.add_argument('destination', help='the destination path') @@ -399,7 +479,11 @@ def scp(args): def go(args): """Select a remote and SSH to it at once (in interactive way).""" client = get_client() - remotes = client.remotes + try: + remotes = client.remotes + except Exception: + # exception info is already provided in client.Client.remotes() + return chosen = iterfzf(align_remote_list(remotes)) if chosen is None: return @@ -485,6 +569,8 @@ def main(args=None): parser.exit('geofront-cli seems incompatible with the server.\n' 'Try `pip install --upgrade geofront-cli` command.\n' 'The server version is {0}.'.format(e.server_version)) + except KeyboardInterrupt: + parser.exit('Aborted.') else: parser.print_usage() diff --git a/geofrontcli/client.py b/geofrontcli/client.py index 95af5d2..e7fa50b 100644 --- a/geofrontcli/client.py +++ b/geofrontcli/client.py @@ -13,13 +13,15 @@ from keyring import get_password, set_password from six import string_types -from six.moves.urllib.error import HTTPError +from six.moves.urllib.error import HTTPError, URLError from six.moves.urllib.parse import urljoin from six.moves.urllib.request import OpenerDirector, Request, build_opener from .key import PublicKey from .ssl import create_urllib_https_handler from .version import MIN_PROTOCOL_VERSION, MAX_PROTOCOL_VERSION, VERSION +if sys.version_info >= (3, 6): # pragma: no cover + from .proxy import start_ssh_proxy __all__ = ('REMOTE_PATTERN', 'BufferedResponse', 'Client', 'ExpiredTokenIdError', @@ -82,8 +84,12 @@ def request(self, method, url, data=None, headers={}): try: response = self.opener.open(request) except HTTPError as e: - logger.exception(e) - response = e + logger.error('{0}: returned {1} {2}'.format(url, e.code, e.reason)) + raise + except URLError as e: + logger.error('{0}: errored {1}'.format(url, e.reason)) + logger.error('Maybe you are not connected to the Internet!') + raise server_version = response.headers.get('X-Geofront-Version') if server_version: try: @@ -167,7 +173,7 @@ def identity(self): @property def master_key(self): """(:class:`~.key.PublicKey`) The current master key.""" - path = ('tokens', self.token_id, 'masterkey') + path = ('masterkey',) headers = {'Accept': 'text/plain'} with self.request('GET', path, headers=headers) as r: if r.code == 200: @@ -192,16 +198,38 @@ def remotes(self): mimetype, _ = parse_mimetype(r.headers['Content-Type']) assert mimetype == 'application/json' result = json.loads(r.read().decode('utf-8')) - fmt = '{0[user]}@{0[host]}:{0[port]}'.format logger.info('Total %d remotes.', len(result), extra={'user_waiting': False}) - return dict((alias, fmt(remote)) - for alias, remote in result.items()) - except: + return result + except Exception: logger.info('Failed to fetch the list of remotes.', extra={'user_waiting': False}) raise + def remote(self, alias, quiet=False): + """(:class:`dict`) The remote information including user, host, and + port. + + """ + logger = self.logger.getChild('remote') + if not quiet: + logger.info('Loading the remote information from the Geofront ' + 'server...', extra={'user_waiting': True}) + try: + path = ('tokens', self.token_id, 'remotes', alias) + with self.request('GET', path) as r: + assert r.code == 200 + mimetype, _ = parse_mimetype(r.headers['Content-Type']) + assert mimetype == 'application/json' + result = json.loads(r.read().decode('utf-8')) + if not quiet: + logger.info('Done.', extra={'user_waiting': False}) + except Exception: + logger.info('Failed to fetch the remote information.', + extra={'user_waiting': False}) + raise + return result['remote'] + def authorize(self, alias): """Temporarily authorize you to access the given remote ``alias``. A made authorization keeps alive in a minute, and then will be expired. @@ -235,7 +263,20 @@ def authorize(self, alias): logger.info('Access to %s has authorized! The access will be ' 'available only for a time.', alias, extra={'user_waiting': False}) - return '{0[user]}@{0[host]}:{0[port]}'.format(result['remote']) + return result['remote'] + + if sys.version_info >= (3, 6): # pragma: no cover + def ssh_proxy(self, cmd_template, remote, alias): + logger = self.logger.getChild('ssh_proxy') + try: + path = ('ws', 'tokens', self.token_id, 'remotes', alias, 'ssh') + url = './{0}/'.format('/'.join(path)) + url = urljoin(self.server_url, url) + except TokenIdError: + logger.info('Authentication is required.', + extra={'user_waiting': False}) + raise + start_ssh_proxy(cmd_template, url, remote) def __repr__(self): return '{0.__module__}.{0.__name__}({1!r})'.format( diff --git a/geofrontcli/proxy.py b/geofrontcli/proxy.py new file mode 100644 index 0000000..de24d11 --- /dev/null +++ b/geofrontcli/proxy.py @@ -0,0 +1,195 @@ +""":mod:`geofrontcli.proxy` --- Local SSH proxy over HTTPS/WebSocket +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +""" +import asyncio +import contextlib +import csv +import logging +import os +import pathlib +import signal +import socket +import sys +import traceback + +from aiohttp import ClientError, ClientSession, WSMsgType +from aiotools import actxmgr, start_server +from dirspec.basedir import load_config_paths, save_config_path + +from .utils import resolve_cmdarg_template +from .version import VERSION + +__all__ = ('start_ssh_proxy', ) + + +CONFIG_RESOURCE = 'geofront-cli' +PROXY_PORT_MAP_FILENAME = 'proxyports.csv' + +logger = logging.getLogger(__name__) + + +def get_unused_port(): + s = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + with contextlib.closing(s): + s.bind(('localhost', 0)) + return s.getsockname()[1] + + +def load_proxy_port_map(): + data = dict() + for path in load_config_paths(CONFIG_RESOURCE): + path = pathlib.Path(path.decode()) / PROXY_PORT_MAP_FILENAME + if path.is_file(): + with open(path) as f: + for row in csv.reader(f): + data[row[0]] = int(row[1]) + return data + + +def save_proxy_port_map(data): + config_path = (pathlib.Path(save_config_path(CONFIG_RESOURCE).decode()) + / PROXY_PORT_MAP_FILENAME) + with open(config_path, 'w') as f: + writer = csv.writer(f) + for key, val in data.items(): + writer.writerow((key, val)) + logger.info(f'To modify port-host mapping, check out {config_path}.', + extra={'user_waiting': False}) + + +def get_port_for_remote(host): + data = load_proxy_port_map() + if host in data: + return data[host] + else: + port = get_unused_port() + data[host] = port + logger.info(f'Mapped port {port} with host {host}.', + extra={'user_waiting': False}) + save_proxy_port_map(data) + return port + + +async def pipe(cmd_tpl, url, remote): + """The main task that proxies the incoming SSH traffic via WebSockets.""" + loop = asyncio.get_event_loop() + headers = { + 'User-Agent': 'geofront-cli/{0} (Python-asyncio/{1})'.format( + VERSION, sys.version[:3] + ), + } + + async def handle_ssh_sock(ws, ssh_sock): + """A sub-task that proxies the outgoing SSH traffic via WebSocket.""" + while True: + try: + data = await loop.sock_recv(ssh_sock, 4096) + except asyncio.CancelledError: + break + if not data: + break + ws.send_bytes(data) + + async def handle_subproc(cmd, pipe_task): + """Launch the local SSH agent and wait until it terminates.""" + try: + proc = await asyncio.create_subprocess_exec( + *cmd, + stdin=None, stdout=None, stderr=None, # inherit + close_fds=True, + ) + await proc.wait() + pipe_task.cancel() # signal to terminate + except: + logger.error('Unexpected error!', extra={'user_waiting': False}) + traceback.print_exc() + + local_sock = None + ssh_sock = None + ssh_reader_task = None + subproc_task = None + + logger.info(f"Making a local SSH proxy to {remote['host']}...", + extra={'user_waiting': True}) + + # TODO: response header version check? + session = ClientSession() + try: + sock_type = socket.SOCK_STREAM + if hasattr(socket, 'SOCK_NONBLOCK'): # only for Linux + sock_type |= socket.SOCK_NONBLOCK + local_sock = socket.socket(socket.AF_INET, sock_type) + local_sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) + bind_port = get_port_for_remote(f"{remote['host']}:{remote['port']}") + try: + local_sock.bind(('localhost', bind_port)) + except OSError: + logger.error(f'Cannot bind to port {bind_port}!', + extra={'user_waiting': False}) + return + local_sock.listen(1) + logger.info(f'Connecting to local SSH proxy at port {bind_port}...', + extra={'user_waiting': False}) + async with session.ws_connect(url, headers=headers) as ws: + cmdargs = resolve_cmdarg_template(cmd_tpl, { + 'host': 'localhost', + 'user': remote['user'], + 'port': str(bind_port), + }) + subproc_task = loop.create_task( + handle_subproc(cmdargs, asyncio.Task.current_task())) + await asyncio.sleep(0) # required! + ssh_sock, _ = await loop.sock_accept(local_sock) + ssh_sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1) + local_sock.close() # used only once + ssh_reader_task = loop.create_task(handle_ssh_sock(ws, ssh_sock)) + async for msg in ws: + if msg.type == WSMsgType.BINARY: + await loop.sock_sendall(ssh_sock, msg.data) + elif msg.type == WSMsgType.CLOSED: + break + elif msg.type == WSMsgType.ERROR: + logger.error('Server disconnected unexpectedly.', + extra={'user_waiting': False}) + break + except ClientError: + logger.error('Connection error!', extra={'user_waiting': False}) + raise + except asyncio.CancelledError: + pass + except: + logger.error('Unexpected error!', extra={'user_waiting': False}) + traceback.print_exc() + finally: + if subproc_task and not subproc_task.done(): + subproc_task.cancel() + await subproc_task + if ssh_reader_task and not ssh_reader_task.done(): + ssh_reader_task.cancel() + await ssh_reader_task + if ssh_sock: + ssh_sock.close() + session.close() + # inform the main that we finished + os.kill(0, signal.SIGINT) + + +@actxmgr +async def serve_proxy(loop, pidx, args): + """The initialize and shtudown routines for the local SSH proxy.""" + pipe_task = None + try: + pipe_task = loop.create_task(pipe(*args)) + yield + finally: + if pipe_task and not pipe_task.done(): + pipe_task.cancel() + await pipe_task + + +def start_ssh_proxy(cmd_tpl, url, remote): + start_server(serve_proxy, + args=(cmd_tpl, url, remote), + use_threading=True, + num_workers=1) diff --git a/geofrontcli/utils.py b/geofrontcli/utils.py new file mode 100644 index 0000000..74a15ec --- /dev/null +++ b/geofrontcli/utils.py @@ -0,0 +1,27 @@ +""":mod:`geofrontcli.utils` --- Utility functions +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +""" +import re + +__all__ = ('resolve_cmdarg_template', ) + + +CMDARG_VAR_PATTERN = re.compile(r'\$(?P[A-Za-z]\w*)') + + +def resolve_cmdarg_template(template, vars): + resolved = template[:] + + def resolve_var(matchobj): + name = matchobj.group('name') + return str(vars[name]) + + for idx, piece in enumerate(resolved): + if isinstance(piece, bytes): + continue + new_piece, num_replaced = CMDARG_VAR_PATTERN.subn(resolve_var, piece) + if num_replaced: + resolved[idx] = new_piece + + return resolved diff --git a/geofrontcli/version.py b/geofrontcli/version.py index eb0e158..dc80e44 100644 --- a/geofrontcli/version.py +++ b/geofrontcli/version.py @@ -6,16 +6,16 @@ #: (:class:`tuple`) The triple of version numbers e.g. ``(1, 2, 3)``. -VERSION_INFO = (0, 4, 1) +VERSION_INFO = (0, 5, 1) #: (:class:`str`) The version string e.g. ``'1.2.3'``. VERSION = '{0}.{1}.{2}'.format(*VERSION_INFO) #: (:class:`tuple`) The minimum compatible version of server protocol. -MIN_PROTOCOL_VERSION = (0, 2, 0) +MIN_PROTOCOL_VERSION = (0, 5, 0) #: (:class:`tuple`) The maximum compatible version of server protocol. -MAX_PROTOCOL_VERSION = (0, 4, 999) +MAX_PROTOCOL_VERSION = (0, 6, 999) if __name__ == '__main__': diff --git a/setup.py b/setup.py index 30c3e37..4c92772 100644 --- a/setup.py +++ b/setup.py @@ -32,6 +32,11 @@ def readme(): 'enum34', } +py36_or_higher_requires = { + 'aiohttp ~= 2.3.0', + 'aiotools ~= 0.5.0', +} + win32_requires = { 'pypiwin32', } @@ -39,6 +44,9 @@ def readme(): if sys.version_info < (3, 4): install_requires.update(below_py34_requires) +if sys.version_info >= (3, 6): + install_requires.update(py36_or_higher_requires) + if sys.platform == 'win32': install_requires.update(win32_requires) @@ -55,14 +63,16 @@ def readme(): maintainer_email='dev' '@' 'spoqa.com', license='GPLv3 or later', packages=find_packages(exclude=['tests']), - entry_points=''' - [console_scripts] - geofront-cli = geofrontcli.cli:main - gfg = geofrontcli.cli:main_go - ''', + entry_points={ + 'console_scripts': [ + 'geofront-cli = geofrontcli.cli:main', + 'gfg = geofrontcli.cli:main_go', + ], + }, install_requires=list(install_requires), extras_require={ ":python_version<'3.4'": list(below_py34_requires), + ":python_version>='3.6'": list(py36_or_higher_requires), ":sys_platform=='win32'": list(win32_requires), }, classifiers=[