From 498093f2f6b404903578277224a912e131b7313b Mon Sep 17 00:00:00 2001 From: Taylor Buchanan Date: Sat, 18 Sep 2021 14:00:20 -0500 Subject: [PATCH] Add Windows native SSH support --- libagent/device/ui.py | 6 +- libagent/ssh/__init__.py | 110 +++++++++++++++------- libagent/win_server.py | 191 +++++++++++++++++++++++++++++++++++++++ setup.py | 1 + tox.ini | 1 + 5 files changed, 275 insertions(+), 34 deletions(-) create mode 100644 libagent/win_server.py diff --git a/libagent/device/ui.py b/libagent/device/ui.py index 00486262..ebc88995 100644 --- a/libagent/device/ui.py +++ b/libagent/device/ui.py @@ -78,7 +78,8 @@ def button_request(self, _code=None): def create_default_options_getter(): """Return current TTY and DISPLAY settings for GnuPG pinentry.""" options = [] - if sys.stdin.isatty(): # short-circuit calling `tty` + # Windows reports that it has a TTY but throws FileNotFoundError + if sys.platform != 'win32' and sys.stdin.isatty(): # short-circuit calling `tty` try: ttyname = subprocess.check_output(args=['tty']).strip() options.append(b'ttyname=' + ttyname) @@ -88,7 +89,8 @@ def create_default_options_getter(): display = os.environ.get('DISPLAY') if display is not None: options.append('display={}'.format(display).encode('ascii')) - else: + # Windows likely doesn't support this anyway + elif sys.platform != 'win32': log.warning('DISPLAY not defined') log.info('using %s for pinentry options', options) diff --git a/libagent/ssh/__init__.py b/libagent/ssh/__init__.py index 614404d1..3f265abe 100644 --- a/libagent/ssh/__init__.py +++ b/libagent/ssh/__init__.py @@ -4,24 +4,32 @@ import io import logging import os +import random import re import signal +import string import subprocess import sys import tempfile import threading import configargparse -import daemon +try: + # TODO: Not supported on Windows. Use daemoniker instead? + import daemon +except ImportError: + daemon = None import pkg_resources -from .. import device, formats, server, util +from .. import device, formats, server, util, win_server from . import client, protocol log = logging.getLogger(__name__) UNIX_SOCKET_TIMEOUT = 0.1 - +WIN_PIPE_TIMEOUT = 0.1 +DEFAULT_TIMEOUT = WIN_PIPE_TIMEOUT if sys.platform == 'win32' else UNIX_SOCKET_TIMEOUT +SOCK_TYPE = 'Windows named pipe' if sys.platform == 'win32' else 'UNIX domain socket' def ssh_args(conn): """Create SSH command for connecting specified server.""" @@ -35,7 +43,7 @@ def ssh_args(conn): if 'user' in identity: args += ['-l', identity['user']] - args += ['-o', 'IdentityFile={}'.format(pubkey_tempfile.name)] + args += ['-o', 'IdentityFile={}'.format(pubkey_tempfile)] args += ['-o', 'IdentitiesOnly=true'] return args + [identity['host']] @@ -83,14 +91,14 @@ def create_agent_parser(device_type): default=formats.CURVE_NIST256, help='specify ECDSA curve name: ' + curve_names) p.add_argument('--timeout', - default=UNIX_SOCKET_TIMEOUT, type=float, + default=DEFAULT_TIMEOUT, type=float, help='timeout for accepting SSH client connections') p.add_argument('--debug', default=False, action='store_true', help='log SSH protocol messages for debugging.') p.add_argument('--log-file', type=str, help='Path to the log file (to be written by the agent).') p.add_argument('--sock-path', type=str, - help='Path to the UNIX domain socket of the agent.') + help='Path to the ' + SOCK_TYPE + ' of the agent.') p.add_argument('--pin-entry-binary', type=str, default='pinentry', help='Path to PIN entry UI helper.') @@ -100,17 +108,20 @@ def create_agent_parser(device_type): help='Expire passphrase from cache after this duration.') g = p.add_mutually_exclusive_group() - g.add_argument('-d', '--daemonize', default=False, action='store_true', - help='Daemonize the agent and print its UNIX socket path') + if daemon: + g.add_argument('-d', '--daemonize', default=False, action='store_true', + help='Daemonize the agent and print its ' + SOCK_TYPE) g.add_argument('-f', '--foreground', default=False, action='store_true', - help='Run agent in foreground with specified UNIX socket path') + help='Run agent in foreground with specified ' + SOCK_TYPE) g.add_argument('-s', '--shell', default=False, action='store_true', help=('run ${SHELL} as subprocess under SSH agent, allowing ' 'regular SSH-based tools to be used in the shell')) g.add_argument('-c', '--connect', default=False, action='store_true', help='connect to specified host via SSH') - g.add_argument('--mosh', default=False, action='store_true', - help='connect to specified host via using Mosh') + # Windows doesn't have native mosh + if sys.platform != 'win32': + g.add_argument('--mosh', default=False, action='store_true', + help='connect to specified host via using Mosh') p.add_argument('identity', type=_to_unicode, default=None, help='proto://[user@]host[:port][/path]') @@ -119,18 +130,48 @@ def create_agent_parser(device_type): return p +def get_ssh_env(sock_path): + ssh_version = subprocess.check_output(['ssh', '-V'], + stderr=subprocess.STDOUT) + log.debug('local SSH version: %r', ssh_version) + return {'SSH_AUTH_SOCK': sock_path, 'SSH_AGENT_PID': str(os.getpid())} + + +# Windows doesn't support AF_UNIX yet +# https://bugs.python.org/issue33408 +@contextlib.contextmanager +def serve_win(handler, sock_path, timeout=WIN_PIPE_TIMEOUT): + """ + Start the ssh-agent server on a Windows named pipe. + """ + environ = get_ssh_env(sock_path) + device_mutex = threading.Lock() + quit_event = threading.Event() + handle_conn = functools.partial(win_server.handle_connection, + handler=handler, + mutex=device_mutex, + quit_event=quit_event) + kwargs = dict(pipe_name=sock_path, + handle_conn=handle_conn, + quit_event=quit_event, + timeout=timeout) + with server.spawn(win_server.server_thread, kwargs): + try: + yield environ + finally: + log.debug('closing server') + quit_event.set() + + @contextlib.contextmanager -def serve(handler, sock_path, timeout=UNIX_SOCKET_TIMEOUT): +def serve_unix(handler, sock_path, timeout=UNIX_SOCKET_TIMEOUT): """ Start the ssh-agent server on a UNIX-domain socket. If no connection is made during the specified timeout, retry until the context is over. """ - ssh_version = subprocess.check_output(['ssh', '-V'], - stderr=subprocess.STDOUT) - log.debug('local SSH version: %r', ssh_version) - environ = {'SSH_AUTH_SOCK': sock_path, 'SSH_AGENT_PID': str(os.getpid())} + environ = get_ssh_env(sock_path) device_mutex = threading.Lock() with server.unix_domain_socket_server(sock_path) as sock: sock.settimeout(timeout) @@ -154,12 +195,15 @@ def run_server(conn, command, sock_path, debug, timeout): ret = 0 try: handler = protocol.Handler(conn=conn, debug=debug) - with serve(handler=handler, sock_path=sock_path, - timeout=timeout) as env: + serve_platform = serve_win if sys.platform == 'win32' else serve_unix + with serve_platform(handler=handler, sock_path=sock_path, timeout=timeout) as env: if command: ret = server.run_process(command=command, environ=env) else: - signal.pause() # wait for signal (e.g. SIGINT) + try: + signal.pause() # wait for signal (e.g. SIGINT) + except AttributeError: + sys.stdin.read() # Windows doesn't support signal.pause except KeyboardInterrupt: log.info('server stopped') return ret @@ -221,10 +265,9 @@ def public_keys_as_files(self): """Store public keys as temporary SSH identity files.""" if not self.public_keys_tempfiles: for pk in self.public_keys(): - f = tempfile.NamedTemporaryFile(prefix='trezor-ssh-pubkey-', mode='w') - f.write(pk) - f.flush() - self.public_keys_tempfiles.append(f) + with tempfile.NamedTemporaryFile(prefix='trezor-ssh-pubkey-', mode='w', delete=False, newline='') as f: + f.write(pk) + self.public_keys_tempfiles.append(f.name) return self.public_keys_tempfiles @@ -241,13 +284,16 @@ def _dummy_context(): def _get_sock_path(args): sock_path = args.sock_path - if not sock_path: - if args.foreground: - log.error('running in foreground mode requires specifying UNIX socket path') - sys.exit(1) - else: - sock_path = tempfile.mktemp(prefix='trezor-ssh-agent-') - return sock_path + if sock_path: + return sock_path + elif args.foreground: + log.error('running in foreground mode requires specifying ' + SOCK_TYPE) + sys.exit(1) + elif sys.platform == 'win32': + suffix = random.choices(string.ascii_letters, k=10) + return '\\\\.\pipe\\trezor-ssh-agent-' + ''.join(suffix) + else: + return tempfile.mktemp(prefix='trezor-ssh-agent-') @handle_connection_error @@ -286,7 +332,7 @@ def main(device_type): command = ['ssh'] + ssh_args(conn) + args.command elif args.mosh: command = ['mosh'] + mosh_args(conn) + args.command - elif args.daemonize: + elif daemon and args.daemonize: out = 'SSH_AUTH_SOCK={0}; export SSH_AUTH_SOCK;\n'.format(sock_path) sys.stdout.write(out) sys.stdout.flush() @@ -300,7 +346,7 @@ def main(device_type): command = os.environ['SHELL'] sys.stdin.close() - if command or args.daemonize or args.foreground: + if command or (daemon and args.daemonize) or args.foreground: with context: return run_server(conn=conn, command=command, sock_path=sock_path, debug=args.debug, timeout=args.timeout) diff --git a/libagent/win_server.py b/libagent/win_server.py new file mode 100644 index 00000000..2bccffdd --- /dev/null +++ b/libagent/win_server.py @@ -0,0 +1,191 @@ +"""Windows named pipe server for ssh-agent implementation.""" +import logging +import pywintypes +import struct +import threading +import win32api +import win32event +import win32pipe +import win32file +import winerror + +from . import util + +log = logging.getLogger(__name__) + +PIPE_BUFFER_SIZE = 64 * 1024 + +# Make MemoryView look like a buffer to reuse util.recv +class MvBuffer: + def __init__(self, mv): + self.mv = mv + def read(self, n): + return self.mv[0:n] + +# Based loosely on https://docs.microsoft.com/en-us/windows/win32/ipc/multithreaded-pipe-server +class NamedPipe: + __frame_size_size = struct.calcsize('>L') + + def __close(handle): + """Closes a named pipe handle.""" + if handle == win32file.INVALID_HANDLE_VALUE: + return + win32file.FlushFileBuffers(handle) + win32pipe.DisconnectNamedPipe(handle) + win32api.CloseHandle(handle) + + def open(name): + """Opens a named pipe server for receiving connections.""" + handle = win32pipe.CreateNamedPipe( + name, + win32pipe.PIPE_ACCESS_DUPLEX | win32file.FILE_FLAG_OVERLAPPED, + win32pipe.PIPE_TYPE_MESSAGE | win32pipe.PIPE_READMODE_MESSAGE | win32pipe.PIPE_WAIT, + win32pipe.PIPE_UNLIMITED_INSTANCES, + PIPE_BUFFER_SIZE, + PIPE_BUFFER_SIZE, + 0, + None) # Default security attributes + + if handle == win32file.INVALID_HANDLE_VALUE: + log.error("CreateNamedPipe failed (%d)", win32api.GetLastError()) + return None + + try: + pending_io = False + overlapped = win32file.OVERLAPPED() + overlapped.hEvent = win32event.CreateEvent(None, True, True, None) + error_code = win32pipe.ConnectNamedPipe(handle, overlapped) + if error_code == winerror.ERROR_IO_PENDING: + pending_io = True + elif error_code != winerror.ERROR_PIPE_CONNECTED or not win32event.SetEvent(overlapped.hEvent): + log.error('ConnectNamedPipe failed (%d)', error_code) + return None + log.debug('waiting for connection on %s', name) + return NamedPipe(name, handle, overlapped, pending_io) + except: + NamedPipe.__close(handle) + raise + + def __init__(self, name, handle, overlapped, pending_io): + self.name = name + self.handle = handle + self.overlapped = overlapped + self.pending_io = pending_io + + def close(self): + """Close the named pipe.""" + NamedPipe.__close(self.handle) + + def connect(self, timeout): + """Connect to an SSH client with the specified timeout.""" + waitHandle = win32event.WaitForSingleObject( + self.overlapped.hEvent, + timeout) + if waitHandle == win32event.WAIT_TIMEOUT: + return False + if not self.pending_io: + return True + win32pipe.GetOverlappedResult( + self.handle, + self.overlapped, + False) + error_code = win32api.GetLastError() + if error_code == winerror.NO_ERROR: + return True + log.error('GetOverlappedResult failed (%d)', error_code) + return False + + def read_frame(self, quit_event): + """Read the request frame from the SSH client.""" + request_size = None + remaining = None + buf = MvBuffer(win32file.AllocateReadBuffer(PIPE_BUFFER_SIZE)) + while True: + if quit_event.is_set(): + return None + error_code, _ = win32file.ReadFile(self.handle, buf.mv, self.overlapped) + if error_code not in (winerror.NO_ERROR, winerror.ERROR_IO_PENDING, winerror.ERROR_MORE_DATA): + log.error('ReadFile failed (%d)', error_code) + return None + win32event.WaitForSingleObject(self.overlapped.hEvent, win32event.INFINITE) + chunk_size = win32pipe.GetOverlappedResult(self.handle, self.overlapped, False) + error_code = win32api.GetLastError() + if error_code != winerror.NO_ERROR: + log.error('GetOverlappedResult failed (%d)', error_code) + return None + if request_size: + remaining -= chunk_size + else: + request_size, = util.recv(buf, '>L') + remaining = request_size - (chunk_size - NamedPipe.__frame_size_size) + if remaining <= 0: + break + return util.recv(buf, request_size) + + def send(self, reply): + """Send the specified reply to the SSH client.""" + error_code, _ = win32file.WriteFile(self.handle, reply) + if error_code == winerror.NO_ERROR: + return True + log.error('WriteFile failed (%d)', error_code) + return False + + +def handle_connection(pipe, handler, mutex, quit_event): + """ + Handle a single connection using the specified protocol handler in a loop. + + Since this function may be called concurrently from server_thread, + the specified mutex is used to synchronize the device handling. + """ + log.debug('welcome agent') + + try: + while True: + if quit_event.is_set(): + return + msg = pipe.read_frame(quit_event) + if not msg: + return + with mutex: + reply = handler.handle(msg=msg) + if not pipe.send(reply): + return + except pywintypes.error as e: + # Surface errors that aren't related to the client disconnecting + if e.args[0] == winerror.ERROR_BROKEN_PIPE: + log.debug('goodbye agent') + else: + raise + except Exception as e: # pylint: disable=broad-except + log.warning('error: %s', e, exc_info=True) + finally: + pipe.close() + + +def server_thread(pipe_name, handle_conn, quit_event, timeout): + """Run a Windows server on the specified pipe.""" + log.debug('server thread started') + + while True: + if quit_event.is_set(): + break + # A new pipe instance is necessary for each client + pipe = NamedPipe.open(pipe_name) + if not pipe: + break + try: + # Poll for a new client connection + while True: + if quit_event.is_set(): + break + if pipe.connect(timeout * 1000): + # Handle connections from SSH concurrently. + threading.Thread(target=handle_conn, + kwargs=dict(pipe=pipe)).start() + break + except: + pipe.close() + raise + + log.debug('server thread stopped') diff --git a/setup.py b/setup.py index 97b184a4..8f8b5cb4 100755 --- a/setup.py +++ b/setup.py @@ -27,6 +27,7 @@ 'pymsgbox>=1.0.6', 'semver>=2.2', 'unidecode>=0.4.20', + 'pypiwin32' ], platforms=['POSIX'], classifiers=[ diff --git a/tox.ini b/tox.ini index 3d6cc1a0..f4473213 100644 --- a/tox.ini +++ b/tox.ini @@ -14,6 +14,7 @@ deps= semver pydocstyle isort<5 + pypiwin32 commands= pycodestyle libagent isort --skip-glob .tox -c -rc libagent