diff --git a/didcomm_messaging/__init__.py b/didcomm_messaging/__init__.py index 3acc4ad..418e675 100644 --- a/didcomm_messaging/__init__.py +++ b/didcomm_messaging/__init__.py @@ -1,7 +1,7 @@ """DIDComm Messaging.""" from dataclasses import dataclass import json -from typing import Generic, Optional +from typing import Generic, Optional, List from pydid.service import DIDCommV2Service @@ -16,7 +16,23 @@ class PackResult: """Result of packing a message.""" message: bytes - target: str + target_services: List[DIDCommV2Service] + + def get_endpoint(self, protocol: str) -> str: + """Get the first matching endpoint to send the message to.""" + return self.get_service(protocol).service_endpoint.uri + + def get_service(self, protocol: str) -> DIDCommV2Service: + """Get the first matching service to send the message to.""" + return self.filter_services_by_protocol(protocol)[0] + + def filter_services_by_protocol(self, protocol: str) -> List[DIDCommV2Service]: + """Get all services that start with a specific uri protocol.""" + return [ + service + for service in self.target_services + if service.service_endpoint.uri.startswith(protocol) + ] @dataclass @@ -68,8 +84,8 @@ async def pack(self, message: dict, to: str, frm: Optional[str] = None, **option json.dumps(message).encode(), [to], frm, **options ) - forward, service = await self.routing.prepare_forward(to, encoded_message) - return PackResult(forward, self.service_to_target(service)) + forward, services = await self.routing.prepare_forward(to, encoded_message) + return PackResult(forward, services) async def unpack(self, encoded_message: bytes, **options) -> UnpackResult: """Unpack a message.""" diff --git a/didcomm_messaging/resolver/__init__.py b/didcomm_messaging/resolver/__init__.py index 19871fd..c4c5342 100644 --- a/didcomm_messaging/resolver/__init__.py +++ b/didcomm_messaging/resolver/__init__.py @@ -24,6 +24,10 @@ class DIDResolver(ABC): async def resolve(self, did: str) -> dict: """Resolve a DID.""" + @abstractmethod + async def is_resolvable(self, did: str) -> bool: + """Check to see if a DID is resolvable.""" + async def resolve_and_parse(self, did: str) -> DIDDocument: """Resolve a DID and parse the DID document.""" doc = await self.resolve(did) @@ -56,6 +60,13 @@ def __init__(self, resolvers: Dict[str, DIDResolver]): """Initialize the resolver.""" self.resolvers = resolvers + async def is_resolvable(self, did: str) -> bool: + """Check to see if a DID is resolvable.""" + for prefix, resolver in self.resolvers.items(): + if did.startswith(prefix): + return await resolver.is_resolvable(did) + return False + async def resolve(self, did: str) -> dict: """Resolve a DID.""" for prefix, resolver in self.resolvers.items(): diff --git a/didcomm_messaging/resolver/peer.py b/didcomm_messaging/resolver/peer.py index cedc7ba..bc0fa47 100644 --- a/didcomm_messaging/resolver/peer.py +++ b/didcomm_messaging/resolver/peer.py @@ -4,7 +4,10 @@ try: from did_peer_2 import resolve as resolve_peer_2 + from did_peer_2 import PATTERN as peer_2_pattern from did_peer_4 import resolve as resolve_peer_4 + from did_peer_4 import LONG_PATTERN as peer_4_pattern_long + from did_peer_4 import SHORT_PATTERN as peer_4_pattern_short except ImportError: raise ImportError( "did-peer-2 and did-peer-4 are required for did:peer resolution; " @@ -15,6 +18,10 @@ class Peer2(DIDResolver): """did:peer:2 resolver.""" + async def is_resolvable(self, did: str) -> bool: + """Check to see if a DID is resolvable.""" + return peer_2_pattern.match(did) + async def resolve(self, did: str) -> dict: """Resolve a did:peer:2 DID.""" return resolve_peer_2(did) @@ -23,6 +30,10 @@ async def resolve(self, did: str) -> dict: class Peer4(DIDResolver): """did:peer:4 resolver.""" + async def is_resolvable(self, did: str) -> bool: + """Check to see if a DID is resolvable.""" + return peer_4_pattern_short.match(did) or peer_4_pattern_long.match(did) + async def resolve(self, did: str) -> dict: """Resolve a did:peer:4 DID.""" return resolve_peer_4(did) diff --git a/didcomm_messaging/routing.py b/didcomm_messaging/routing.py index 85801a5..b9d9003 100644 --- a/didcomm_messaging/routing.py +++ b/didcomm_messaging/routing.py @@ -1,6 +1,9 @@ """RoutingService interface.""" -from typing import Tuple +import json +import uuid + +from typing import Tuple, List, Dict, Any from pydid.service import DIDCommV2Service from didcomm_messaging.packaging import PackagingService from didcomm_messaging.resolver import DIDResolver @@ -18,24 +21,45 @@ def __init__(self, packaging: PackagingService, resolver: DIDResolver): self.packaging = packaging self.resolver = resolver - async def _resolve_service(self, to: str) -> DIDCommV2Service: - """Resolve the service endpoint for a given DID.""" - doc = await self.resolver.resolve_and_parse(to) - if not doc.service: - raise RoutingServiceError(f"No service endpoint found for {to}") - - first_didcomm_service = next( - ( - service - for service in doc.service - if isinstance(service, DIDCommV2Service) - ), - None, - ) - if not first_didcomm_service: - raise RoutingServiceError(f"No DIDCommV2 service endpoint found for {to}") + async def _resolve_services(self, to: str) -> List[DIDCommV2Service]: + if not await self.resolver.is_resolvable(to): + return [] + did_doc = await self.resolver.resolve_and_parse(to) + services = [] + if did_doc.service: # service is not guaranteed to exist + for did_service in did_doc.service: + if "didcomm/v2" in did_service.service_endpoint.accept: + services.append(did_service) + if not services: + return [] + return services + + async def is_forwardable_service(self, service: DIDCommV2Service) -> bool: + """Determine if the uri of a service is a service we should forward to.""" + endpoint = service.service_endpoint.uri + found_forwardable_service = await self.resolver.is_resolvable(endpoint) + return found_forwardable_service - return first_didcomm_service + def _create_forward_message( + self, to: str, next_target: str, message: str + ) -> Dict[Any, Any]: + return { + "typ": "application/didcomm-plain+json", + "type": "https://didcomm.org/routing/2.0/forward", + "id": str(uuid.uuid4()), + "to": [to], + # "expires_time": 123456, # time to expire the forward message, in epoch time + "body": {"next": next_target}, + "attachments": [ + { + "id": str(uuid.uuid4()), + "media_type": "application/didcomm-encrypted+json", + "data": { + "json": json.loads(message), + }, + }, + ], + } async def prepare_forward( self, to: str, encoded_message: bytes @@ -47,8 +71,63 @@ async def prepare_forward( encoded_message (bytes): The encoded message. Returns: - The encoded message, and the service endpoint to forward to. + The encoded message, and the services to forward to. """ - service = await self._resolve_service(to) - # TODO Do the stuff - return encoded_message, service + + # Get the initial service + services = await self._resolve_services(to) + chain = [ + { + "did": to, + "service": services, + } + ] + + # Loop through service DIDs until we run out of DIDs to forward to + to_did = services[0].service_endpoint.uri + found_forwardable_service = await self.is_forwardable_service(services[0]) + while found_forwardable_service: + services = await self._resolve_services(to_did) + if services: + chain.append( + { + "did": to_did, + "service": services, + } + ) + to_did = services[0].service_endpoint.uri + found_forwardable_service = ( + await self.is_forwardable_service(services[0]) if services else False + ) + + if not chain[-1]["service"]: + raise RoutingServiceError(f"No DIDCommV2 service endpoint found for {to}") + + # Grab our target to pack the initial message to, then pack the message + # for the DID target + next_target = chain.pop(0)["did"] + packed_message = encoded_message + + # Loop through the entire services chain and pack the message for each + # layer of mediators + for service in chain: + # https://identity.foundation/didcomm-messaging/spec/#sender-process-to-enable-forwarding + # Respect routing keys by adding the current DID to the front of + # the list, then wrapping message following routing key order + routing_keys = service["service"][0].service_endpoint.routing_keys + routing_keys.insert(0, service["did"]) # prepend did + + # Pack for each key + while routing_keys: + key = routing_keys.pop() # pop from end of list (reverse order) + packed_message = await self.packaging.pack( + json.dumps( + self._create_forward_message(key, next_target, packed_message) + ), + [key], + ) + next_target = key + + # Return the forward-packed message as well as the last service in the + # chain, which is the destination of the top-level forward message. + return (packed_message, chain[-1]["service"]) diff --git a/tests/test_didresolver.py b/tests/test_didresolver.py index 494fc18..a7b27d4 100644 --- a/tests/test_didresolver.py +++ b/tests/test_didresolver.py @@ -8,6 +8,9 @@ class TestResolver(DIDResolver): + async def is_resolvable(self, did: str) -> bool: + return True + async def resolve(self, did: str) -> dict: return {"did": did}