From 3f9a6e48f84ea693350e183004d944e7ea56cb87 Mon Sep 17 00:00:00 2001 From: Daniel Bluhm Date: Fri, 19 Apr 2024 15:35:09 -0400 Subject: [PATCH] refactor: make services more flexible By passing in all dependencies to each method that requires it rather than capturing it in the instance state. This should better support more use cases than the previous structure. This is a backwards compatible change; the more flexible services are wrapped by the `DIDCommMessaging` class to behave exactly the same way as before. To use the new, more flexible services, use `DIDCommMessagingService` class. Signed-off-by: Daniel Bluhm --- didcomm_messaging/__init__.py | 113 ++++++++++++++++++++++++++------- didcomm_messaging/packaging.py | 67 ++++++++++--------- didcomm_messaging/routing.py | 53 ++++++++++------ tests/test_packaging.py | 20 ++++-- 4 files changed, 171 insertions(+), 82 deletions(-) diff --git a/didcomm_messaging/__init__.py b/didcomm_messaging/__init__.py index 418e675..df4a947 100644 --- a/didcomm_messaging/__init__.py +++ b/didcomm_messaging/__init__.py @@ -46,24 +46,9 @@ class UnpackResult: sender_kid: Optional[str] = None -class DIDCommMessaging(Generic[P, S]): +class DIDCommMessagingService(Generic[P, S]): """Main entrypoint for DIDComm Messaging.""" - def __init__( - self, - crypto: CryptoService[P, S], - secrets: SecretsManager[S], - resolver: DIDResolver, - packaging: PackagingService[P, S], - routing: RoutingService, - ): - """Initialize the DIDComm Messaging service.""" - self.crypto = crypto - self.secrets = secrets - self.resolver = resolver - self.packaging = packaging - self.routing = routing - def service_to_target(self, service: DIDCommV2Service) -> str: """Convert a service to a target uri. @@ -76,20 +61,49 @@ def service_to_target(self, service: DIDCommV2Service) -> str: return service_endpoint.uri - async def pack(self, message: dict, to: str, frm: Optional[str] = None, **options): + async def pack( + self, + crypto: CryptoService[P, S], + resolver: DIDResolver, + secrets: SecretsManager[S], + packaging: PackagingService[P, S], + routing: RoutingService, + message: dict, + to: str, + frm: Optional[str] = None, + **options, + ): """Pack a message.""" # TODO crypto layer permits packing to multiple recipients; should we as well? - encoded_message = await self.packaging.pack( - json.dumps(message).encode(), [to], frm, **options + encoded_message = await packaging.pack( + crypto, + resolver, + secrets, + json.dumps(message).encode(), + [to], + frm, + **options, ) - forward, services = await self.routing.prepare_forward(to, encoded_message) + forward, services = await routing.prepare_forward( + crypto, packaging, resolver, secrets, to, encoded_message + ) return PackResult(forward, services) - async def unpack(self, encoded_message: bytes, **options) -> UnpackResult: + async def unpack( + self, + crypto: CryptoService[P, S], + resolver: DIDResolver, + secrets: SecretsManager[S], + packaging: PackagingService[P, S], + encoded_message: bytes, + **options, + ) -> UnpackResult: """Unpack a message.""" - unpacked, metadata = await self.packaging.unpack(encoded_message, **options) + unpacked, metadata = await packaging.unpack( + crypto, resolver, secrets, encoded_message, **options + ) message = json.loads(unpacked.decode()) return UnpackResult( message, @@ -98,3 +112,58 @@ async def unpack(self, encoded_message: bytes, **options) -> UnpackResult: recipient_kid=metadata.recip_key.kid, sender_kid=metadata.sender_kid, ) + + +class DIDCommMessaging(Generic[P, S]): + """Main entrypoint for DIDComm Messaging.""" + + def __init__( + self, + crypto: CryptoService[P, S], + secrets: SecretsManager[S], + resolver: DIDResolver, + packaging: PackagingService[P, S], + routing: RoutingService, + ): + """Initialize the DIDComm Messaging service.""" + self.crypto = crypto + self.secrets = secrets + self.resolver = resolver + self.packaging = packaging + self.routing = routing + self.dmp = DIDCommMessagingService() + + async def pack( + self, + message: dict, + to: str, + frm: Optional[str] = None, + **options, + ) -> PackResult: + """Pack a message.""" + return await self.dmp.pack( + self.crypto, + self.resolver, + self.secrets, + self.packaging, + self.routing, + message, + to, + frm, + **options, + ) + + async def unpack( + self, + encoded_message: bytes, + **options, + ) -> UnpackResult: + """Unpack a message.""" + return await self.dmp.unpack( + self.crypto, + self.resolver, + self.secrets, + self.packaging, + encoded_message, + **options, + ) diff --git a/didcomm_messaging/packaging.py b/didcomm_messaging/packaging.py index ffd744d..a102981 100644 --- a/didcomm_messaging/packaging.py +++ b/didcomm_messaging/packaging.py @@ -28,19 +28,8 @@ class PackagingServiceError(Exception): class PackagingService(Generic[P, S]): """DIDComm Messaging interface.""" - def __init__( - self, - resolver: DIDResolver, - crypto: CryptoService[P, S], - secrets: SecretsManager[S], - ): - """Initialize the KMS.""" - self.resolver = resolver - self.crypto = crypto - self.secrets = secrets - async def extract_packed_message_metadata( # noqa: C901 - self, enc_message: Union[str, bytes] + self, enc_message: Union[str, bytes], secrets: SecretsManager[S] ) -> PackedMessageMetadata: """Extract metadata from a packed DIDComm message.""" try: @@ -61,7 +50,7 @@ async def extract_packed_message_metadata( # noqa: C901 sender_kid = None recip_key = None for kid in wrapper.recipient_key_ids: - recip_key = await self.secrets.get_secret_by_kid(kid) + recip_key = await secrets.get_secret_by_kid(kid) if recip_key: break @@ -97,40 +86,42 @@ async def extract_packed_message_metadata( # noqa: C901 return PackedMessageMetadata(wrapper, method, recip_key, sender_kid) async def unpack( - self, enc_message: Union[str, bytes] + self, + crypto: CryptoService[P, S], + resolver: DIDResolver, + secrets: SecretsManager[S], + enc_message: Union[str, bytes], ) -> Tuple[bytes, PackedMessageMetadata]: """Unpack a DIDComm message.""" - metadata = await self.extract_packed_message_metadata(enc_message) + metadata = await self.extract_packed_message_metadata(enc_message, secrets) if metadata.method == "ECDH-ES": return ( - await self.crypto.ecdh_es_decrypt(enc_message, metadata.recip_key), + await crypto.ecdh_es_decrypt(enc_message, metadata.recip_key), metadata, ) if not metadata.sender_kid: raise PackagingServiceError("Missing sender key ID") - sender_vm = await self.resolver.resolve_and_dereference_verification_method( + sender_vm = await resolver.resolve_and_dereference_verification_method( metadata.sender_kid ) - sender_key = self.crypto.verification_method_to_public_key(sender_vm) + sender_key = crypto.verification_method_to_public_key(sender_vm) return ( - await self.crypto.ecdh_1pu_decrypt( - enc_message, metadata.recip_key, sender_key - ), + await crypto.ecdh_1pu_decrypt(enc_message, metadata.recip_key, sender_key), metadata, ) - async def recip_for_kid_or_default_for_did(self, kid_or_did: str) -> P: + async def recip_for_kid_or_default_for_did( + self, crypto: CryptoService[P, S], resolver: DIDResolver, kid_or_did: str + ) -> P: """Resolve a verification method for a kid or return default recip.""" if "#" in kid_or_did: - vm = await self.resolver.resolve_and_dereference_verification_method( - kid_or_did - ) + vm = await resolver.resolve_and_dereference_verification_method(kid_or_did) else: - doc = await self.resolver.resolve_and_parse(kid_or_did) + doc = await resolver.resolve_and_parse(kid_or_did) if not doc.key_agreement: raise PackagingServiceError( "No key agreement methods found; cannot determine recipient" @@ -146,14 +137,14 @@ async def recip_for_kid_or_default_for_did(self, kid_or_did: str) -> P: else: vm = default - return self.crypto.verification_method_to_public_key(vm) + return crypto.verification_method_to_public_key(vm) - async def default_sender_kid_for_did(self, did: str) -> str: + async def default_sender_kid_for_did(self, resolver: DIDResolver, did: str) -> str: """Determine the kid of the default sender key for a DID.""" if "#" in did: return did - doc = await self.resolver.resolve_and_parse(did) + doc = await resolver.resolve_and_parse(did) if not doc.key_agreement: raise PackagingServiceError( "No key agreement methods found; cannot determine recipient" @@ -175,21 +166,27 @@ async def default_sender_kid_for_did(self, did: str) -> str: async def pack( self, + crypto: CryptoService[P, S], + resolver: DIDResolver, + secrets: SecretsManager[S], message: bytes, to: Sequence[str], frm: Optional[str] = None, **options, ): """Pack a DIDComm message.""" - recip_keys = [await self.recip_for_kid_or_default_for_did(kid) for kid in to] - sender_kid = await self.default_sender_kid_for_did(frm) if frm else None - sender_key = ( - await self.secrets.get_secret_by_kid(sender_kid) if sender_kid else None + recip_keys = [ + await self.recip_for_kid_or_default_for_did(crypto, resolver, kid) + for kid in to + ] + sender_kid = ( + await self.default_sender_kid_for_did(resolver, frm) if frm else None ) + sender_key = await secrets.get_secret_by_kid(sender_kid) if sender_kid else None if frm and not sender_key: raise PackagingServiceError("No sender key found") if sender_key: - return await self.crypto.ecdh_1pu_encrypt(recip_keys, sender_key, message) + return await crypto.ecdh_1pu_encrypt(recip_keys, sender_key, message) else: - return await self.crypto.ecdh_es_encrypt(recip_keys, message) + return await crypto.ecdh_es_encrypt(recip_keys, message) diff --git a/didcomm_messaging/routing.py b/didcomm_messaging/routing.py index c3ad43c..c3c1dc1 100644 --- a/didcomm_messaging/routing.py +++ b/didcomm_messaging/routing.py @@ -5,6 +5,7 @@ from typing import Tuple, List, Dict, Any from pydid.service import DIDCommV2Service +from didcomm_messaging.crypto.base import P, S, CryptoService, SecretsManager from didcomm_messaging.packaging import PackagingService from didcomm_messaging.resolver import DIDResolver @@ -16,15 +17,12 @@ class RoutingServiceError(Exception): class RoutingService: """RoutingService.""" - def __init__(self, packaging: PackagingService, resolver: DIDResolver): - """Initialize the RoutingService.""" - self.packaging = packaging - self.resolver = resolver - - async def _resolve_services(self, to: str) -> List[DIDCommV2Service]: - if not await self.resolver.is_resolvable(to): + async def _resolve_services( + self, resolver: DIDResolver, to: str + ) -> List[DIDCommV2Service]: + if not await resolver.is_resolvable(to): return [] - did_doc = await self.resolver.resolve_and_parse(to) + did_doc = await resolver.resolve_and_parse(to) services = [] if did_doc.service: # service is not guaranteed to exist for did_service in did_doc.service: @@ -36,14 +34,16 @@ async def _resolve_services(self, to: str) -> List[DIDCommV2Service]: return [] return services - async def is_forwardable_service(self, service: DIDCommV2Service) -> bool: + async def is_forwardable_service( + self, resolver: DIDResolver, service: DIDCommV2Service + ) -> bool: """Determine if the uri of a service is a service we should forward to.""" endpoint = service.service_endpoint.uri - found_forwardable_service = await self.resolver.is_resolvable(endpoint) + found_forwardable_service = await resolver.is_resolvable(endpoint) return found_forwardable_service def _create_forward_message( - self, to: str, next_target: str, message: str + self, to: str, next_target: str, message: bytes ) -> Dict[Any, Any]: return { "typ": "application/didcomm-plain+json", @@ -64,11 +64,21 @@ def _create_forward_message( } async def prepare_forward( - self, to: str, encoded_message: bytes + self, + crypto: CryptoService[P, S], + packaging: PackagingService, + resolver: DIDResolver, + secrets: SecretsManager[S], + to: str, + encoded_message: bytes, ) -> Tuple[bytes, DIDCommV2Service]: """Prepare a forward message, if necessary. Args: + crypto (CryptoService[P, S]): Crypto service + packaging (PackagingService): Packaging service + resolver (DIDResolver): Resolver instance + secrets (SecretsManager[S]): Secrets manager to (str): The recipient of the message. This will be a DID. encoded_message (bytes): The encoded message. @@ -77,7 +87,7 @@ async def prepare_forward( """ # Get the initial service - services = await self._resolve_services(to) + services = await self._resolve_services(resolver, to) chain = [ { "did": to, @@ -87,9 +97,11 @@ async def prepare_forward( # Loop through service DIDs until we run out of DIDs to forward to to_did = services[0].service_endpoint.uri - found_forwardable_service = await self.is_forwardable_service(services[0]) + found_forwardable_service = await self.is_forwardable_service( + resolver, services[0] + ) while found_forwardable_service: - services = await self._resolve_services(to_did) + services = await self._resolve_services(resolver, to_did) if services: chain.append( { @@ -99,7 +111,9 @@ async def prepare_forward( ) to_did = services[0].service_endpoint.uri found_forwardable_service = ( - await self.is_forwardable_service(services[0]) if services else False + await self.is_forwardable_service(resolver, services[0]) + if services + else False ) if not chain[-1]["service"]: @@ -127,10 +141,13 @@ async def prepare_forward( # Pack for each key while routing_keys: key = routing_keys.pop() # pop from end of list (reverse order) - packed_message = await self.packaging.pack( + packed_message = await packaging.pack( + crypto, + resolver, + secrets, json.dumps( self._create_forward_message(key, next_target, packed_message) - ), + ).encode(), [key], ) next_target = key diff --git a/tests/test_packaging.py b/tests/test_packaging.py index 60a8ff0..036f157 100644 --- a/tests/test_packaging.py +++ b/tests/test_packaging.py @@ -4,11 +4,12 @@ from aries_askar import Key, KeyAlg from didcomm_messaging.crypto.backend.askar import AskarCryptoService, AskarSecretKey from didcomm_messaging.crypto.backend.basic import InMemorySecretsManager +from didcomm_messaging.crypto.base import CryptoService from didcomm_messaging.packaging import PackagingService from didcomm_messaging.multiformats import multibase from didcomm_messaging.multiformats import multicodec from didcomm_messaging.resolver.peer import Peer2, Peer4 -from didcomm_messaging.resolver import PrefixResolver +from didcomm_messaging.resolver import DIDResolver, PrefixResolver from did_peer_2 import KeySpec, generate @@ -25,17 +26,22 @@ def crypto(): @pytest.fixture -def packaging(secrets, crypto): +def resolver(): + yield PrefixResolver({"did:peer:2": Peer2(), "did:peer:4": Peer4()}) + + +@pytest.fixture +def packaging(): """Fixture for packaging.""" - yield PackagingService( - PrefixResolver({"did:peer:2": Peer2(), "did:peer:4": Peer4()}), crypto, secrets - ) + yield PackagingService() # TODO More thorough tests @pytest.mark.asyncio async def test_packer_basic( + crypto: CryptoService, secrets: InMemorySecretsManager, + resolver: DIDResolver, packaging: PackagingService, ): """Test basic packaging. @@ -63,6 +69,6 @@ async def test_packer_basic( await secrets.add_secret(AskarSecretKey(verkey, f"{did}#key-1")) await secrets.add_secret(AskarSecretKey(xkey, f"{did}#key-2")) message = b"hello world" - packed = await packaging.pack(message, [did], did) - unpacked, meta = await packaging.unpack(packed) + packed = await packaging.pack(crypto, resolver, secrets, message, [did], did) + unpacked, meta = await packaging.unpack(crypto, resolver, secrets, packed) assert unpacked == message