From d38c788785c745399402a9d2c448b19353abd863 Mon Sep 17 00:00:00 2001 From: SlugFiller <5435495+SlugFiller@users.noreply.github.com> Date: Fri, 22 Sep 2023 07:46:06 +0300 Subject: [PATCH] Port to trio async library --- libagent/age/__init__.py | 92 ++++---- libagent/age/client.py | 18 +- libagent/device/interface.py | 3 +- libagent/device/trezor.py | 1 - libagent/device/ui.py | 335 +++++++++++++++++++++------- libagent/gpg/__init__.py | 213 +++++++++--------- libagent/gpg/agent.py | 142 ++++++------ libagent/gpg/client.py | 26 +-- libagent/gpg/encode.py | 12 +- libagent/gpg/keyring.py | 124 +++++----- libagent/gpg/protocol.py | 6 +- libagent/gpg/tests/test_keyring.py | 39 ++-- libagent/gpg/tests/test_protocol.py | 7 +- libagent/server.py | 168 +++++++------- libagent/signify/__init__.py | 65 +++--- libagent/ssh/__init__.py | 199 ++++++++--------- libagent/ssh/client.py | 20 +- libagent/ssh/protocol.py | 18 +- libagent/ssh/tests/test_client.py | 48 ++-- libagent/ssh/tests/test_protocol.py | 74 +++--- libagent/tests/test_server.py | 132 ++++++----- libagent/tests/test_util.py | 55 ++--- libagent/util.py | 207 +++++++++++++++-- libagent/win_server.py | 334 ++++++++++++--------------- setup.py | 2 + tox.ini | 3 + 26 files changed, 1328 insertions(+), 1015 deletions(-) diff --git a/libagent/age/__init__.py b/libagent/age/__init__.py index dd2fbe66..9556c19b 100644 --- a/libagent/age/__init__.py +++ b/libagent/age/__init__.py @@ -16,6 +16,7 @@ import bech32 import pkg_resources +import trio from cryptography.exceptions import InvalidTag from cryptography.hazmat.primitives.ciphers.aead import ChaCha20Poly1305 @@ -37,22 +38,23 @@ def bech32_encode(prefix, data): return bech32.bech32_encode(prefix, bech32.convertbits(bytes(data), 8, 5)) -def run_pubkey(device_type, args): +async def run_pubkey(device_type, args): """Initialize hardware-based GnuPG identity.""" log.warning('This AGE tool is still in EXPERIMENTAL mode, ' 'so please note that the API and features may ' 'change without backwards compatibility!') - c = client.Client(device=device_type()) - pubkey = c.pubkey(identity=client.create_identity(args.identity), ecdh=True) - recipient = bech32_encode(prefix="age", data=pubkey) - print(f"# recipient: {recipient}") - print(f"# SLIP-0017: {args.identity}") - data = args.identity.encode() - encoded = bech32_encode(prefix="age-plugin-trezor-", data=data).upper() - decoded = bech32_decode(prefix="age-plugin-trezor-", encoded=encoded) - assert decoded.startswith(data) - print(encoded) + async with await device.ui.UI.create(device_type=device_type, config=vars(args)) as ui: + c = client.Client(ui=ui) + pubkey = await c.pubkey(identity=client.create_identity(args.identity), ecdh=True) + recipient = bech32_encode(prefix="age", data=pubkey) + print(f"# recipient: {recipient}") + print(f"# SLIP-0017: {args.identity}") + data = args.identity.encode() + encoded = bech32_encode(prefix="age-plugin-trezor-", data=data).upper() + decoded = bech32_decode(prefix="age-plugin-trezor-", encoded=encoded) + assert decoded.startswith(data) + print(encoded) def base64_decode(encoded: str) -> bytes: @@ -86,48 +88,48 @@ def decrypt(key, encrypted): return None -def run_decrypt(device_type, args): +async def run_decrypt(device_type, args): """Unlock hardware device (for future interaction).""" # pylint: disable=too-many-locals - c = client.Client(device=device_type()) + async with await device.ui.UI.create(device_type=device_type, config=vars(args)) as ui: + c = client.Client(ui=ui) - lines = (line.strip() for line in sys.stdin) # strip whitespace - lines = (line for line in lines if line) # skip empty lines + lines = (line.strip() for line in sys.stdin) # strip whitespace + lines = (line for line in lines if line) # skip empty lines - identities = [] - stanza_map = {} + identities = [] + stanza_map = {} - for line in lines: - log.debug("got %r", line) - if line == "-> done": - break + for line in lines: + log.debug("got %r", line) + if line == "-> done": + break - if line.startswith("-> add-identity "): - encoded = line.split(" ")[-1].lower() - data = bech32_decode("age-plugin-trezor-", encoded) - identity = client.create_identity(data.decode()) - identities.append(identity) + if line.startswith("-> add-identity "): + encoded = line.split(" ")[-1].lower() + data = bech32_decode("age-plugin-trezor-", encoded) + identity = client.create_identity(data.decode()) + identities.append(identity) - elif line.startswith("-> recipient-stanza "): - file_index, tag, *args = line.split(" ")[2:] - body = next(lines) - if tag != "X25519": - continue + elif line.startswith("-> recipient-stanza "): + file_index, tag, *args = line.split(" ")[2:] + body = next(lines) + if tag != "X25519": + continue - peer_pubkey = base64_decode(args[0]) - encrypted = base64_decode(body) - stanza_map.setdefault(file_index, []).append((peer_pubkey, encrypted)) + peer_pubkey = base64_decode(args[0]) + encrypted = base64_decode(body) + stanza_map.setdefault(file_index, []).append((peer_pubkey, encrypted)) - for file_index, stanzas in stanza_map.items(): - _handle_single_file(file_index, stanzas, identities, c) + for file_index, stanzas in stanza_map.items(): + await _handle_single_file(file_index, stanzas, identities, c, ui.get_device_name()) - sys.stdout.buffer.write('-> done\n\n'.encode()) - sys.stdout.flush() - sys.stdout.close() + sys.stdout.buffer.write('-> done\n\n'.encode()) + sys.stdout.flush() + sys.stdout.close() -def _handle_single_file(file_index, stanzas, identities, c): - d = c.device.__class__.__name__ +async def _handle_single_file(file_index, stanzas, identities, c, d): for peer_pubkey, encrypted in stanzas: for identity in identities: id_str = identity.to_string() @@ -135,7 +137,7 @@ def _handle_single_file(file_index, stanzas, identities, c): sys.stdout.buffer.write(f'-> msg\n{msg}\n'.encode()) sys.stdout.flush() - key = c.ecdh(identity=identity, peer_pubkey=peer_pubkey) + key = await c.ecdh(identity=identity, peer_pubkey=peer_pubkey) result = decrypt(key=key, encrypted=encrypted) if not result: continue @@ -167,13 +169,11 @@ def main(device_type): log.debug("starting age plugin: %s", args) - device_type.ui = device.ui.UI(device_type=device_type, config=vars(args)) - try: if args.identity: - run_pubkey(device_type=device_type, args=args) + trio.run(run_pubkey, device_type, args) elif args.age_plugin == 'identity-v1': - run_decrypt(device_type=device_type, args=args) + trio.run(run_decrypt, device_type, args) else: log.error("Unsupported state machine: %r", args.age_plugin) except Exception as e: # pylint: disable=broad-except diff --git a/libagent/age/client.py b/libagent/age/client.py index a3695915..12dfbdea 100644 --- a/libagent/age/client.py +++ b/libagent/age/client.py @@ -20,24 +20,24 @@ def create_identity(user_id): class Client: """Sign messages and get public keys from a hardware device.""" - def __init__(self, device): + def __init__(self, ui): """C-tor.""" - self.device = device + self.ui = ui - def pubkey(self, identity, ecdh=False): + async def pubkey(self, identity, ecdh=False): """Return public key as VerifyingKey object.""" - with self.device: - pubkey = bytes(self.device.pubkey(ecdh=ecdh, identity=identity)) + async with self.ui.device() as device: + pubkey = bytes(await device.pubkey(ecdh=ecdh, identity=identity)) assert len(pubkey) == 32 return pubkey - def ecdh(self, identity, peer_pubkey): + async def ecdh(self, identity, peer_pubkey): """Derive shared secret using ECDH from peer public key.""" log.info('please confirm AGE decryption on %s for "%s"...', - self.device, identity.to_string()) - with self.device: + self.ui.get_device_name(), identity.to_string()) + async with self.ui.device() as device: assert len(peer_pubkey) == 32 - result, self_pubkey = self.device.ecdh_with_pubkey( + result, self_pubkey = await device.ecdh_with_pubkey( pubkey=(b"\x40" + peer_pubkey), identity=identity) assert result[:1] == b"\x04" hkdf = HKDF( diff --git a/libagent/device/interface.py b/libagent/device/interface.py index a21aad77..54b2f0df 100644 --- a/libagent/device/interface.py +++ b/libagent/device/interface.py @@ -105,9 +105,10 @@ def get_curve_name(self, ecdh=False): class Device: """Abstract cryptographic hardware device interface.""" - def __init__(self): + def __init__(self, ui): """C-tor.""" self.conn = None + self.ui = ui def connect(self): """Connect to device, otherwise raise NotFoundError.""" diff --git a/libagent/device/trezor.py b/libagent/device/trezor.py index 65978b39..1abdfc35 100644 --- a/libagent/device/trezor.py +++ b/libagent/device/trezor.py @@ -26,7 +26,6 @@ def _defs(self): required_version = '>=1.4.0' - ui = None # can be overridden by device's users cached_session_id = None def _verify_version(self, connection): diff --git a/libagent/device/ui.py b/libagent/device/ui.py index 2cf0f130..6819a7bc 100644 --- a/libagent/device/ui.py +++ b/libagent/device/ui.py @@ -1,10 +1,16 @@ """UIs for PIN/passphrase entry.""" +import contextlib +import functools +import io import logging import os import subprocess import sys +import trio +import trio_util + from .. import util from ..gpg import keyring @@ -17,72 +23,218 @@ log = logging.getLogger(__name__) +class _UISync: + def __init__(self, ui): + self.ui = ui + + def get_pin(self, code=None): + return trio.from_thread.run(self.ui.get_pin, code) + + def get_passphrase(self, prompt='Passphrase:', available_on_device=False): + return trio.from_thread.run(self.ui.get_passphrase, prompt, available_on_device) + + def button_request(self, br=None): + return trio.from_thread.run(self.ui.button_request, br) + + +class _DeviceOnThread: + def __init__(self, runner, runner_immediate, proxy, button_scope): + self.runner = runner + self.runner_immediate = runner_immediate + self.proxy = proxy + self.button_scope = button_scope + + async def connect(self): + return await self.runner(self.proxy.connect) + + async def close(self): + return await self.runner(self.proxy.close) + + async def __aenter__(self): + async with self.button_scope(): # May request a pin unlock + await self.runner(self.proxy.__enter__) + return self + + async def __aexit__(self, *args): + # Try to close the device immediately + # If a device request is in progress, this will prevent the program from being stuck + return await self.runner_immediate(self.proxy.__exit__, *args) + + async def pubkey(self, identity, ecdh=False): + async with self.button_scope(): + return await self.runner(self.proxy.pubkey, identity, ecdh) + + async def sign(self, identity, blob): + async with self.button_scope(): + return await self.runner(self.proxy.sign, identity, blob) + + async def sign_with_pubkey(self, identity, blob): + async with self.button_scope(): + return await self.runner(self.proxy.sign_with_pubkey, identity, blob) + + async def ecdh(self, identity, pubkey): + async with self.button_scope(): + return await self.runner(self.proxy.ecdh, identity, pubkey) + + async def ecdh_with_pubkey(self, identity, pubkey): + async with self.button_scope(): + return await self.runner(self.proxy.ecdh_with_pubkey, identity, pubkey) + + def __str__(self): + return self.proxy.__str__() + + +# pylint: disable=too-many-instance-attributes class UI: """UI for PIN/passphrase entry (for TREZOR devices).""" - def __init__(self, device_type, config=None): + @classmethod + async def create(cls, device_type, config=None): + """Asynchronously create a UI object, fiilling in default options.""" + # by default, use GnuPG pinentry tool + default_pinentry = await keyring.get_pinentry_binary() + options_getter = await create_default_options_getter() + return cls(device_type, default_pinentry, options_getter, config) + + def __init__(self, device_type, default_pinentry, options_getter, config=None): """C-tor.""" - default_pinentry = keyring.get_pinentry_binary() # by default, use GnuPG pinentry tool + self.run_on_thread = None + self.run_command_on_thread = None + self.run_command_on_thread_immediate = None + self.quit_event = None + self.button_nursery = None if config is None: config = {} self.pin_entry_binary = config.get('pin_entry_binary', default_pinentry) self.passphrase_entry_binary = config.get('passphrase_entry_binary', default_pinentry) - self.options_getter = create_default_options_getter() - self.device_name = device_type.__name__ + self.options_getter = options_getter + self.device_lock = trio.Lock() + self.device_instance = None + self.device_type = device_type self.cached_passphrase_ack = util.ExpiringCache( seconds=float(config.get('cache_expiry_seconds', 'inf'))) - def get_pin(self, _code=None): + async def __aenter__(self): + """Start a thread for accepting device commands.""" + assert self.run_on_thread is None + self.device_instance = self.device_type(_UISync(self)) + self.run_on_thread = util.run_on_thread() + self.quit_event = trio.Event() + self.run_command_on_thread, self.run_command_on_thread_immediate = await ( + type(self.run_on_thread).__aenter__(self.run_on_thread)) + return self + + async def __aexit__(self, *args): + """Close the thread and wait for it to complete.""" + if self.quit_event is not None: + self.quit_event.set() + if self.run_on_thread is not None: + run_on_thread = self.run_on_thread + self.run_on_thread = None + self.device_instance = None + return await type(run_on_thread).__aexit__(run_on_thread, *args) + + async def get_pin(self, _code=None): """Ask the user for (scrambled) PIN.""" - description = ( - 'Use the numeric keypad to describe number positions.\n' - 'The layout is:\n' - ' 7 8 9\n' - ' 4 5 6\n' - ' 1 2 3') - return interact( - title='{} PIN'.format(self.device_name), - prompt='PIN:', - description=description, - binary=self.pin_entry_binary, - options=self.options_getter()) + assert self.quit_event is not None + async with trio_util.move_on_when(self.quit_event.wait): + description = ( + 'Use the numeric keypad to describe number positions.\n' + 'The layout is:\n' + ' 7 8 9\n' + ' 4 5 6\n' + ' 1 2 3') + return await interact( + title='{} PIN'.format(self.device_type.__name__), + prompt='PIN:', + description=description, + binary=self.pin_entry_binary, + options=self.options_getter()) + raise RuntimeError('UI scope exited') - def get_passphrase(self, prompt='Passphrase:', available_on_device=False): + async def get_passphrase(self, prompt='Passphrase:', available_on_device=False): """Ask the user for passphrase.""" - passphrase = None - if self.cached_passphrase_ack: - passphrase = self.cached_passphrase_ack.get() - if passphrase is None: - env_passphrase = os.environ.get("TREZOR_PASSPHRASE") - if env_passphrase is not None: - passphrase = env_passphrase - elif available_on_device: - passphrase = PASSPHRASE_ON_DEVICE - else: - passphrase = interact( - title='{} passphrase'.format(self.device_name), - prompt=prompt, - description=None, - binary=self.passphrase_entry_binary, - options=self.options_getter()) - if self.cached_passphrase_ack: - self.cached_passphrase_ack.set(passphrase) - return passphrase - - def button_request(self, _code=None): + assert self.quit_event is not None + async with trio_util.move_on_when(self.quit_event.wait): + passphrase = None + if self.cached_passphrase_ack: + passphrase = self.cached_passphrase_ack.get(prompt) + if passphrase is None: + env_passphrase = os.environ.get("TREZOR_PASSPHRASE") + if env_passphrase is not None: + passphrase = env_passphrase + elif available_on_device: + passphrase = PASSPHRASE_ON_DEVICE + else: + passphrase = await interact( + title='{} passphrase'.format(self.device_type.__name__), + prompt=prompt, + description=None, + binary=self.passphrase_entry_binary, + options=self.options_getter()) + if self.cached_passphrase_ack: + self.cached_passphrase_ack.set(prompt, passphrase) + return passphrase + raise RuntimeError('UI scope exited') + + async def button_request(self, _code=None): """Called by TrezorClient when device interaction is required.""" - # XXX: show notification to the user? + if self.button_nursery is None: + # We don't have a clear scope for the operation + # Better to show nothing than to show a window that would not automatically close + return + self.button_nursery.start_soon(self._button_request) + + async def _button_request(self, _code=None): + try: + await interact( + title='{} interact'.format(self.device_type.__name__), + prompt=None, + description='Please follow the instructions\n' + 'on your {} device\'s screen'.format(self.device_type.__name__), + binary=self.passphrase_entry_binary, + options=self.options_getter(), + is_message=True) + except Exception as e: # pylint: disable=broad-except + log.exception('Failed to show an interaction dialog: %s', e) + + def get_device_name(self): + """Human-readable representation.""" + return self.device_type.__name__ + @contextlib.asynccontextmanager + async def device(self): + """Acquire access to the device.""" + assert self.device_instance + async with self.device_lock: # Only allow one connection at a time + async with _DeviceOnThread(self.run_command_on_thread, + self.run_command_on_thread_immediate, + self.device_instance, + self._button_scope) as dot: + yield dot + self.button_nursery = None -def create_default_options_getter(): + @contextlib.asynccontextmanager + async def _button_scope(self): + async with trio.open_nursery() as nursery: + self.button_nursery = nursery + try: + yield + finally: + if self.button_nursery == nursery: + self.button_nursery = None + nursery.cancel_scope.cancel() + + +async def create_default_options_getter(): """Return current TTY and DISPLAY settings for GnuPG pinentry.""" options = [] # Windows reports that it has a TTY but throws FileNotFoundError if sys.platform != 'win32' and sys.stdin.isatty(): # short-circuit calling `tty` try: - ttyname = subprocess.check_output(args=['tty']).strip() + ttyname = (await trio.run_process(['tty'], capture_stdout=True)).stdout.strip() options.append(b'ttyname=' + ttyname) except subprocess.CalledProcessError as e: log.warning('no TTY found: %s', e) @@ -98,20 +250,27 @@ def create_default_options_getter(): return lambda: options -def write(p, line): - """Send and flush a single line to the subprocess' stdin.""" +async def write(p, line): + """Send a single line to the subprocess' stdin.""" log.debug('%s <- %r', p.args, line) - p.stdin.write(line) - p.stdin.flush() + await p.stdin.send_all(line) class UnexpectedError(Exception): """Unexpected response.""" -def expect(p, prefixes, confidential=False): +async def expect(p, prefixes, confidential=False): """Read a line and return it without required prefix.""" - resp = p.stdout.readline() + resp = io.BytesIO() + while True: + c = await p.stdout.receive_some(1) + if not c: + raise IOError('Program abruptly closed after receiving: ' + str(resp.getvalue())) + if c == b'\n': + break + resp.write(c) + resp = resp.getvalue() log.debug('%s -> %r', p.args, resp if not confidential else '********') for prefix in prefixes: if resp.startswith(prefix): @@ -119,41 +278,47 @@ def expect(p, prefixes, confidential=False): raise UnexpectedError(resp) -def interact(title, description, prompt, binary, options): +async def interact(title, description, prompt, binary, options, is_message=False): """Use GPG pinentry program to interact with the user.""" - args = [binary] - p = subprocess.Popen(args=args, - stdin=subprocess.PIPE, - stdout=subprocess.PIPE, - env=os.environ) - p.args = args # TODO: remove after Python 2 deprecation. - expect(p, [b'OK']) - - title = util.assuan_serialize(title.encode('ascii')) - write(p, b'SETTITLE ' + title + b'\n') - expect(p, [b'OK']) - - if description: - description = util.assuan_serialize(description.encode('ascii')) - write(p, b'SETDESC ' + description + b'\n') - expect(p, [b'OK']) - - if prompt: - prompt = util.assuan_serialize(prompt.encode('ascii')) - write(p, b'SETPROMPT ' + prompt + b'\n') - expect(p, [b'OK']) - - log.debug('setting %d options', len(options)) - for opt in options: - write(p, b'OPTION ' + opt + b'\n') - expect(p, [b'OK', b'ERR']) - - write(p, b'GETPIN\n') - pin = expect(p, [b'OK', b'D '], confidential=True) - - p.communicate() # close stdin and wait for the process to exit - exit_code = p.wait() - if exit_code: - raise subprocess.CalledProcessError(exit_code, binary) - - return pin.decode('ascii').strip() + # pylint: disable=too-many-arguments + async with trio.open_nursery() as nursery: + p = await nursery.start(functools.partial(trio.run_process, [binary], + stdin=subprocess.PIPE, + stdout=subprocess.PIPE, + env=os.environ)) + await expect(p, [b'OK']) + + title = util.assuan_serialize(title.encode('ascii')) + await write(p, b'SETTITLE ' + title + b'\n') + await expect(p, [b'OK']) + + if description: + description = util.assuan_serialize(description.encode('ascii')) + await write(p, b'SETDESC ' + description + b'\n') + await expect(p, [b'OK']) + + if prompt: + prompt = util.assuan_serialize(prompt.encode('ascii')) + await write(p, b'SETPROMPT ' + prompt + b'\n') + await expect(p, [b'OK']) + + log.debug('setting %d options', len(options)) + for opt in options: + await write(p, b'OPTION ' + opt + b'\n') + await expect(p, [b'OK', b'ERR']) + + if is_message: + await write(p, b'MESSAGE\n') + else: + await write(p, b'GETPIN\n') + pin = await expect(p, [b'OK', b'D '], confidential=True) + + # close stdin and wait for the process to exit + await p.stdin.aclose() + async for _ in p.stdout: + pass + exit_code = await p.wait() + if exit_code: + raise subprocess.CalledProcessError(exit_code, binary) + + return pin.decode('ascii').strip() diff --git a/libagent/gpg/__init__.py b/libagent/gpg/__init__.py index 6bad4f65..7a04e14a 100644 --- a/libagent/gpg/__init__.py +++ b/libagent/gpg/__init__.py @@ -18,8 +18,10 @@ import subprocess import sys +import trio + try: - # TODO: Not supported on Windows. Use daemoniker instead? + # Not supported on Windows. Should be manually installed as a service instead. import daemon except ImportError: daemon = None @@ -32,66 +34,67 @@ log = logging.getLogger(__name__) -def export_public_key(device_type, args): +async def export_public_key(device_type, args): """Generate a new pubkey for a new/existing GPG identity.""" log.warning('NOTE: in order to re-generate the exact same GPG key later, ' 'run this command with "--time=%d" commandline flag (to set ' 'the timestamp of the GPG key manually).', args.time) - c = client.Client(device=device_type()) - identity = client.create_identity(user_id=args.user_id, - curve_name=args.ecdsa_curve) - verifying_key = c.pubkey(identity=identity, ecdh=False) - decryption_key = c.pubkey(identity=identity, ecdh=True) - signer_func = functools.partial(c.sign, identity=identity) - fingerprints = [] - - if args.subkey: # add as subkey - log.info('adding %s GPG subkey for "%s" to existing key', - args.ecdsa_curve, args.user_id) - # subkey for signing - signing_key = protocol.PublicKey( - curve_name=args.ecdsa_curve, created=args.time, - verifying_key=verifying_key, ecdh=False) - fingerprints.append(util.hexlify(signing_key.fingerprint())) - # subkey for encryption - encryption_key = protocol.PublicKey( - curve_name=formats.get_ecdh_curve_name(args.ecdsa_curve), - created=args.time, verifying_key=decryption_key, ecdh=True) - fingerprints.append(util.hexlify(encryption_key.fingerprint())) - primary_bytes = keyring.export_public_key(args.user_id) - result = encode.create_subkey(primary_bytes=primary_bytes, - subkey=signing_key, - signer_func=signer_func) - result = encode.create_subkey(primary_bytes=result, - subkey=encryption_key, - signer_func=signer_func) - else: # add as primary - log.info('creating new %s GPG primary key for "%s"', - args.ecdsa_curve, args.user_id) - # primary key for signing - primary = protocol.PublicKey( - curve_name=args.ecdsa_curve, created=args.time, - verifying_key=verifying_key, ecdh=False) - fingerprints.append(util.hexlify(primary.fingerprint())) - # subkey for encryption - subkey = protocol.PublicKey( - curve_name=formats.get_ecdh_curve_name(args.ecdsa_curve), - created=args.time, verifying_key=decryption_key, ecdh=True) - fingerprints.append(util.hexlify(subkey.fingerprint())) - - result = encode.create_primary(user_id=args.user_id, - pubkey=primary, - signer_func=signer_func) - result = encode.create_subkey(primary_bytes=result, - subkey=subkey, - signer_func=signer_func) - - return (fingerprints, protocol.armor(result, 'PUBLIC KEY BLOCK')) - - -def verify_gpg_version(): + async with await device.ui.UI.create(device_type=device_type, config=vars(args)) as ui: + c = client.Client(ui=ui) + identity = client.create_identity(user_id=args.user_id, + curve_name=args.ecdsa_curve) + verifying_key = await c.pubkey(identity=identity, ecdh=False) + decryption_key = await c.pubkey(identity=identity, ecdh=True) + signer_func = functools.partial(c.sign, identity=identity) + fingerprints = [] + + if args.subkey: # add as subkey + log.info('adding %s GPG subkey for "%s" to existing key', + args.ecdsa_curve, args.user_id) + # subkey for signing + signing_key = protocol.PublicKey( + curve_name=args.ecdsa_curve, created=args.time, + verifying_key=verifying_key, ecdh=False) + fingerprints.append(util.hexlify(signing_key.fingerprint())) + # subkey for encryption + encryption_key = protocol.PublicKey( + curve_name=formats.get_ecdh_curve_name(args.ecdsa_curve), + created=args.time, verifying_key=decryption_key, ecdh=True) + fingerprints.append(util.hexlify(encryption_key.fingerprint())) + primary_bytes = await keyring.export_public_key(args.user_id) + result = await encode.create_subkey(primary_bytes=primary_bytes, + subkey=signing_key, + signer_func=signer_func) + result = await encode.create_subkey(primary_bytes=result, + subkey=encryption_key, + signer_func=signer_func) + else: # add as primary + log.info('creating new %s GPG primary key for "%s"', + args.ecdsa_curve, args.user_id) + # primary key for signing + primary = protocol.PublicKey( + curve_name=args.ecdsa_curve, created=args.time, + verifying_key=verifying_key, ecdh=False) + fingerprints.append(util.hexlify(primary.fingerprint())) + # subkey for encryption + subkey = protocol.PublicKey( + curve_name=formats.get_ecdh_curve_name(args.ecdsa_curve), + created=args.time, verifying_key=decryption_key, ecdh=True) + fingerprints.append(util.hexlify(subkey.fingerprint())) + + result = await encode.create_primary(user_id=args.user_id, + pubkey=primary, + signer_func=signer_func) + result = await encode.create_subkey(primary_bytes=result, + subkey=subkey, + signer_func=signer_func) + + return (fingerprints, protocol.armor(result, 'PUBLIC KEY BLOCK')) + + +async def verify_gpg_version(): """Make sure that the installed GnuPG is not too old.""" - existing_gpg = keyring.gpg_version().decode('ascii') + existing_gpg = (await keyring.gpg_version()).decode('ascii') required_gpg = '>=2.1.11' msg = 'Existing GnuPG has version "{}" ({} required)'.format(existing_gpg, required_gpg) @@ -121,14 +124,14 @@ def write_file(path, data): return f -def run_init(device_type, args): +async def run_init(device_type, args): """Initialize hardware-based GnuPG identity.""" util.setup_logging(verbosity=args.verbose) log.warning('This GPG tool is still in EXPERIMENTAL mode, ' 'so please note that the API and features may ' 'change without backwards compatibility!') - verify_gpg_version() + await verify_gpg_version() # Prepare new GPG home directory for hardware-based identity device_name = device_type.package_name().rsplit('-', 1)[0] @@ -145,11 +148,11 @@ def run_init(device_type, args): sys.exit(1) # Prepare the key before making any changes - fingerprints, public_key_bytes = export_public_key(device_type, args) + fingerprints, public_key_bytes = await export_public_key(device_type, args) os.makedirs(homedir, mode=0o700) - agent_path = util.which('{}-gpg-agent'.format(device_name)) + agent_path = await util.which('{}-gpg-agent'.format(device_name)) # Prepare GPG agent invocation script (to pass the PATH from environment). with open(os.path.join(homedir, ('run-agent.sh' @@ -200,28 +203,29 @@ def run_init(device_type, args): # Generate new GPG identity and import into GPG keyring verbosity = ('-' + ('v' * args.verbose)) if args.verbose else '--quiet' - check_call(keyring.gpg_command(['--homedir', homedir, verbosity, - '--import']), + check_call(await keyring.gpg_command(['--homedir', homedir, verbosity, + '--import']), input_bytes=public_key_bytes.encode()) # Make new GPG identity with "ultimate" trust (via its fingerprint) - check_call(keyring.gpg_command(['--homedir', homedir, - '--import-ownertrust']), + check_call(await keyring.gpg_command(['--homedir', homedir, + '--import-ownertrust']), input_bytes=(fingerprints[0] + ':6\n').encode()) # Load agent and make sure it responds with the new identity - check_call(keyring.gpg_command(['--homedir', homedir, - '--list-secret-keys', args.user_id])) + check_call(await keyring.gpg_command(['--homedir', homedir, + '--list-secret-keys', args.user_id])) -def run_unlock(device_type, args): +async def run_unlock(device_type, args): """Unlock hardware device (for future interaction).""" util.setup_logging(verbosity=args.verbose) - with device_type() as d: - log.info('unlocked %s device', d) + async with await device.ui.UI.create(device_type=device_type, config=vars(args)) as ui: + async with ui.device(): + log.info('unlocked %s device', ui.get_device_name()) -def _server_from_assuan_fd(env): +async def _server_from_assuan_fd(env): fd = env.get('_assuan_connection_fd') if fd is None: return None @@ -229,8 +233,8 @@ def _server_from_assuan_fd(env): return server.unix_domain_socket_server_from_fd(int(fd)) -def _server_from_sock_path(env): - sock_path = keyring.get_agent_sock_path(env=env) +async def _server_from_sock_path(env): + sock_path = await keyring.get_agent_sock_path(env=env) return server.unix_domain_socket_server(sock_path) @@ -256,12 +260,27 @@ def run_agent(device_type): if daemon and args.daemon: with daemon.DaemonContext(): - run_agent_internal(args, device_type) + trio.run(run_agent_internal, args, device_type) else: - run_agent_internal(args, device_type) + trio.run(run_agent_internal, args, device_type) + + +async def handle_connection(conn, handler, quit_event): + """Handle a single connection to the agent.""" + try: + await handler.handle(conn) + except agent.AgentStop: + log.info('stopping gpg-agent') + quit_event.set() + return + except IOError as e: + log.info('connection closed: %s', e) + return + except Exception as e: # pylint: disable=broad-except + log.exception('handler failed: %s', e) -def run_agent_internal(args, device_type): +async def run_agent_internal(args, device_type): """Actually run the server.""" assert args.homedir @@ -273,29 +292,24 @@ def run_agent_internal(args, device_type): log.debug('pid: %d, parent pid: %d', os.getpid(), os.getppid()) try: env = {'GNUPGHOME': args.homedir, 'PATH': os.environ['PATH']} - pubkey_bytes = keyring.export_public_keys(env=env) - device_type.ui = device.ui.UI(device_type=device_type, - config=vars(args)) - handler = agent.Handler(device=device_type(), - pubkey_bytes=pubkey_bytes) - - sock_server = _server_from_assuan_fd(os.environ) - if sock_server is None: - sock_server = _server_from_sock_path(env) - - with sock_server as sock: - for conn in agent.yield_connections(sock): - with contextlib.closing(conn): - try: - handler.handle(conn) - except agent.AgentStop: - log.info('stopping gpg-agent') - return - except IOError as e: - log.info('connection closed: %s', e) - return - except Exception as e: # pylint: disable=broad-except - log.exception('handler failed: %s', e) + pubkey_bytes = await keyring.export_public_keys(env=env) + async with await device.ui.UI.create(device_type=device_type, config=vars(args)) as ui: + handler = agent.Handler(ui=ui, pubkey_bytes=pubkey_bytes) + + sock_server = await _server_from_assuan_fd(os.environ) + if sock_server is None: + sock_server = await _server_from_sock_path(env) + + async with sock_server as sock: + quit_event = trio.Event() + handle_conn = functools.partial(handle_connection, + handler=handler, + quit_event=quit_event) + try: + await server.server_thread(sock, handle_conn, quit_event) + finally: + log.debug('closing server') + quit_event.set() except Exception as e: # pylint: disable=broad-except log.exception('gpg-agent failed: %s', e) @@ -342,6 +356,5 @@ def main(device_type): p.set_defaults(func=run_unlock) args = parser.parse_args() - device_type.ui = device.ui.UI(device_type=device_type, config=vars(args)) - return args.func(device_type=device_type, args=args) + return trio.run(args.func, device_type, args) diff --git a/libagent/gpg/agent.py b/libagent/gpg/agent.py index 15c93643..8b08ecaa 100644 --- a/libagent/gpg/agent.py +++ b/libagent/gpg/agent.py @@ -8,19 +8,6 @@ log = logging.getLogger(__name__) -def yield_connections(sock): - """Run a server on the specified socket.""" - while True: - log.debug('waiting for connection on %s', sock.getsockname()) - try: - conn, _ = sock.accept() - except KeyboardInterrupt: - return - conn.settimeout(None) - log.debug('accepted connection on %s', sock.getsockname()) - yield conn - - def sig_encode(r, s): """Serialize ECDSA signature data into GPG S-expression.""" r = util.assuan_serialize(util.num2bytes(r, 32)) @@ -49,7 +36,7 @@ def parse_ecdh(line): return dict(items)[b'e'] -def _key_info(conn, args): +async def _key_info(conn, keygrip, *_): """ Dummy reply (mainly for 'gpg --edit' to succeed). @@ -57,8 +44,7 @@ def _key_info(conn, args): https://git.gnupg.org/cgi-bin/gitweb.cgi?p=gnupg.git;a=blob;f=agent/command.c;h=c8b34e9882076b1b724346787781f657cac75499;hb=refs/heads/master#l1082 """ fmt = 'S KEYINFO {0} X - - - - - - -' - keygrip, = args - keyring.sendline(conn, fmt.format(keygrip).encode('ascii')) + await keyring.sendline(conn, fmt.format(keygrip).encode('ascii')) class AgentError(Exception): @@ -76,83 +62,93 @@ class Handler: def _get_options(self): return self.options - def __init__(self, device, pubkey_bytes): + def __init__(self, ui, pubkey_bytes): """C-tor.""" - self.reset() + self.keygrip = None + self.digest = None + self.algo = None self.options = [] - device.ui.options_getter = self._get_options - self.client = client.Client(device=device) + ui.options_getter = self._get_options + self.ui = ui + self.client = client.Client(ui=ui) # Cache public keys from GnuPG self.pubkey_bytes = pubkey_bytes - # "Clone" existing GPG version - self.version = keyring.gpg_version() self.handlers = { - b'RESET': lambda *_: self.reset(), - b'OPTION': lambda _, args: self.handle_option(*args), + b'RESET': self.reset, + b'OPTION': self.handle_option, b'SETKEYDESC': None, b'NOP': None, b'GETINFO': self.handle_getinfo, - b'AGENT_ID': lambda conn, _: keyring.sendline(conn, b'D TREZOR'), # "Fake" agent ID - b'SIGKEY': lambda _, args: self.set_key(*args), - b'SETKEY': lambda _, args: self.set_key(*args), - b'SETHASH': lambda _, args: self.set_hash(*args), - b'PKSIGN': lambda conn, _: self.pksign(conn), - b'PKDECRYPT': lambda conn, _: self.pkdecrypt(conn), - b'HAVEKEY': lambda conn, args: self.have_key(conn, *args), + b'AGENT_ID': self.handle_agent_id, + b'SIGKEY': self.set_key, + b'SETKEY': self.set_key, + b'SETHASH': self.set_hash, + b'PKSIGN': self.pksign, + b'PKDECRYPT': self.pkdecrypt, + b'HAVEKEY': self.have_key, b'KEYINFO': _key_info, b'SCD': self.handle_scd, b'GET_PASSPHRASE': self.handle_get_passphrase, } - def reset(self): + @util.memoize_method + async def get_version(self): + """Clone existing GPG version.""" + return await keyring.gpg_version() + + async def reset(self, *_): """Reset agent's state variables.""" self.keygrip = None self.digest = None self.algo = None - def handle_option(self, opt): + async def handle_option(self, _conn, opt, *_): """Store GPG agent-related options (e.g. for pinentry).""" self.options.append(opt) log.debug('options: %s', self.options) - def handle_get_passphrase(self, conn, _): + async def handle_get_passphrase(self, conn, *_): """Allow simple GPG symmetric encryption (using a passphrase).""" - p1 = self.client.device.ui.get_passphrase('Symmetric encryption:') - p2 = self.client.device.ui.get_passphrase('Re-enter encryption:') + p1 = await self.ui.get_passphrase('Symmetric encryption:') + p2 = await self.ui.get_passphrase('Re-enter encryption:') if p1 == p2: result = b'D ' + util.assuan_serialize(p1.encode('ascii')) - keyring.sendline(conn, result, confidential=True) + await keyring.sendline(conn, result, confidential=True) else: log.warning('Passphrase does not match!') - def handle_getinfo(self, conn, args): + async def handle_agent_id(self, conn, *_): + """Send fake agent ID.""" + await keyring.sendline(conn, b'D TREZOR') + + async def handle_getinfo(self, conn, cmd, *_): """Handle some of the GETINFO messages.""" result = None - if args[0] == b'version': - result = self.version - elif args[0] == b's2k_count': + if cmd == b'version': + result = await self.get_version() + elif cmd == b's2k_count': # Use highest number of S2K iterations. # https://www.gnupg.org/documentation/manuals/gnupg/OpenPGP-Options.html # https://tools.ietf.org/html/rfc4880#section-3.7.1.3 result = '{}'.format(64 << 20).encode('ascii') else: - log.warning('Unknown GETINFO command: %s', args) + log.warning('Unknown GETINFO command: %s', cmd) if result: - keyring.sendline(conn, b'D ' + result) + await keyring.sendline(conn, b'D ' + result) - def handle_scd(self, conn, args): + async def handle_scd(self, conn, *args): """No support for smart-card device protocol.""" reply = { - (b'GETINFO', b'version'): self.version, + (b'GETINFO', b'version'): await self.get_version(), }.get(args) if reply is None: raise AgentError(b'ERR 100696144 No such device ') - keyring.sendline(conn, b'D ' + reply) + await keyring.sendline(conn, b'D ' + reply) @util.memoize_method # global cache for key grips - def get_identity(self, keygrip): + async def get_identity(self, keygrip): """ Returns device.interface.Identity that matches specified keygrip. @@ -167,7 +163,7 @@ def get_identity(self, keygrip): ecdh = pubkey_dict['algo'] == protocol.ECDH_ALGO_ID identity = client.create_identity(user_id=user_id, curve_name=curve_name) - verifying_key = self.client.pubkey(identity=identity, ecdh=ecdh) + verifying_key = await self.client.pubkey(identity=identity, ecdh=ecdh) pubkey = protocol.PublicKey( curve_name=curve_name, created=pubkey_dict['created'], verifying_key=verifying_key, ecdh=ecdh) @@ -175,81 +171,83 @@ def get_identity(self, keygrip): assert pubkey.keygrip() == keygrip_bytes return identity - def pksign(self, conn): + async def pksign(self, conn, *_): """Sign a message digest using a private EC key.""" log.debug('signing %r digest (algo #%s)', self.digest, self.algo) - identity = self.get_identity(keygrip=self.keygrip) - r, s = self.client.sign(identity=identity, - digest=binascii.unhexlify(self.digest)) + identity = await self.get_identity(keygrip=self.keygrip) + r, s = await self.client.sign(identity=identity, + digest=binascii.unhexlify(self.digest)) result = sig_encode(r, s) log.debug('result: %r', result) - keyring.sendline(conn, b'D ' + result) + await keyring.sendline(conn, b'D ' + result) - def pkdecrypt(self, conn): + async def pkdecrypt(self, conn, *_): """Handle decryption using ECDH.""" for msg in [b'S INQUIRE_MAXLEN 4096', b'INQUIRE CIPHERTEXT']: - keyring.sendline(conn, msg) + await keyring.sendline(conn, msg) - line = keyring.recvline(conn) - assert keyring.recvline(conn) == b'END' + line = await keyring.recvline(conn) + assert await keyring.recvline(conn) == b'END' remote_pubkey = parse_ecdh(line) - identity = self.get_identity(keygrip=self.keygrip) - ec_point = self.client.ecdh(identity=identity, pubkey=remote_pubkey) - keyring.sendline(conn, b'D ' + _serialize_point(ec_point)) + identity = await self.get_identity(keygrip=self.keygrip) + ec_point = await self.client.ecdh(identity=identity, pubkey=remote_pubkey) + await keyring.sendline(conn, b'D ' + _serialize_point(ec_point)) - def have_key(self, conn, *keygrips): + async def have_key(self, conn, *keygrips): """Check if any keygrip corresponds to a TREZOR-based key.""" if len(keygrips) == 1 and keygrips[0].startswith(b"--list="): # Support "fast-path" key listing: # https://dev.gnupg.org/rG40da61b89b62dcb77847dc79eb159e885f52f817 keygrips = list(decode.iter_keygrips(pubkey_bytes=self.pubkey_bytes)) log.debug('keygrips: %r', keygrips) - keyring.sendline(conn, b'D ' + util.assuan_serialize(b''.join(keygrips))) + await keyring.sendline(conn, b'D ' + util.assuan_serialize(b''.join(keygrips))) return for keygrip in keygrips: try: - self.get_identity(keygrip=keygrip) + await self.get_identity(keygrip=keygrip) break except KeyError as e: log.warning('HAVEKEY(%s) failed: %s', keygrip, e) else: raise AgentError(b'ERR 67108881 No secret key ') - def set_key(self, keygrip): + async def set_key(self, _conn, keygrip, *_): """Set hexadecimal keygrip for next operation.""" self.keygrip = keygrip - def set_hash(self, algo, digest): + async def set_hash(self, _conn, algo, digest, *_): """Set algorithm ID and hexadecimal digest for next operation.""" self.algo = algo self.digest = digest - def handle(self, conn): + async def handle(self, conn): """Handle connection from GPG binary using the ASSUAN protocol.""" - keyring.sendline(conn, b'OK') - for line in keyring.iterlines(conn): + await keyring.sendline(conn, b'OK') + async for line in keyring.iterlines(conn): parts = line.split(b' ') command = parts[0] args = tuple(parts[1:]) if command == b'BYE': + await keyring.sendline(conn, b'OK closing connection') return elif command == b'KILLAGENT': - keyring.sendline(conn, b'OK') + await keyring.sendline(conn, b'OK closing connection') raise AgentStop() if command not in self.handlers: + await keyring.sendline(conn, b'ERR 67109139 Unknown IPC command ') log.error('unknown request: %r', line) continue handler = self.handlers[command] if handler: try: - handler(conn, args) + await handler(conn, *args) except AgentError as e: msg, = e.args - keyring.sendline(conn, msg) + await keyring.sendline(conn, msg) continue - keyring.sendline(conn, b'OK') + await keyring.sendline(conn, b'OK') diff --git a/libagent/gpg/client.py b/libagent/gpg/client.py index 131ce96d..afe307d2 100644 --- a/libagent/gpg/client.py +++ b/libagent/gpg/client.py @@ -18,29 +18,29 @@ def create_identity(user_id, curve_name): class Client: """Sign messages and get public keys from a hardware device.""" - def __init__(self, device): + def __init__(self, ui): """C-tor.""" - self.device = device + self.ui = ui - def pubkey(self, identity, ecdh=False): + async def pubkey(self, identity, ecdh=False): """Return public key as VerifyingKey object.""" - with self.device: - return self.device.pubkey(ecdh=ecdh, identity=identity) + async with self.ui.device() as device: + return await device.pubkey(ecdh=ecdh, identity=identity) - def sign(self, identity, digest): + async def sign(self, identity, digest): """Sign the digest and return a serialized signature.""" log.info('please confirm GPG signature on %s for "%s"...', - self.device, identity.to_string()) + self.ui.get_device_name(), identity.to_string()) if identity.curve_name == formats.CURVE_NIST256: digest = digest[:32] # sign the first 256 bits log.debug('signing digest: %s', util.hexlify(digest)) - with self.device: - sig = self.device.sign(blob=digest, identity=identity) + async with self.ui.device() as device: + sig = await device.sign(blob=digest, identity=identity) return (util.bytes2num(sig[:32]), util.bytes2num(sig[32:])) - def ecdh(self, identity, pubkey): + async def ecdh(self, identity, pubkey): """Derive shared secret using ECDH from remote public key.""" log.info('please confirm GPG decryption on %s for "%s"...', - self.device, identity.to_string()) - with self.device: - return self.device.ecdh(pubkey=pubkey, identity=identity) + self.ui.get_device_name(), identity.to_string()) + async with self.ui.device() as device: + return await device.ecdh(pubkey=pubkey, identity=identity) diff --git a/libagent/gpg/encode.py b/libagent/gpg/encode.py index 44c3d2e6..3053a483 100644 --- a/libagent/gpg/encode.py +++ b/libagent/gpg/encode.py @@ -8,7 +8,7 @@ log = logging.getLogger(__name__) -def create_primary(user_id, pubkey, signer_func, secret_bytes=b''): +async def create_primary(user_id, pubkey, signer_func, secret_bytes=b''): """Export new primary GPG public key, ready for "gpg2 --import".""" pubkey_packet = protocol.packet(tag=(5 if secret_bytes else 6), blob=pubkey.data() + secret_bytes) @@ -36,7 +36,7 @@ def create_primary(user_id, pubkey, signer_func, secret_bytes=b''): protocol.subpacket(16, pubkey.key_id()), # issuer key id protocol.CUSTOM_SUBPACKET] - signature = protocol.make_signature( + signature = await protocol.make_signature( signer_func=signer_func, public_algo=pubkey.algo_id, data_to_sign=data_to_sign, @@ -48,7 +48,7 @@ def create_primary(user_id, pubkey, signer_func, secret_bytes=b''): return pubkey_packet + user_id_packet + sign_packet -def create_subkey(primary_bytes, subkey, signer_func, secret_bytes=b''): +async def create_subkey(primary_bytes, subkey, signer_func, secret_bytes=b''): """Export new subkey to GPG primary key.""" subkey_packet = protocol.packet(tag=(7 if secret_bytes else 14), blob=subkey.data() + secret_bytes) @@ -65,7 +65,7 @@ def create_subkey(primary_bytes, subkey, signer_func, secret_bytes=b''): protocol.subpacket_time(subkey.created)] # signature time unhashed_subpackets = [ protocol.subpacket(16, subkey.key_id())] # issuer key id - embedded_sig = protocol.make_signature( + embedded_sig = await protocol.make_signature( signer_func=signer_func, data_to_sign=data_to_sign, public_algo=subkey.algo_id, @@ -90,9 +90,9 @@ def create_subkey(primary_bytes, subkey, signer_func, secret_bytes=b''): unhashed_subpackets.append(protocol.CUSTOM_SUBPACKET) if not decode.has_custom_subpacket(signature): - signer_func = keyring.create_agent_signer(user_id['value']) + signer_func = await keyring.create_agent_signer(user_id['value']) - signature = protocol.make_signature( + signature = await protocol.make_signature( signer_func=signer_func, data_to_sign=data_to_sign, public_algo=primary['algo'], diff --git a/libagent/gpg/keyring.py b/libagent/gpg/keyring.py index 46dd00a6..be53cee6 100644 --- a/libagent/gpg/keyring.py +++ b/libagent/gpg/keyring.py @@ -7,10 +7,11 @@ import os import re import socket -import subprocess import sys import urllib.parse +import trio + from .. import util if sys.platform == 'win32': @@ -19,54 +20,53 @@ log = logging.getLogger(__name__) -def check_output(args, env=None, sp=subprocess): +async def check_output(args, env=None, run_process=trio.run_process): """Call an external binary and return its stdout.""" log.debug('calling %s with env %s', args, env) - p = sp.Popen(args=args, env=env, stdin=sp.PIPE, stdout=sp.PIPE, stderr=sp.PIPE) - (output, error) = p.communicate() - log.debug('output: %r', output) - if error: - log.debug('error: %r', error) - return output + info = await run_process(args, env=env, capture_stdout=True, capture_stderr=True) + log.debug('output: %r', info.stdout) + if info.stderr: + log.debug('error: %r', info.stderr) + return info.stdout -def get_agent_sock_path(env=None, sp=subprocess): +async def get_agent_sock_path(env=None, run_process=trio.run_process): """Parse gpgconf output to find out GPG agent UNIX socket path.""" - args = [util.which('gpgconf'), '--list-dirs', 'agent-socket'] - return check_output(args=args, env=env, sp=sp).strip() + args = [await util.which('gpgconf'), '--list-dirs', 'agent-socket'] + return (await check_output(args=args, env=env, run_process=run_process)).strip() -def connect_to_agent(env=None, sp=subprocess): +async def connect_to_agent(env=None, run_process=trio.run_process): """Connect to GPG agent's UNIX socket.""" - sock_path = get_agent_sock_path(sp=sp, env=env) + sock_path = get_agent_sock_path(run_process=run_process, env=env) # Make sure the original gpg-agent is running. - check_output(args=['gpg-connect-agent', '/bye'], sp=sp) + await check_output(args=['gpg-connect-agent', '/bye'], run_process=run_process) if sys.platform == 'win32': - sock = win_server.Client(sock_path) + sock = await win_server.Client.open(sock_path) else: sock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM) - sock.connect(sock_path) + await sock.connect(sock_path) return sock -def communicate(sock, msg): +async def communicate(sock, msg): """Send a message and receive a single line.""" - sendline(sock, msg.encode('ascii')) - return recvline(sock) + await sendline(sock, msg.encode('ascii')) + return await recvline(sock) -def sendline(sock, msg, confidential=False): +async def sendline(sock, msg, confidential=False): """Send a binary message, followed by EOL.""" log.debug('<- %r', ('' if confidential else msg)) - sock.sendall(msg + b'\n') + await util.send(sock, msg + b'\n') -def recvline(sock): +async def recvline(sock): """Receive a single line from the socket.""" reply = io.BytesIO() while True: - c = sock.recv(1) + c = await sock.recv(1) if not c: return None # socket is closed @@ -79,10 +79,10 @@ def recvline(sock): return result -def iterlines(conn): +async def iterlines(conn): """Iterate over input, split by lines.""" while True: - line = recvline(conn) + line = await recvline(conn) if line is None: break yield line @@ -153,14 +153,14 @@ def parse_sig(sig): return parser(args=sig[1:]) -def sign_digest(sock, keygrip, digest, sp=subprocess, environ=None): +async def sign_digest(sock, keygrip, digest, run_process=trio.run_process, environ=None): """Sign a digest using specified key using GPG agent.""" hash_algo = 8 # SHA256 assert len(digest) == 32 - assert communicate(sock, 'RESET').startswith(b'OK') + assert (await communicate(sock, 'RESET')).startswith(b'OK') - ttyname = check_output(args=['tty'], sp=sp).strip() + ttyname = (await check_output(args=['tty'], run_process=run_process)).strip() options = ['ttyname={}'.format(ttyname)] # set TTY for passphrase entry display = (environ or os.environ).get('DISPLAY') @@ -168,18 +168,18 @@ def sign_digest(sock, keygrip, digest, sp=subprocess, environ=None): options.append('display={}'.format(display)) for opt in options: - assert communicate(sock, 'OPTION {}'.format(opt)) == b'OK' + assert await communicate(sock, 'OPTION {}'.format(opt)) == b'OK' - assert communicate(sock, 'SIGKEY {}'.format(keygrip)) == b'OK' + assert await communicate(sock, 'SIGKEY {}'.format(keygrip)) == b'OK' hex_digest = binascii.hexlify(digest).upper().decode('ascii') - assert communicate(sock, 'SETHASH {} {}'.format(hash_algo, - hex_digest)) == b'OK' + assert await communicate(sock, 'SETHASH {} {}'.format(hash_algo, + hex_digest)) == b'OK' - assert communicate(sock, 'SETKEYDESC ' - 'Sign+a+new+TREZOR-based+subkey') == b'OK' - assert communicate(sock, 'PKSIGN') == b'OK' + assert await communicate(sock, 'SETKEYDESC ' + 'Sign+a+new+TREZOR-based+subkey') == b'OK' + assert await communicate(sock, 'PKSIGN') == b'OK' while True: - line = recvline(sock).strip() + line = (await recvline(sock)).strip() if not line.startswith(b'S PROGRESS'): break line = unescape(line) @@ -193,10 +193,10 @@ def sign_digest(sock, keygrip, digest, sp=subprocess, environ=None): return parse_sig(sig) -def get_gnupg_components(sp=subprocess): +async def get_gnupg_components(run_process=trio.run_process): """Parse GnuPG components' paths.""" - args = [util.which('gpgconf'), '--list-components'] - output = check_output(args=args, sp=sp) + args = [await util.which('gpgconf'), '--list-components'] + output = await check_output(args=args, run_process=run_process) components = {k: urllib.parse.unquote(v) for k, v in re.findall( r'(?BBBB', @@ -271,7 +271,7 @@ def make_signature(signer_func, data_to_sign, public_algo, log.debug('hashing %d bytes', len(data_to_hash)) digest = hashlib.sha256(data_to_hash).digest() log.debug('signing digest: %s', util.hexlify(digest)) - params = signer_func(digest=digest) + params = await signer_func(digest=digest) sig = b''.join(mpi(p) for p in params) return bytes(header + hashed + unhashed + diff --git a/libagent/gpg/tests/test_keyring.py b/libagent/gpg/tests/test_keyring.py index 605ba0c9..a3a18f3c 100644 --- a/libagent/gpg/tests/test_keyring.py +++ b/libagent/gpg/tests/test_keyring.py @@ -1,7 +1,7 @@ import io import subprocess -import mock +import pytest from .. import keyring @@ -47,22 +47,22 @@ def __init__(self): self.rx = io.BytesIO() self.tx = io.BytesIO() - def recv(self, n): + async def recv(self, n): return self.rx.read(n) - def sendall(self, data): + async def send(self, data): self.tx.write(data) + return len(data) -def mock_subprocess(output, error=b''): - sp = mock.Mock(spec=['Popen', 'PIPE']) - p = mock.Mock(spec=['communicate']) - sp.Popen.return_value = p - p.communicate.return_value = (output, error) - return sp +def mock_run_process(output, error=b''): + async def run_process(args, **_): + return subprocess.CompletedProcess(args, returncode=0, stdout=output, stderr=error) + return run_process -def test_sign_digest(): +@pytest.mark.trio +async def test_sign_digest(): sock = FakeSocket() sock.rx.write(b'OK Pleased to meet you, process XYZ\n') sock.rx.write(b'OK\n' * 6) @@ -70,9 +70,9 @@ def test_sign_digest(): sock.rx.seek(0) keygrip = '1234' digest = b'A' * 32 - sig = keyring.sign_digest(sock=sock, keygrip=keygrip, - digest=digest, sp=mock_subprocess('/dev/pts/0'), - environ={'DISPLAY': ':0'}) + sig = await keyring.sign_digest(sock=sock, keygrip=keygrip, + digest=digest, run_process=mock_run_process('/dev/pts/0'), + environ={'DISPLAY': ':0'}) assert sig == (0x30313233343536373839414243444546,) assert sock.tx.getvalue() == b'''RESET OPTION ttyname=/dev/pts/0 @@ -84,18 +84,23 @@ def test_sign_digest(): ''' -def test_iterlines(): +@pytest.mark.trio +async def test_iterlines(): sock = FakeSocket() sock.rx.write(b'foo\nbar\nxyz') sock.rx.seek(0) - assert list(keyring.iterlines(sock)) == [b'foo', b'bar'] + assert [line async for line in keyring.iterlines(sock)] == [b'foo', b'bar'] -def test_get_agent_sock_path(): +@pytest.mark.trio +async def test_get_agent_sock_path(): expected_prefix = b'/run/user/' expected_suffix = b'/gnupg/S.gpg-agent' expected_infix = b'0123456789' - value = keyring.get_agent_sock_path(sp=subprocess) + expected_if_root = b'/root/.gnupg/S.gpg-agent' # Use in case tox was executed as root + value = await keyring.get_agent_sock_path() + if value == expected_if_root: + return assert value.startswith(expected_prefix) assert value.endswith(expected_suffix) value = value[len(expected_prefix):-len(expected_suffix)] diff --git a/libagent/gpg/tests/test_protocol.py b/libagent/gpg/tests/test_protocol.py index 233be13d..fe1e5799 100644 --- a/libagent/gpg/tests/test_protocol.py +++ b/libagent/gpg/tests/test_protocol.py @@ -47,13 +47,14 @@ def test_armor(): ''' -def test_make_signature(): - def signer_func(digest): +@pytest.mark.trio +async def test_make_signature(): + async def signer_func(digest): assert digest == (b'\xd0\xe5]|\x8bP\xe6\x91\xb3\xe8+\xf4A\xf0`(\xb1' b'\xc7\xf4;\x86\x97s\xdb\x9a\xda\xee< \xcb\x9e\x00') return (7, 8) - sig = protocol.make_signature( + sig = await protocol.make_signature( signer_func=signer_func, data_to_sign=b'Hello World!', public_algo=22, diff --git a/libagent/server.py b/libagent/server.py index 43289ce7..1fb9ede9 100644 --- a/libagent/server.py +++ b/libagent/server.py @@ -1,11 +1,16 @@ """UNIX-domain socket server for ssh-agent implementation.""" import contextlib +import functools import logging import os +import signal import socket -import subprocess import sys -import threading + +import trio +import trio.lowlevel +import trio.socket +import trio_util from . import util @@ -15,17 +20,17 @@ log = logging.getLogger(__name__) -def remove_file(path, remove=os.remove, exists=os.path.exists): +async def remove_file(path, trio_path=trio.Path): """Remove file, and raise OSError if still exists.""" try: - remove(path) + await trio_path(path).unlink() except OSError: - if exists(path): + if await trio_path(path).exists(): raise -@contextlib.contextmanager -def unix_domain_socket_server(sock_path): +@contextlib.asynccontextmanager +async def unix_domain_socket_server(sock_path): """ Create UNIX-domain socket on specified path. @@ -34,17 +39,19 @@ def unix_domain_socket_server(sock_path): log.debug('serving on %s', sock_path) if sys.platform == 'win32': # Return a named pipe emulating a socket server interface - yield win_server.Server(sock_path) + with await win_server.Server.open(sock_path) as server: + yield server return - remove_file(sock_path) - server = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM) - server.bind(sock_path) - server.listen(1) - try: - yield server - finally: - remove_file(sock_path) + await remove_file(sock_path) + + with trio.socket.socket(socket.AF_UNIX, socket.SOCK_STREAM) as server: + await server.bind(sock_path) + server.listen(1) + try: + yield server + finally: + await remove_file(sock_path) class FDServer: @@ -53,122 +60,103 @@ class FDServer: def __init__(self, fd): """C-tor.""" self.fd = fd - self.sock = socket.fromfd(fd, socket.AF_UNIX, socket.SOCK_STREAM) + self.sock = trio.socket.fromfd(fd, socket.AF_UNIX, socket.SOCK_STREAM) + + def __enter__(self): + """Context manager support.""" + return self + + def __exit__(self, *args): + """Context manager support.""" + return self.sock.__exit__(*args) def accept(self): """Use the same socket for I/O.""" return self, None - def recv(self, n): + async def recv(self, n): """Forward to underlying socket.""" - return self.sock.recv(n) + return await self.sock.recv(n) - def sendall(self, data): + async def send(self, data): """Forward to underlying socket.""" - return self.sock.sendall(data) + return await self.sock.send(data) def close(self): - """Not needed.""" - - def settimeout(self, _): - """Not needed.""" + """Close the duplicated file descriptor.""" + return self.sock.close() def getsockname(self): """Simple representation.""" return ''.format(self.fd) -@contextlib.contextmanager -def unix_domain_socket_server_from_fd(fd): +@contextlib.asynccontextmanager +async def unix_domain_socket_server_from_fd(fd): """Build UDS-based socket server from a file descriptor.""" yield FDServer(fd) -def handle_connection(conn, handler, mutex): +async def handle_connection(conn, handler): """ Handle a single connection using the specified protocol handler in a loop. - Since this function may be called concurrently from server_thread, - the specified mutex is used to synchronize the device handling. - Exit when EOFError is raised. All other exceptions are logged as warnings. """ try: log.debug('welcome agent') - with contextlib.closing(conn): + with conn: while True: - msg = util.read_frame(conn) - with mutex: - reply = handler.handle(msg=msg) - util.send(conn, reply) + msg = await util.read_frame_async(conn) + reply = await handler.handle(msg=msg) + await util.send(conn, reply) except EOFError: log.debug('goodbye agent') except Exception as e: # pylint: disable=broad-except log.warning('error: %s', e, exc_info=True) -def retry(func, exception_type, quit_event): - """ - Run the function, retrying when the specified exception_type occurs. - - Poll quit_event on each iteration, to be responsive to an external - exit request. - """ - while True: - if quit_event.is_set(): - raise StopIteration - try: - return func() - except exception_type: - pass - - -def server_thread(sock, handle_conn, quit_event): +async def server_thread(sock, handle_conn, quit_event): """Run a server on the specified socket.""" log.debug('server thread started') - def accept_connection(): - conn, _ = sock.accept() - conn.settimeout(None) + async def handle(conn): + with conn: + await handle_conn(conn) return conn - while True: - log.debug('waiting for connection on %s', sock.getsockname()) - try: - conn = retry(accept_connection, socket.timeout, quit_event) - except StopIteration: - log.debug('server stopped') - break - # Handle connections from SSH concurrently. - threading.Thread(target=handle_conn, - kwargs={'conn': conn}).start() - log.debug('server thread stopped') - - -@contextlib.contextmanager -def spawn(func, kwargs): - """Spawn a thread, and join it after the context is over.""" - t = threading.Thread(target=func, kwargs=kwargs) - t.start() - yield - t.join() - - -def run_process(command, environ): + try: + signals = [getattr(signal, attr) + for attr in ['SIGINT', 'SIGBREAK', 'SIGABRT'] if hasattr(signal, attr)] + with trio.open_signal_receiver(*signals) as signal_waiter: + async with trio_util.move_on_when(signal_waiter.__anext__): + async with trio_util.move_on_when(quit_event.wait): + async with trio.open_nursery() as nursery: + while True: + log.debug('waiting for connection on %s', sock.getsockname()) + conn, _ = await sock.accept() + nursery.start_soon(handle, conn) + finally: + log.debug('server thread stopped') + + +async def run_process(command, environ): """ Run the specified process and wait until it finishes. Use environ dict for environment variables. """ - log.info('running %r with %r', command, environ) - env = dict(os.environ) - env.update(environ) - try: - p = subprocess.Popen(args=command, env=env) - except OSError as e: - raise OSError('cannot run %r: %s' % (command, e)) from e - log.debug('subprocess %d is running', p.pid) - ret = p.wait() - log.debug('subprocess %d exited: %d', p.pid, ret) - return ret + async with trio.open_nursery() as nursery: + log.info('running %r with %r', command, environ) + env = dict(os.environ) + env.update(environ) + try: + p = await nursery.start(functools.partial(trio.run_process, command, env=env, + check=False, stdin=None)) + except OSError as e: + raise OSError('cannot run %r: %s' % (command, e)) from e + log.debug('subprocess %d is running', p.pid) + ret = await p.wait() + log.debug('subprocess %d exited: %d', p.pid, ret) + return ret diff --git a/libagent/signify/__init__.py b/libagent/signify/__init__.py index a846ee90..3dcc3eac 100644 --- a/libagent/signify/__init__.py +++ b/libagent/signify/__init__.py @@ -7,14 +7,15 @@ import sys import time -from .. import util -from ..device import interface, ui +import trio + +from .. import device, util log = logging.getLogger(__name__) def _create_identity(user_id): - result = interface.Identity(identity_str='signify://', curve_name='ed25519') + result = device.interface.Identity(identity_str='signify://', curve_name='ed25519') result.identity_dict['host'] = user_id return result @@ -22,22 +23,23 @@ def _create_identity(user_id): class Client: """Sign messages and get public keys from a hardware device.""" - def __init__(self, device): + def __init__(self, ui): """C-tor.""" - self.device = device + self.ui = ui + self.ui.cached_passphrase_ack = util.ExpiringCache(seconds=float(60)) - def pubkey(self, identity): + async def pubkey(self, identity): """Return public key as VerifyingKey object.""" - with self.device: - return bytes(self.device.pubkey(ecdh=False, identity=identity)) + async with self.ui.device() as d: + return bytes(await d.pubkey(ecdh=False, identity=identity)) - def sign_with_pubkey(self, identity, data): + async def sign_with_pubkey(self, identity, data): """Sign the data and return a signature.""" log.info('please confirm Signify signature on %s for "%s"...', - self.device, identity.to_string()) + self.ui.get_device_name(), identity.to_string()) log.debug('signing data: %s', util.hexlify(data)) - with self.device: - sig, pubkey = self.device.sign_with_pubkey(blob=data, identity=identity) + async with self.ui.device() as d: + sig, pubkey = await d.sign_with_pubkey(blob=data, identity=identity) assert len(sig) == 64 assert len(pubkey) == 33 assert pubkey[:1] == b"\x00" @@ -54,21 +56,22 @@ def format_payload(pubkey, data, sig_alg): return binascii.b2a_base64(sig_alg + keynum + data).decode("ascii") -def run_pubkey(device_type, args): +async def run_pubkey(device_type, args): """Export hardware-based Signify public key.""" util.setup_logging(verbosity=args.verbose) log.warning('This Signify tool is still in EXPERIMENTAL mode, ' 'so please note that the key derivation, API, and features ' 'may change without backwards compatibility!') - identity = _create_identity(user_id=args.user_id) - pubkey = Client(device=device_type()).pubkey(identity=identity) - comment = f'untrusted comment: identity {identity.to_string()}\n' - payload = format_payload(pubkey=pubkey, data=pubkey, sig_alg=ALG_SIGNIFY) - print(comment + payload, end="") + async with await device.ui.UI.create(device_type=device_type, config=vars(args)) as ui: + identity = _create_identity(user_id=args.user_id) + pubkey = await Client(ui=ui).pubkey(identity=identity) + comment = f'untrusted comment: identity {identity.to_string()}\n' + payload = format_payload(pubkey=pubkey, data=pubkey, sig_alg=ALG_SIGNIFY) + print(comment + payload, end="") -def run_sign(device_type, args): +async def run_sign(device_type, args): """Prehash & sign an input blob using Ed25519.""" util.setup_logging(verbosity=args.verbose) identity = _create_identity(user_id=args.user_id) @@ -81,16 +84,18 @@ def run_sign(device_type, args): sig_alg = ALG_MINISIGN data_to_sign = hashlib.blake2b(data_to_sign).digest() - sig, pubkey = Client(device=device_type()).sign_with_pubkey(identity, data_to_sign) - pubkey_str = format_payload(pubkey=pubkey, data=pubkey, sig_alg=sig_alg) - sig_str = format_payload(pubkey=pubkey, data=sig, sig_alg=sig_alg) - untrusted_comment = f'untrusted comment: pubkey {pubkey_str}' - print(untrusted_comment + sig_str, end="") + async with await device.ui.UI.create(device_type=device_type, config=vars(args)) as ui: + c = Client(ui=ui) + sig, pubkey = await c.sign_with_pubkey(identity, data_to_sign) + pubkey_str = format_payload(pubkey=pubkey, data=pubkey, sig_alg=sig_alg) + sig_str = format_payload(pubkey=pubkey, data=sig, sig_alg=sig_alg) + untrusted_comment = f'untrusted comment: pubkey {pubkey_str}' + print(untrusted_comment + sig_str, end="") - comment_to_sign = sig + args.comment.encode() - sig, _ = Client(device=device_type()).sign_with_pubkey(identity, comment_to_sign) - sig_str = binascii.b2a_base64(sig).decode("ascii") - print(f'trusted comment: {args.comment}\n' + sig_str, end="") + comment_to_sign = sig + args.comment.encode() + sig, _ = await c.sign_with_pubkey(identity, comment_to_sign) + sig_str = binascii.b2a_base64(sig).decode("ascii") + print(f'trusted comment: {args.comment}\n' + sig_str, end="") def main(device_type): @@ -113,7 +118,5 @@ def main(device_type): p.set_defaults(func=run_sign) args = parser.parse_args() - device_type.ui = ui.UI(device_type=device_type, config=vars(args)) - device_type.ui.cached_passphrase_ack = util.ExpiringCache(seconds=float(60)) - return args.func(device_type=device_type, args=args) + return trio.run(args.func, device_type, args) diff --git a/libagent/ssh/__init__.py b/libagent/ssh/__init__.py index dee3ee24..35fe3cc2 100644 --- a/libagent/ssh/__init__.py +++ b/libagent/ssh/__init__.py @@ -12,9 +12,10 @@ import subprocess import sys import tempfile -import threading import configargparse +import trio +import trio_util try: # TODO: Not supported on Windows. Use daemoniker instead? @@ -28,17 +29,16 @@ log = logging.getLogger(__name__) -UNIX_SOCKET_TIMEOUT = 0.1 SOCK_TYPE = 'Windows named pipe' if sys.platform == 'win32' else 'UNIX domain socket' SOCK_TYPE_PATH = 'Windows named pipe path' if sys.platform == 'win32' else 'UNIX socket path' FILE_PREFIX = 'file:' -def ssh_args(conn): +@contextlib.asynccontextmanager +async def ssh_args(previous, conn): """Create SSH command for connecting specified server.""" I, = conn.identities identity = I.identity_dict - pubkey_tempfile, = conn.public_keys_as_files() args = [] if 'port' in identity: @@ -46,12 +46,15 @@ def ssh_args(conn): if 'user' in identity: args += ['-l', identity['user']] - args += ['-o', 'IdentityFile={}'.format(pubkey_tempfile.name)] - args += ['-o', 'IdentitiesOnly=true'] - return args + [identity['host']] + async with conn.public_keys_as_file() as pubkey_tempfile_name: + args += ['-o', 'IdentityFile={}'.format(pubkey_tempfile_name)] + args += ['-o', 'IdentitiesOnly=true'] + async with previous(conn) as command: + yield ['ssh'] + args + [identity['host']] + command -def mosh_args(conn): +@contextlib.asynccontextmanager +async def mosh_args(previous, conn): """Create SSH command for connecting specified server.""" I, = conn.identities identity = I.identity_dict @@ -64,7 +67,8 @@ def mosh_args(conn): else: args += [identity['host']] - return args + async with previous(conn) as command: + yield ['mosh'] + args + command def _to_unicode(s): @@ -93,9 +97,6 @@ def create_agent_parser(device_type): p.add_argument('-e', '--ecdsa-curve-name', metavar='CURVE', default=formats.CURVE_NIST256, help='specify ECDSA curve name: ' + curve_names) - p.add_argument('--timeout', - default=UNIX_SOCKET_TIMEOUT, type=float, - help='timeout for accepting SSH client connections') p.add_argument('--debug', default=False, action='store_true', help='log SSH protocol messages for debugging.') p.add_argument('--log-file', type=str, @@ -133,29 +134,18 @@ def create_agent_parser(device_type): return p -@contextlib.contextmanager -def serve(handler, sock_path, timeout=UNIX_SOCKET_TIMEOUT): - """ - Start the ssh-agent server on a UNIX-domain socket. - - If no connection is made during the specified timeout, - retry until the context is over. - """ +@contextlib.asynccontextmanager +async def serve(handler, sock_path): + """Start the ssh-agent server on a UNIX-domain socket.""" ssh_version = subprocess.check_output(['ssh', '-V'], stderr=subprocess.STDOUT) log.debug('local SSH version: %r', ssh_version) environ = {'SSH_AUTH_SOCK': sock_path, 'SSH_AGENT_PID': str(os.getpid())} - device_mutex = threading.Lock() - with server.unix_domain_socket_server(sock_path) as sock: - sock.settimeout(timeout) - quit_event = threading.Event() + async with server.unix_domain_socket_server(sock_path) as sock: + quit_event = trio.Event() handle_conn = functools.partial(server.handle_connection, - handler=handler, - mutex=device_mutex) - kwargs = {'sock': sock, - 'handle_conn': handle_conn, - 'quit_event': quit_event} - with server.spawn(server.server_thread, kwargs): + handler=handler) + async with trio_util.move_on_when(server.server_thread, sock, handle_conn, quit_event): try: yield environ finally: @@ -163,25 +153,32 @@ def serve(handler, sock_path, timeout=UNIX_SOCKET_TIMEOUT): quit_event.set() -def run_server(conn, command, sock_path, debug, timeout): - """Common code for run_agent and run_git below.""" +async def run_server(conn_context, command_context, sock_path, debug): + """Run the SSH agent. Optionally execute a command while the agent is running.""" ret = 0 try: - handler = protocol.Handler(conn=conn, debug=debug) - with serve(handler=handler, sock_path=sock_path, - timeout=timeout) as env: - if command: - ret = server.run_process(command=command, environ=env) - else: - try: - signal.pause() # wait for signal (e.g. SIGINT) - except AttributeError: - sys.stdin.read() # Windows doesn't support signal.pause - except KeyboardInterrupt: + # override default PIN/passphrase entry tools (relevant for TREZOR/Keepkey): + async with conn_context as conn: + handler = protocol.Handler(conn=conn, debug=debug) + async with serve(handler=handler, sock_path=sock_path) as env: + async with command_context(conn) as command: + if command: + ret = await server.run_process(command=command, environ=env) + else: + await trio.sleep_forever() # Wait until the server has stopped + finally: log.info('server stopped') return ret +async def show_public_keys(conn_context): + """Command for showing public keys associated with the provided identities.""" + async with conn_context as conn: + for pk in await conn.public_keys(): + sys.stdout.write(pk) + return 0 # success exit code + + def handle_connection_error(func): """Fail with non-zero exit code.""" @functools.wraps(func) @@ -209,74 +206,49 @@ def import_public_keys(contents): yield line -class ClosableNamedTemporaryFile(): - """Creates a temporary file that is not deleted when the file is closed. - - This allows the file to be opened with an exclusive lock, but used by other programs before - it is deleted - """ - - def __init__(self): - """Create a temporary file.""" - self.file = tempfile.NamedTemporaryFile(prefix='trezor-ssh-pubkey-', mode='w', delete=False) - self.name = self.file.name - - def write(self, buf): - """Write `buf` to the file.""" - self.file.write(buf) - - def close(self): - """Closes the file, allowing it to be opened by other programs. Does not delete the file.""" - self.file.close() - - def __del__(self): - """Deletes the temporary file.""" - try: - os.unlink(self.file.name) - except OSError: - log.warning("Failed to delete temporary file: %s", self.file.name) - - class JustInTimeConnection: """Connect to the device just before the needed operation.""" - def __init__(self, conn_factory, identities, public_keys=None): + def __init__(self, conn, identities, public_keys=None): """Create a JIT connection object.""" - self.conn_factory = conn_factory + self.conn = conn self.identities = identities self.public_keys_cache = public_keys - self.public_keys_tempfiles = [] - def public_keys(self): + async def public_keys(self): """Return a list of SSH public keys (in textual format).""" if not self.public_keys_cache: - conn = self.conn_factory() - self.public_keys_cache = conn.export_public_keys(self.identities) + self.public_keys_cache = await self.conn.export_public_keys(self.identities) return self.public_keys_cache - def parse_public_keys(self): + async def parse_public_keys(self): """Parse SSH public keys into dictionaries.""" public_keys = [formats.import_public_key(pk) - for pk in self.public_keys()] + for pk in await self.public_keys()] for pk, identity in zip(public_keys, self.identities): pk['identity'] = identity return public_keys - def public_keys_as_files(self): + @contextlib.asynccontextmanager + async def public_keys_as_file(self): """Store public keys as temporary SSH identity files.""" - if not self.public_keys_tempfiles: - for pk in self.public_keys(): - f = ClosableNamedTemporaryFile() - f.write(pk) - f.close() - self.public_keys_tempfiles.append(f) + tf = tempfile.NamedTemporaryFile(prefix='trezor-ssh-pubkey-', mode='w', delete=False) + try: + async with trio.wrap_file(tf) as f: + for pk in await self.public_keys(): + await f.write(pk) + await f.write('\n') - return self.public_keys_tempfiles + yield tf.name + finally: + try: + await trio.Path(tf.name).unlink() + except OSError: + log.warning("Failed to delete temporary file: %s", tf.name) - def sign(self, blob, identity): + async def sign(self, blob, identity): """Sign a given blob using the specified identity on the device.""" - conn = self.conn_factory() - return conn.sign_ssh_challenge(blob=blob, identity=identity) + return await self.conn.sign_ssh_challenge(blob=blob, identity=identity) @contextlib.contextmanager @@ -297,10 +269,25 @@ def _get_sock_path(args): return tempfile.mktemp(prefix='trezor-ssh-agent-') +@contextlib.asynccontextmanager +async def _command_context_constant(command, *_): + yield command + + +@contextlib.asynccontextmanager +async def _just_in_time_conection(device_type, config, identities, public_keys): + # override default PIN/passphrase entry tools (relevant for TREZOR/Keepkey): + async with await device.ui.UI.create(device_type=device_type, config=config) as ui: + conn = JustInTimeConnection( + conn=client.Client(ui), + identities=identities, public_keys=public_keys) + yield conn + + @handle_connection_error def main(device_type): """Run ssh-agent using given hardware client factory.""" - args = create_agent_parser(device_type=device_type).parse_args() + args = create_agent_parser(device_type=device_type).parse_intermixed_args() util.setup_logging(verbosity=args.verbose, filename=args.log_file) public_keys = None @@ -321,39 +308,37 @@ def main(device_type): identity.identity_dict['proto'] = 'ssh' log.info('identity #%d: %s', index, identity.to_string()) - # override default PIN/passphrase entry tools (relevant for TREZOR/Keepkey): - device_type.ui = device.ui.UI(device_type=device_type, config=vars(args)) - - conn = JustInTimeConnection( - conn_factory=lambda: client.Client(device_type()), - identities=identities, public_keys=public_keys) + conn_context = _just_in_time_conection(device_type, vars(args), identities, public_keys) sock_path = _get_sock_path(args) - command = args.command + command_context = functools.partial(_command_context_constant, args.command) + show_pks = not args.command context = _dummy_context() if args.connect: - command = ['ssh'] + ssh_args(conn) + args.command + show_pks = False + command_context = functools.partial(ssh_args, command_context) elif sys.platform != 'win32' and args.mosh: - command = ['mosh'] + mosh_args(conn) + args.command + show_pks = False + command_context = functools.partial(mosh_args, command_context) elif daemon and args.daemonize: + show_pks = False out = 'SSH_AUTH_SOCK={0}; export SSH_AUTH_SOCK;\n'.format(sock_path) sys.stdout.write(out) sys.stdout.flush() context = daemon.DaemonContext() log.info('running the agent as a daemon on %s', sock_path) elif args.foreground: + show_pks = False log.info('running the agent on %s', sock_path) use_shell = bool(args.shell) if use_shell: - command = os.environ['SHELL'] + show_pks = False + command_context = functools.partial(_command_context_constant, [os.environ['SHELL']]) sys.stdin.close() - if command or (daemon and args.daemonize) or args.foreground: + if not show_pks: with context: - return run_server(conn=conn, command=command, sock_path=sock_path, - debug=args.debug, timeout=args.timeout) + return trio.run(run_server, conn_context, command_context, sock_path, args.debug) else: - for pk in conn.public_keys(): - sys.stdout.write(pk) - return 0 # success exit code + return trio.run(show_public_keys, conn_context) diff --git a/libagent/ssh/client.py b/libagent/ssh/client.py index aa3b47cc..27990293 100644 --- a/libagent/ssh/client.py +++ b/libagent/ssh/client.py @@ -14,28 +14,28 @@ class Client: """Client wrapper for SSH authentication device.""" - def __init__(self, device): + def __init__(self, ui): """Connect to hardware device.""" - self.device = device + self.ui = ui - def export_public_keys(self, identities): + async def export_public_keys(self, identities): """Export SSH public keys from the device.""" pubkeys = [] - with self.device: + async with self.ui.device() as device: for i in identities: - vk = self.device.pubkey(identity=i) + vk = await device.pubkey(identity=i) label = i.to_string() pubkey = formats.export_public_key(vk=vk, label=label) pubkeys.append(pubkey) return pubkeys - def sign_ssh_challenge(self, blob, identity): + async def sign_ssh_challenge(self, blob, identity): """Sign given blob using a private key on the device.""" log.debug('blob: %r', blob) msg = parse_ssh_blob(blob) if msg['sshsig']: log.info('please confirm "%s" signature for "%s" using %s...', - msg['namespace'], identity.to_string(), self.device) + msg['namespace'], identity.to_string(), self.ui.get_device_name()) else: log.debug('%s: user %r via %r (%r)', msg['conn'], msg['user'], msg['auth'], msg['key_type']) @@ -46,10 +46,10 @@ def sign_ssh_challenge(self, blob, identity): log.info('please confirm user "%s" login to "%s" using %s...', msg['user'].decode('ascii'), identity.to_string(), - self.device) + self.ui.get_device_name()) - with self.device: - return self.device.sign(blob=blob, identity=identity) + async with self.ui.device() as device: + return await device.sign(blob=blob, identity=identity) def parse_ssh_blob(data): diff --git a/libagent/ssh/protocol.py b/libagent/ssh/protocol.py index 020c8f7b..986a5dda 100644 --- a/libagent/ssh/protocol.py +++ b/libagent/ssh/protocol.py @@ -62,7 +62,7 @@ def failure(): return util.frame(error_msg) -def _legacy_pubs(buf): +async def _legacy_pubs(buf): """SSH v1 public keys are not supported.""" leftover = buf.read() if leftover: @@ -91,7 +91,7 @@ def __init__(self, conn, debug=False): msg_code('SSH_AGENTC_EXTENSION'): _unsupported_extension, } - def handle(self, msg): + async def handle(self, msg): """Handle SSH message from the SSH client and return the response.""" debug_msg = ': {!r}'.format(msg) if self.debug else '' log.debug('request: %d bytes%s', len(msg), debug_msg) @@ -103,15 +103,15 @@ def handle(self, msg): method = self.methods[code] log.debug('calling %s()', method.__name__) - reply = method(buf=buf) + reply = await method(buf=buf) debug_reply = ': {!r}'.format(reply) if self.debug else '' log.debug('reply: %d bytes%s', len(reply), debug_reply) return reply - def list_pubs(self, buf): + async def list_pubs(self, buf): """SSH v2 public keys are serialized and returned.""" assert not buf.read() - keys = self.conn.parse_public_keys() + keys = await self.conn.parse_public_keys() code = util.pack('B', msg_code('SSH2_AGENT_IDENTITIES_ANSWER')) num = util.pack('L', len(keys)) log.debug('available keys: %s', [k['name'] for k in keys]) @@ -120,7 +120,7 @@ def list_pubs(self, buf): pubs = [util.frame(k['blob']) + util.frame(k['name']) for k in keys] return util.frame(code, num, *pubs) - def sign_message(self, buf): + async def sign_message(self, buf): """ SSH v2 public key authentication is performed. @@ -133,7 +133,7 @@ def sign_message(self, buf): assert util.read_frame(buf) == b'' assert not buf.read() - for k in self.conn.parse_public_keys(): + for k in await self.conn.parse_public_keys(): if (k['fingerprint']) == (key['fingerprint']): log.debug('using key %r (%s)', k['name'], k['fingerprint']) key = k @@ -144,7 +144,7 @@ def sign_message(self, buf): label = key['name'].decode('utf-8') log.debug('signing %d-byte blob with "%s" key', len(blob), label) try: - signature = self.conn.sign(blob=blob, identity=key['identity']) + signature = await self.conn.sign(blob=blob, identity=key['identity']) except IOError: return failure() except Exception: @@ -167,6 +167,6 @@ def sign_message(self, buf): return util.frame(code, data) -def _unsupported_extension(buf): # pylint: disable=unused-argument +async def _unsupported_extension(buf): # pylint: disable=unused-argument code = util.pack('B', msg_code('SSH_AGENT_EXTENSION_FAILURE')) return util.frame(code) diff --git a/libagent/ssh/tests/test_client.py b/libagent/ssh/tests/test_client.py index 9c982244..8910f0be 100644 --- a/libagent/ssh/tests/test_client.py +++ b/libagent/ssh/tests/test_client.py @@ -17,6 +17,8 @@ class MockDevice(device.interface.Device): # pylint: disable=abstract-method + fail_sign = False + @classmethod def package_name(cls): return 'fake-device-agent' @@ -30,6 +32,8 @@ def pubkey(self, identity, ecdh=False): # pylint: disable=unused-argument def sign(self, identity, blob): """Sign given blob and return the signature (as bytes).""" + if MockDevice.fail_sign: + raise IOError(42, 'ERROR') assert self.conn assert blob == BLOB return SIG @@ -49,31 +53,29 @@ def sign(self, identity, blob): b'\xdc\xf0H\xab\xa8\xac\xa7? \x8f=C\x88N\xe2') -def test_ssh_agent(): +@pytest.mark.trio +async def test_ssh_agent(): identity = device.interface.Identity(identity_str='localhost:22', curve_name=CURVE) - c = client.Client(device=MockDevice()) - assert c.export_public_keys([identity]) == [PUBKEY_TEXT] - signature = c.sign_ssh_challenge(blob=BLOB, identity=identity) - - key = formats.import_public_key(PUBKEY_TEXT) - serialized_sig = key['verifier'](sig=signature, msg=BLOB) - - stream = io.BytesIO(serialized_sig) - r = util.read_frame(stream) - s = util.read_frame(stream) - assert not stream.read() - assert r[:1] == b'\x00' - assert s[:1] == b'\x00' - assert r[1:] + s[1:] == SIG - - # pylint: disable=unused-argument - def cancel_sign(identity, blob): - raise IOError(42, 'ERROR') - - c.device.sign = cancel_sign - with pytest.raises(IOError): - c.sign_ssh_challenge(blob=BLOB, identity=identity) + async with await device.ui.UI.create(device_type=MockDevice, config={}) as ui: + c = client.Client(ui) + assert await c.export_public_keys([identity]) == [PUBKEY_TEXT] + signature = await c.sign_ssh_challenge(blob=BLOB, identity=identity) + + key = formats.import_public_key(PUBKEY_TEXT) + serialized_sig = key['verifier'](sig=signature, msg=BLOB) + + stream = io.BytesIO(serialized_sig) + r = util.read_frame(stream) + s = util.read_frame(stream) + assert not stream.read() + assert r[:1] == b'\x00' + assert s[:1] == b'\x00' + assert r[1:] + s[1:] == SIG + + MockDevice.fail_sign = True + with pytest.raises(IOError): + await c.sign_ssh_challenge(blob=BLOB, identity=identity) CHALLENGE_BLOB = ( diff --git a/libagent/ssh/tests/test_protocol.py b/libagent/ssh/tests/test_protocol.py index e226b809..19906828 100644 --- a/libagent/ssh/tests/test_protocol.py +++ b/libagent/ssh/tests/test_protocol.py @@ -1,4 +1,3 @@ -import mock import pytest from .. import device, formats, protocol @@ -16,31 +15,41 @@ NIST256_SIGN_REPLY = b'\x00\x00\x00j\x0e\x00\x00\x00e\x00\x00\x00\x13ecdsa-sha2-nistp256\x00\x00\x00J\x00\x00\x00!\x00\x88G!\x0c\n\x16:\xbeF\xbe\xb9\xd2\xa9&e\x89\xad\xc4}\x10\xf8\xbc\xdc\xef\x0e\x8d_\x8a6.\xb6\x1f\x00\x00\x00!\x00q\xf0\x16>,\x9a\xde\xe7(\xd6\xd7\x93\x1f\xed\xf9\x94ddw\xfe\xbdq\x13\xbb\xfc\xa9K\xea\x9dC\xa1\xe9' # nopep8 -def fake_connection(keys, signer): - c = mock.Mock() - c.parse_public_keys.return_value = keys - c.sign = signer - return c +class FakeConnection: + def __init__(self, keys, signer): + self.keys = keys + self.signer = signer + async def parse_public_keys(self): + return self.keys -def test_list(): + async def sign(self, blob, identity): + if self.signer: + return self.signer(blob=blob, identity=identity) + return b'' + + +@pytest.mark.trio +async def test_list(): key = formats.import_public_key(NIST256_KEY) key['identity'] = device.interface.Identity('ssh://localhost', 'nist256p1') - h = protocol.Handler(fake_connection(keys=[key], signer=None)) - reply = h.handle(LIST_MSG) + h = protocol.Handler(FakeConnection(keys=[key], signer=None)) + reply = await h.handle(LIST_MSG) assert reply == LIST_NIST256_REPLY -def test_list_legacy_pubs_with_suffix(): - h = protocol.Handler(fake_connection(keys=[], signer=None)) +@pytest.mark.trio +async def test_list_legacy_pubs_with_suffix(): + h = protocol.Handler(FakeConnection(keys=[], signer=None)) suffix = b'\x00\x00\x00\x06foobar' - reply = h.handle(b'\x01' + suffix) + reply = await h.handle(b'\x01' + suffix) assert reply == b'\x00\x00\x00\x05\x02\x00\x00\x00\x00' # no legacy keys -def test_unsupported(): - h = protocol.Handler(fake_connection(keys=[], signer=None)) - reply = h.handle(b'\x09') +@pytest.mark.trio +async def test_unsupported(): + h = protocol.Handler(FakeConnection(keys=[], signer=None)) + reply = await h.handle(b'\x09') assert reply == b'\x00\x00\x00\x01\x05' @@ -50,21 +59,24 @@ def ecdsa_signer(identity, blob): return NIST256_SIG -def test_ecdsa_sign(): +@pytest.mark.trio +async def test_ecdsa_sign(): key = formats.import_public_key(NIST256_KEY) key['identity'] = device.interface.Identity('ssh://localhost', 'nist256p1') - h = protocol.Handler(fake_connection(keys=[key], signer=ecdsa_signer)) - reply = h.handle(NIST256_SIGN_MSG) + h = protocol.Handler(FakeConnection(keys=[key], signer=ecdsa_signer)) + reply = await h.handle(NIST256_SIGN_MSG) assert reply == NIST256_SIGN_REPLY -def test_sign_missing(): - h = protocol.Handler(fake_connection(keys=[], signer=ecdsa_signer)) +@pytest.mark.trio +async def test_sign_missing(): + h = protocol.Handler(FakeConnection(keys=[], signer=ecdsa_signer)) with pytest.raises(KeyError): - h.handle(NIST256_SIGN_MSG) + await h.handle(NIST256_SIGN_MSG) -def test_sign_wrong(): +@pytest.mark.trio +async def test_sign_wrong(): def wrong_signature(identity, blob): assert identity.to_string() == '' assert blob == NIST256_BLOB @@ -72,19 +84,20 @@ def wrong_signature(identity, blob): key = formats.import_public_key(NIST256_KEY) key['identity'] = device.interface.Identity('ssh://localhost', 'nist256p1') - h = protocol.Handler(fake_connection(keys=[key], signer=wrong_signature)) + h = protocol.Handler(FakeConnection(keys=[key], signer=wrong_signature)) with pytest.raises(ValueError): - h.handle(NIST256_SIGN_MSG) + await h.handle(NIST256_SIGN_MSG) -def test_sign_cancel(): +@pytest.mark.trio +async def test_sign_cancel(): def cancel_signature(identity, blob): # pylint: disable=unused-argument raise IOError() key = formats.import_public_key(NIST256_KEY) key['identity'] = device.interface.Identity('ssh://localhost', 'nist256p1') - h = protocol.Handler(fake_connection(keys=[key], signer=cancel_signature)) - assert h.handle(NIST256_SIGN_MSG) == protocol.failure() + h = protocol.Handler(FakeConnection(keys=[key], signer=cancel_signature)) + assert await h.handle(NIST256_SIGN_MSG) == protocol.failure() ED25519_KEY = 'ssh-ed25519 AAAAC3NzaC1lZDI1NTE5AAAAIFBdF2tjfSO8nLIi736is+f0erq28RTc7CkM11NZtTKR ssh://localhost' # nopep8 @@ -101,9 +114,10 @@ def ed25519_signer(identity, blob): return ED25519_SIG -def test_ed25519_sign(): +@pytest.mark.trio +async def test_ed25519_sign(): key = formats.import_public_key(ED25519_KEY) key['identity'] = device.interface.Identity('ssh://localhost', 'ed25519') - h = protocol.Handler(fake_connection(keys=[key], signer=ed25519_signer)) - reply = h.handle(ED25519_SIGN_MSG) + h = protocol.Handler(FakeConnection(keys=[key], signer=ed25519_signer)) + reply = await h.handle(ED25519_SIGN_MSG) assert reply == ED25519_SIGN_REPLY diff --git a/libagent/tests/test_server.py b/libagent/tests/test_server.py index 947160a8..243e3cce 100644 --- a/libagent/tests/test_server.py +++ b/libagent/tests/test_server.py @@ -1,135 +1,147 @@ +import functools import io import os -import socket import tempfile -import threading -import mock import pytest +import trio from .. import server, util from ..ssh import protocol -def test_socket(): +@pytest.mark.trio +async def test_socket(): path = tempfile.mktemp() - with server.unix_domain_socket_server(path): + async with server.unix_domain_socket_server(path): pass assert not os.path.isfile(path) class FakeSocket: - def __init__(self, data=b''): + def __init__(self, data=b'', recv_raises=None): self.rx = io.BytesIO(data) self.tx = io.BytesIO() + self.recv_raises = recv_raises - def sendall(self, data): + def __enter__(self): + return self + + def __exit__(self, *_): + self.close() + + async def send(self, data): self.tx.write(data) + return len(data) - def recv(self, size): + async def recv(self, size): + if self.recv_raises: + toraise = self.recv_raises[0] + self.recv_raises = self.recv_raises[1:] + raise toraise return self.rx.read(size) def close(self): pass - def settimeout(self, value): - pass - - -def empty_device(): - c = mock.Mock(spec=['parse_public_keys']) - c.parse_public_keys.return_value = [] - return c +# pylint: disable=too-few-public-methods +class EmptyDevice: + async def parse_public_keys(self): + return [] -def test_handle(): - mutex = threading.Lock() - handler = protocol.Handler(conn=empty_device()) +@pytest.mark.trio +async def test_handle(): + handler = protocol.Handler(conn=EmptyDevice()) conn = FakeSocket() - server.handle_connection(conn, handler, mutex) + await server.handle_connection(conn, handler) msg = bytearray([protocol.msg_code('SSH_AGENTC_REQUEST_RSA_IDENTITIES')]) conn = FakeSocket(util.frame(msg)) - server.handle_connection(conn, handler, mutex) + await server.handle_connection(conn, handler) assert conn.tx.getvalue() == b'\x00\x00\x00\x05\x02\x00\x00\x00\x00' msg = bytearray([protocol.msg_code('SSH2_AGENTC_REQUEST_IDENTITIES')]) conn = FakeSocket(util.frame(msg)) - server.handle_connection(conn, handler, mutex) + await server.handle_connection(conn, handler) assert conn.tx.getvalue() == b'\x00\x00\x00\x05\x0C\x00\x00\x00\x00' msg = bytearray([protocol.msg_code('SSH2_AGENTC_ADD_IDENTITY')]) conn = FakeSocket(util.frame(msg)) - server.handle_connection(conn, handler, mutex) + await server.handle_connection(conn, handler) conn.tx.seek(0) reply = util.read_frame(conn.tx) assert reply == util.pack('B', protocol.msg_code('SSH_AGENT_FAILURE')) - conn_mock = mock.Mock(spec=FakeSocket) - conn_mock.recv.side_effect = [Exception, EOFError] - server.handle_connection(conn=conn_mock, handler=None, mutex=mutex) + conn = FakeSocket(recv_raises=[Exception(), EOFError()]) + await server.handle_connection(conn=conn, handler=None) -def test_server_thread(): +@pytest.mark.trio +async def test_server_thread(): sock = FakeSocket() connections = [sock] - quit_event = threading.Event() + quit_event = trio.Event() class FakeServer: - def accept(self): + async def accept(self): if not connections: - raise socket.timeout() + await trio.sleep_forever() return connections.pop(), 'address' def getsockname(self): return 'fake_server' - def handle_conn(conn): + async def handle_conn(conn): assert conn is sock quit_event.set() - server.server_thread(sock=FakeServer(), - handle_conn=handle_conn, - quit_event=quit_event) - quit_event.wait() - + await server.server_thread(sock=FakeServer(), + handle_conn=handle_conn, + quit_event=quit_event) -def test_spawn(): - obj = [] - def thread(x): - obj.append(x) +@pytest.mark.trio +async def test_run(): + assert await server.run_process(['true'], environ={}) == 0 + assert await server.run_process(['false'], environ={}) == 1 + assert await server.run_process(command=['bash', '-c', 'exit $X'], + environ={'X': '42'}) == 42 - with server.spawn(thread, {'x': 1}): - pass - - assert obj == [1] + with pytest.raises(OSError): + await server.run_process([''], environ={}) -def test_run(): - assert server.run_process(['true'], environ={}) == 0 - assert server.run_process(['false'], environ={}) == 1 - assert server.run_process(command=['bash', '-c', 'exit $X'], - environ={'X': '42'}) == 42 +@pytest.mark.trio +async def test_remove(): + path = 'foo.bar' + paths = set() + force_exists_paths = set() - with pytest.raises(OSError): - server.run_process([''], environ={}) + class FakePath: + def __init__(self, paths, force_exists_paths, path): + self.path = path + self.paths = paths + self.force_exists_paths = force_exists_paths + async def unlink(self): + if self.path not in self.paths: + raise OSError('boom') + self.paths.remove(self.path) -def test_remove(): - path = 'foo.bar' + async def exists(self): + return self.path in self.paths or self.path in self.force_exists_paths - def remove(p): - assert p == path + fake_path = functools.partial(FakePath, paths, force_exists_paths) + paths.add(path) - server.remove_file(path, remove=remove) + await server.remove_file(path, trio_path=fake_path) - def remove_raise(_): - raise OSError('boom') + await server.remove_file(path, trio_path=fake_path) - server.remove_file(path, remove=remove_raise, exists=lambda _: False) + force_exists_paths.add(path) with pytest.raises(OSError): - server.remove_file(path, remove=remove_raise, exists=lambda _: True) + await server.remove_file(path, trio_path=fake_path) diff --git a/libagent/tests/test_util.py b/libagent/tests/test_util.py index e3135e84..88b8dc34 100644 --- a/libagent/tests/test_util.py +++ b/libagent/tests/test_util.py @@ -29,24 +29,27 @@ class FakeSocket: def __init__(self): self.buf = io.BytesIO() - def sendall(self, data): + async def send(self, data): self.buf.write(data) + return len(data) - def recv(self, size): + async def recv(self, size): return self.buf.read(size) -def test_send_recv(): +@pytest.mark.trio +async def test_send_recv(): s = FakeSocket() - util.send(s, b'123') - util.send(s, b'*') + await util.send(s, b'123') + await util.send(s, b'*') assert s.buf.getvalue() == b'123*' s.buf.seek(0) - assert util.recv(s, 2) == b'12' - assert util.recv(s, 2) == b'3*' + assert await util.recv_async(s, 2) == b'12' + assert await util.recv_async(s, 2) == b'3*' - pytest.raises(EOFError, util.recv, s, 1) + with pytest.raises(EOFError): + await util.recv_async(s, 1) def test_crc24(): @@ -104,16 +107,17 @@ def test_setup_logging(): util.setup_logging(verbosity=10, filename='/dev/null') -def test_memoize(): +@pytest.mark.trio +async def test_memoize(): f = mock.Mock(side_effect=lambda x: x) - def func(x): + @util.memoize + async def func(x): # mock.Mock doesn't work with functools.wraps() return f(x) - g = util.memoize(func) - assert g(1) == g(1) - assert g(1) != g(2) + assert await func(1) == await func(1) + assert await func(1) != await func(2) assert f.mock_calls == [mock.call(1), mock.call(2)] @@ -125,22 +129,23 @@ def test_assuan_serialize(): def test_cache(): timer = mock.Mock(side_effect=range(7)) - c = util.ExpiringCache(seconds=2, timer=timer) # t=0 - assert c.get() is None # t=1 + c = util.ExpiringCache(seconds=2, timer=timer) + c.set('not_the_key', 'unused') # t=0 + assert c.get('key') is None # t=1 obj = 'foo' - c.set(obj) # t=2 - assert c.get() is obj # t=3 - assert c.get() is obj # t=4 - assert c.get() is None # t=5 - assert c.get() is None # t=6 + c.set('key', obj) # t=2 + assert c.get('key') is obj # t=3 + assert c.get('key') is obj # t=4 + assert c.get('key') is None # t=5 + assert c.get('key') is None # t=6 def test_cache_inf(): timer = mock.Mock(side_effect=range(6)) c = util.ExpiringCache(seconds=float('inf'), timer=timer) obj = 'foo' - c.set(obj) - assert c.get() is obj - assert c.get() is obj - assert c.get() is obj - assert c.get() is obj + c.set('key', obj) + assert c.get('key') is obj + assert c.get('key') is obj + assert c.get('key') is obj + assert c.get('key') is obj diff --git a/libagent/util.py b/libagent/util.py index 96ccad2f..23cbf2e7 100644 --- a/libagent/util.py +++ b/libagent/util.py @@ -6,19 +6,56 @@ import logging import struct import sys -import time +import threading + +import trio log = logging.getLogger(__name__) -def send(conn, data): +async def send(conn, data): """Send data blob to connection socket.""" - conn.sendall(data) + while len(data) > 0: + sent = await conn.send(data) + if not sent: + raise IOError('Socket refused data') + data = data[sent:] + + +async def recv_async(conn, size): + """ + Receive bytes from connection socket. + + If size is struct.calcsize()-compatible format, use it to unpack the data. + Otherwise, return the plain blob as bytes. + """ + try: + fmt = size + size = struct.calcsize(fmt) + except TypeError: + fmt = None + try: + _read = conn.recv + except AttributeError: + _read = conn.read + + res = io.BytesIO() + while size > 0: + buf = await _read(size) + if not buf: + raise EOFError + size = size - len(buf) + res.write(buf) + res = res.getvalue() + if fmt: + return struct.unpack(fmt, res) + else: + return res def recv(conn, size): """ - Receive bytes from connection socket or stream. + Receive bytes from in-memory stream. If size is struct.calcsize()-compatible format, use it to unpack the data. Otherwise, return the plain blob as bytes. @@ -47,8 +84,14 @@ def recv(conn, size): return res -def read_frame(conn): +async def read_frame_async(conn): """Read size-prefixed frame from connection.""" + size, = await recv_async(conn, '>L') + return await recv_async(conn, size) + + +def read_frame(conn): + """Read size-prefixed frame from in-memory stream.""" size, = recv(conn, '>L') return recv(conn, size) @@ -204,13 +247,13 @@ def memoize(func): cache = {} @functools.wraps(func) - def wrapper(*args, **kwargs): + async def wrapper(*args, **kwargs): """Caching wrapper.""" key = (args, tuple(sorted(kwargs.items()))) if key in cache: return cache[key] else: - result = func(*args, **kwargs) + result = await func(*args, **kwargs) cache[key] = result return result @@ -222,13 +265,13 @@ def memoize_method(method): cache = {} @functools.wraps(method) - def wrapper(self, *args, **kwargs): + async def wrapper(self, *args, **kwargs): """Caching wrapper.""" key = (args, tuple(sorted(kwargs.items()))) if key in cache: return cache[key] else: - result = method(self, *args, **kwargs) + result = await method(self, *args, **kwargs) cache[key] = result return result @@ -236,7 +279,7 @@ def wrapper(self, *args, **kwargs): @memoize -def which(cmd): +async def which(cmd): """Return full path to specified command, or raise OSError if missing.""" try: # For Python 3 @@ -244,7 +287,7 @@ def which(cmd): except ImportError: # For Python 2 from backports.shutil_which import which as _which - full_path = _which(cmd) + full_path = await trio.to_thread.run_sync(_which, cmd) if full_path is None: raise OSError('Cannot find {!r} in $PATH'.format(cmd)) log.debug('which %r => %r', cmd, full_path) @@ -286,20 +329,142 @@ def escape_cmd_win(in_str): class ExpiringCache: """Simple cache with a deadline.""" - def __init__(self, seconds, timer=time.time): + def __init__(self, seconds, timer=trio.current_time): """C-tor.""" self.duration = seconds self.timer = timer - self.value = None - self.set(None) + self.values = {} - def get(self): + def get(self, key): """Returns existing value, or None if deadline has expired.""" - if self.timer() > self.deadline: - self.value = None - return self.value + curtime = self.timer() + self.values = {k: v for k, v in self.values.items() if curtime <= v[0]} + return self.values.get(key, (None, None))[1] - def set(self, value): + def set(self, key, value): """Set new value and reset the deadline for expiration.""" - self.deadline = self.timer() + self.duration - self.value = value + self.values[key] = ( + self.timer() + self.duration, + value + ) + + +@contextlib.asynccontextmanager +async def run_on_thread(): + """Allows running blocking commands from asynchronous context on a single thread.""" + # pylint: disable=too-many-statements + command_condition = threading.Condition() + command_value = () + command_in_progress = False + thread_is_running = True + + def before_resolve(): + nonlocal command_condition, command_in_progress + with command_condition: + command_in_progress = False + + async def run_command(command, *args, **kwargs): + nonlocal command_condition, command_value, thread_is_running + assert thread_is_running + res = _ResultFromThread(before_resolve) + with command_condition: + assert not command_value + command_value = (res, command, args, kwargs) + command_condition.notify() + return await res.wait() + + async def run_command_immediate(command, *args, **kwargs): + nonlocal command_condition, command_value, command_in_progress, thread_is_running + assert thread_is_running + res = _ResultFromThread(before_resolve) + bypass_thread = False + with command_condition: + if not command_value and not command_in_progress: + command_value = (res, command, args, kwargs) + command_condition.notify() + else: + bypass_thread = True + if bypass_thread: + def run_func(): + nonlocal command, args, kwargs + command(*args, **kwargs) + return await trio.to_thread.run_sync(run_func) + else: + return await res.wait() + + def thread_func(): + nonlocal command_condition, command_value, command_in_progress + while True: + with command_condition: + while not command_value: + command_condition.wait() + command = command_value + command_value = () + command_in_progress = True + res, func, args, kwargs = command + if res is None: + break + with res: + res.resolve(func(*args, **kwargs)) + + async with trio.open_nursery() as nursery: + nursery.start_soon(trio.to_thread.run_sync, thread_func) + try: + yield run_command, run_command_immediate + finally: + thread_is_running = False + with command_condition: + # Abort any in-flight commands, as closing the thread is the highest priority + if command_value: + res, = command_value + if res is not None: + res.reject(trio.Cancelled('Thread closed')) + command_value = (None, None, None, None) + command_condition.notify() + + +class _ResultFromThread: + def __init__(self, before_resolve): + self.event = trio.Event() + self.retval = None + self.retiserr = False + self.done = False + self.before_resolve = before_resolve + + def __enter__(self): + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + if exc_type: + trio.from_thread.run_sync(self.reject, exc_val) + else: + # Last chance to resolve + trio.from_thread.run_sync(self._resolve, None) + return True + + def resolve(self, retval): + trio.from_thread.run_sync(self._resolve, retval) + + def _resolve(self, retval): + if self.done: + return + self.done = True + self.before_resolve() + self.retval = retval + self.retiserr = False + self.event.set() + + def reject(self, retval): + if self.done: + return + self.done = True + self.before_resolve() + self.retval = retval + self.retiserr = True + self.event.set() + + async def wait(self): + await self.event.wait() + if self.retiserr: + raise self.retval + return self.retval diff --git a/libagent/win_server.py b/libagent/win_server.py index c0594029..bf0bedd1 100644 --- a/libagent/win_server.py +++ b/libagent/win_server.py @@ -1,10 +1,10 @@ """Windows named pipe server simulating a UNIX socket.""" -import contextlib -import ctypes import io import os -import socket +import trio +import trio.lowlevel +import trio.socket import win32api import win32event import win32file @@ -13,31 +13,11 @@ from . import util -kernel32 = ctypes.WinDLL('kernel32', use_last_error=True) - PIPE_BUFFER_SIZE = 64 * 1024 CTRL_C_EVENT = 0 THREAD_SET_CONTEXT = 0x0010 -# Workaround for Ctrl+C not stopping IO on Windows -# See https://github.com/python/cpython/issues/85609 -@contextlib.contextmanager -def ctrl_cancel_async_io(file_handle): - """Listen for SIGINT and translate it to interrupting IO on the specified file handle.""" - @ctypes.WINFUNCTYPE(ctypes.c_uint, ctypes.c_uint) - def ctrl_handler(ctrl_event): - if ctrl_event == CTRL_C_EVENT: - kernel32.CancelIoEx(file_handle, None) - return False - - try: - kernel32.SetConsoleCtrlHandler(ctrl_handler, True) - yield - finally: - kernel32.SetConsoleCtrlHandler(ctrl_handler, False) - - # Based loosely on https://docs.microsoft.com/en-us/windows/win32/ipc/multithreaded-pipe-server class NamedPipe: """A Windows named pipe. @@ -46,8 +26,8 @@ class NamedPipe: or as a client connecting to a listener. """ - @staticmethod - def __close(handle, disconnect): + @classmethod + def __close(cls, handle, disconnect): """Closes a named pipe handle.""" if handle == win32file.INVALID_HANDLE_VALUE: return @@ -56,8 +36,8 @@ def __close(handle, disconnect): win32pipe.DisconnectNamedPipe(handle) win32api.CloseHandle(handle) - @staticmethod - def create(name): + @classmethod + def create(cls, name): """Opens a named pipe server for receiving connections.""" handle = win32pipe.CreateNamedPipe( name, @@ -83,14 +63,14 @@ def create(name): win32event.SetEvent(overlapped.hEvent) if error_code != winerror.ERROR_PIPE_CONNECTED: raise IOError('ConnectNamedPipe failed ({0})'.format(error_code)) - ret = NamedPipe(name, handle, overlapped, pending_io, True) + ret = cls(name, handle, overlapped, pending_io, True) handle = win32file.INVALID_HANDLE_VALUE return ret finally: NamedPipe.__close(handle, True) - @staticmethod - def open(name): + @classmethod + def open(cls, name): """Opens a named pipe server for receiving connections.""" handle = win32file.CreateFile( name, @@ -108,12 +88,20 @@ def open(name): overlapped = win32file.OVERLAPPED() overlapped.hEvent = win32event.CreateEvent(None, True, True, None) win32pipe.SetNamedPipeHandleState(handle, win32pipe.PIPE_READMODE_BYTE, None, None) - ret = NamedPipe(name, handle, overlapped, False, False) + ret = cls(name, handle, overlapped, False, False) handle = win32file.INVALID_HANDLE_VALUE return ret finally: NamedPipe.__close(handle, False) + def __enter__(self): + """Context manager support.""" + return self + + def __exit__(self, *_): + """Context manager support.""" + self.close() + def __init__(self, name, handle, overlapped, pending_io, created): """Should not be called directly. @@ -125,28 +113,19 @@ def __init__(self, name, handle, overlapped, pending_io, created): self.overlapped = overlapped self.pending_io = pending_io self.created = created - self.retain_buf = bytes() - self.timeout = win32event.INFINITE def __del__(self): """Close the named pipe.""" self.close() - def settimeout(self, timeout): - """Sets the timeout for IO operations on the named pipe in milliseconds.""" - self.timeout = win32event.INFINITE if timeout is None else int(timeout * 1000) - def close(self): """Close the named pipe.""" NamedPipe.__close(self.handle, self.created) self.handle = win32file.INVALID_HANDLE_VALUE - def connect(self): - """Connect to a named pipe with the specified timeout.""" - with ctrl_cancel_async_io(self.handle): - waitHandle = win32event.WaitForSingleObject(self.overlapped.hEvent, self.timeout) - if waitHandle == win32event.WAIT_TIMEOUT: - raise TimeoutError('Timed out waiting for client on pipe {0}'.format(self.name)) + async def connect(self): + """Connect to a named pipe.""" + await trio.lowlevel.WaitForSingleObject(int(self.overlapped.hEvent)) if not self.pending_io: return win32pipe.GetOverlappedResult( @@ -158,7 +137,7 @@ def connect(self): return raise IOError('Connection to named pipe {0} failed ({1})'.format(self.name, error_code)) - def recv(self, size): + async def recv(self, size): """Read data from the pipe.""" rbuf = win32file.AllocateReadBuffer(min(size, PIPE_BUFFER_SIZE)) try: @@ -171,8 +150,7 @@ def recv(self, size): if e.winerror == winerror.ERROR_NO_DATA: return None raise - with ctrl_cancel_async_io(self.handle): - win32event.WaitForSingleObject(self.overlapped.hEvent, self.timeout) + await trio.lowlevel.WaitForSingleObject(int(self.overlapped.hEvent)) try: chunk_size = win32pipe.GetOverlappedResult(self.handle, self.overlapped, False) error_code = win32api.GetLastError() @@ -184,67 +162,20 @@ def recv(self, size): return None raise - def send(self, data): + async def send(self, data): """Write from the specified buffer to the pipe.""" error_code, _ = win32file.WriteFile(self.handle, data, self.overlapped) if error_code not in (winerror.NO_ERROR, winerror.ERROR_IO_PENDING, winerror.ERROR_MORE_DATA): raise IOError('WriteFile failed ({0})'.format(error_code)) - with ctrl_cancel_async_io(self.handle): - win32event.WaitForSingleObject(self.overlapped.hEvent, self.timeout) + await trio.lowlevel.WaitForSingleObject(int(self.overlapped.hEvent)) written = win32pipe.GetOverlappedResult(self.handle, self.overlapped, False) error_code = win32api.GetLastError() if error_code != winerror.NO_ERROR: raise IOError('WriteFile failed ({0})'.format(error_code)) return written - def sendall(self, data): - """Send the specified reply to the pipe.""" - while len(data) > 0: - written = self.send(data) - data = data[written:] - - -class InterruptibleSocket: - """A wrapper for sockets which allows IO operations to be interrupted by SIGINT.""" - - def __init__(self, sock): - """Wraps the socket object ``sock``.""" - self.sock = sock - - def __del__(self): - """Close the wrapped socket. It should not outlive the wrapper.""" - self.close() - - def settimeout(self, timeout): - """Forward to underlying socket.""" - self.sock.settimeout(timeout) - - def recv(self, size): - """Forward to underlying socket, while monitoring for SIGINT.""" - try: - with ctrl_cancel_async_io(self.sock.fileno()): - return self.sock.recv(size) - except OSError as e: - if e.winerror == 10054: - # Convert socket close to end of file - return None - raise - - def sendall(self, reply): - """Forward to underlying socket, while monitoring for SIGINT.""" - with ctrl_cancel_async_io(self.sock.fileno()): - return self.sock.sendall(reply) - - def close(self): - """Forward to underlying socket.""" - return self.sock.close() - - def getsockname(self): - """Forward to underlying socket.""" - return self.sock.getsockname() - class Server: """Listend on an emulated AF_UNIX socket on Windows. @@ -252,7 +183,8 @@ class Server: Supports both Gpg4win-style AF_UNIX emulation and OpenSSH-style AF_UNIX emulation """ - def __init__(self, pipe_name): + @classmethod + async def open(cls, pipe_name): """Opens a socket or named pipe. If ``pipe_name`` is a byte string, it is interpreted as a Gpg4win-style socket. @@ -263,44 +195,58 @@ def __init__(self, pipe_name): If it is a string, it is interpreted as an OpenSSH-style socket. The string contains the name of a Windows named pipe. """ - self.timeout = None - self.pipe_name = pipe_name - self.sock = None - self.pipe = None - if not isinstance(self.pipe_name, str): - # GPG simulated socket via localhost socket - self.key = os.urandom(16) - self.sock = socket.socket() - self.sock.bind(('127.0.0.1', 0)) - _, port = self.sock.getsockname() - self.sock.listen(1) + if isinstance(pipe_name, str): + return Server(pipe_name, None, None) + # GPG simulated socket via localhost socket + key = os.urandom(16) + sock_close = sock = trio.socket.socket() + try: + await sock.bind(('127.0.0.1', 0)) + _, port = sock.getsockname() + sock.listen(1) # Write key to file - with open(self.pipe_name, 'wb') as f: - with ctrl_cancel_async_io(f.fileno()): - f.write(str(port).encode()) - f.write(b'\n') - f.write(self.key) + async with await trio.open_file(pipe_name, 'wb') as f: + await f.write(str(port).encode()) + await f.write(b'\n') + await f.write(key) + sock_close = None + return Server(pipe_name, sock, key) + finally: + if sock_close: + sock_close.close() + + def __enter__(self): + """Context manager support.""" + return self + + def __exit__(self, *_): + """Context manager support.""" + self.close() + + def __init__(self, pipe_name, sock, key): + """Should not be called directly. + + Use ``Server.open`` instead. + """ + self.pipe_name = pipe_name + self.sock = sock + self.key = key def __del__(self): """Close the underlying socket or pipe.""" - if self.pipe is not None: - self.pipe.close() - self.pipe = None + self.close() + + def close(self): + """Close the underlying socket or pipe.""" if self.sock is not None: self.sock.close() self.sock = None - def settimeout(self, timeout): - """Set the timeout in seconds.""" - if self.sock: - self.sock.settimeout(timeout) - self.timeout = timeout - def getsockname(self): """Return the file path or pipe name used for creating this named pipe.""" return self.pipe_name - def accept(self): + async def accept(self, retry_invalid_client=True): """Listens for incoming connections on the socket. Returns a pair ``(pipe, address)`` where ``pipe`` is a connected socket-like object @@ -309,28 +255,19 @@ def accept(self): When a named pipe is used, the client's address is the same as the pipe name. """ if self.sock: - with ctrl_cancel_async_io(self.sock.fileno()): - sock, addr = self.sock.accept() - sock = InterruptibleSocket(sock) - sock.settimeout(self.timeout) - if self.key != util.recv(sock, 16): + while True: + sock, addr = await self.sock.accept() + if self.key == await util.recv_async(sock, 16): + break sock.close() - # Simulate timeout on failed connection to allow the caller to retry - raise TimeoutError('Illegitimate client tried to connect to pipe {0}' - .format(self.pipe_name)) - sock.settimeout(None) + if not retry_invalid_client: + raise IOError('Illegitimate client tried to connect to pipe {0}' + .format(self.pipe_name)) return (sock, addr) else: # Named pipe based server - if self.pipe is None: - self.pipe = NamedPipe.create(self.pipe_name) - self.pipe.settimeout(self.timeout) - self.pipe.connect() - self.pipe.settimeout(None) - # A named pipe can only accept a single connection - # It must be recreated if a new connection is to be made - pipe = self.pipe - self.pipe = None + pipe = NamedPipe.create(self.pipe_name) + await pipe.connect() return (pipe, self.pipe_name) @@ -340,7 +277,8 @@ class Client: Supports both Gpg4win-style AF_UNIX emulation and OpenSSH-style AF_UNIX emulation """ - def __init__(self, pipe_name): + @classmethod + async def open(cls, pipe_name): """Connects to a socket or named pipe. If ``pipe_name`` is a byte string, it is interpreted as a Gpg4win-style socket. @@ -350,46 +288,67 @@ def __init__(self, pipe_name): If it is a string, it is interpreted as an OpenSSH-style socket. The string contains the name of a Windows named pipe. """ + if isinstance(pipe_name, str): + return Client(pipe_name, None, NamedPipe.open(pipe_name)) + # Read key from file + async with await trio.open_file(pipe_name, 'rb') as f: + port = io.BytesIO() + while True: + c = await f.read(1) + if not c: + raise OSError('Could not read port for socket {0}'.format(pipe_name)) + if c == b'\n': + break + if c < b'0' or c > b'9': + raise OSError('Could not read port for socket {0}'.format(pipe_name)) + port.write(c) + port = int(port.getvalue()) + key_len = 0 + key = io.BytesIO() + while key_len < 16: + c = await f.read(16-key_len) + if not c: + raise OSError('Could not read nonce for socket {0}'.format(pipe_name)) + key.write(c) + key_len += len(c) + key = key.getvalue() + # Verify end of file + c = await f.read(1) + if c: + raise OSError('Corrupt socket {0}'.format(pipe_name)) + # GPG simulated socket via localhost socket + sock_close = sock = trio.socket.socket() + try: + await sock.connect(('127.0.0.1', port)) + await util.send(sock, key) + sock_close = None + return Client(pipe_name, sock, None) + finally: + if sock_close: + sock_close.close() + + def __enter__(self): + """Context manager support.""" + return self + + def __exit__(self, *_): + """Context manager support.""" + self.close() + + def __init__(self, pipe_name, sock, pipe): + """Should not be called directly. + + Use ``Client.open`` instead. + """ self.pipe_name = pipe_name - self.sock = None - self.pipe = None - if not isinstance(self.pipe_name, str): - # Read key from file - with open(self.pipe_name, 'rb') as f: - with ctrl_cancel_async_io(f.fileno()): - port = io.BytesIO() - while True: - c = f.read(1) - if not c: - raise OSError('Could not read port for socket {0}'.format(pipe_name)) - if c == b'\n': - break - if c < b'0' or c > b'9': - raise OSError('Could not read port for socket {0}'.format(pipe_name)) - port.write(c) - port = int(port.getvalue()) - key_len = 0 - key = io.BytesIO() - while key: - c = f.read(16-key_len) - if not c: - raise OSError('Could not read nonce for socket {0}'.format(pipe_name)) - key.write(c) - key_len += len(c) - key = key.getvalue() - # Verify end of file - c = f.read(1) - if c: - raise OSError('Corrupt socket {0}'.format(pipe_name)) - # GPG simulated socket via localhost socket - sock = socket.socket() - sock.connect(('127.0.0.1', port)) - self.sock = InterruptibleSocket(sock) - self.sock.sendall(key) - else: - self.pipe = NamedPipe.open(pipe_name) + self.sock = sock + self.pipe = pipe def __del__(self): + """Close the underlying socket or named pipe.""" + self.close() + + def close(self): """Close the underlying socket or named pipe.""" if self.pipe is not None: self.pipe.close() @@ -398,25 +357,18 @@ def __del__(self): self.sock.close() self.sock = None - def settimeout(self, timeout): - """Forward to underlying socket or named pipe.""" - if self.sock: - self.sock.settimeout(timeout) - if self.pipe: - self.pipe.settimeout(timeout) - def getsockname(self): """Return the file path or pipe name used for connecting to this named pipe.""" return self.pipe_name - def recv(self, size): + async def recv(self, size): """Forward to underlying socket or named pipe.""" if self.sock is not None: - return self.sock.recv(size) - return self.pipe.recv(size) + return await self.sock.recv(size) + return await self.pipe.recv(size) - def sendall(self, reply): + async def send(self, reply): """Forward to underlying socket or named pipe.""" if self.sock is not None: - return self.sock.sendall(reply) - return self.pipe.sendall(reply) + return await self.sock.send(reply) + return await self.pipe.send(reply) diff --git a/setup.py b/setup.py index aefdcdc1..585ca52a 100755 --- a/setup.py +++ b/setup.py @@ -31,6 +31,8 @@ 'pymsgbox>=1.0.6', 'semver>=2.2', 'unidecode>=0.4.20', + 'trio>=0.22.2', + 'trio-util>=0.7.0', 'pywin32>=300;sys_platform=="win32"' ], platforms=['POSIX', 'win32'], diff --git a/tox.ini b/tox.ini index ca6f3c2d..594b3ca0 100644 --- a/tox.ini +++ b/tox.ini @@ -10,6 +10,7 @@ platform = win: win32 deps= pytest + pytest-trio mock pycodestyle coverage @@ -17,6 +18,8 @@ deps= semver pydocstyle isort + trio + trio-util pywin32;sys_platform=="win32" commands= pycodestyle libagent