Skip to content

Commit

Permalink
fix: incompatibility issues
Browse files Browse the repository at this point in the history
Signed-off-by: Daniel Bluhm <[email protected]>
  • Loading branch information
dbluhm committed Nov 13, 2023
1 parent f089daf commit 91196b3
Show file tree
Hide file tree
Showing 8 changed files with 181 additions and 30 deletions.
39 changes: 23 additions & 16 deletions didcomm_messaging/crypto/backend/askar.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,26 +143,37 @@ async def ecdh_es_encrypt(
except AskarError:
raise CryptoServiceError("Error creating content encryption key")

apv = []
for recip_key in to_keys:
apv.append(recip_key.kid)
apv.sort()
apv = hashlib.sha256((".".join(apv)).encode()).digest()

for recip_key in to_keys:
try:
epk = Key.generate(recip_key.key.algorithm, ephemeral=True)
except AskarError:
raise CryptoServiceError("Error creating ephemeral key")
enc_key = ecdh.EcdhEs(alg_id, None, None).sender_wrap_key( # type: ignore
enc_key = ecdh.EcdhEs(alg_id, None, apv).sender_wrap_key( # type: ignore
wrap_alg, epk, recip_key.key, cek
)
builder.add_recipient(
JweRecipient(
encrypted_key=enc_key.ciphertext,
header={"kid": recip_key.kid, "epk": epk.get_jwk_public()},
header={
"kid": recip_key.kid,
"epk": json.loads(epk.get_jwk_public()),
},
)
)

builder.set_protected(
OrderedDict(
[
("typ", "application/didcomm-encrypted+json"),
("alg", alg_id),
("enc", enc_id),
("apv", b64url(apv)),
]
)
)
Expand Down Expand Up @@ -218,12 +229,10 @@ async def ecdh_es_decrypt(
except AskarError:
raise CryptoServiceError("Error loading ephemeral key")

apu = recip.header.get("apu")
apv = recip.header.get("apv")
# apu and apv are allowed to be None

try:
cek = ecdh.EcdhEs(alg_id, apu, apv).receiver_unwrap_key( # type: ignore
cek = ecdh.EcdhEs(
alg_id, None, wrapper.apv_bytes
).receiver_unwrap_key( # type: ignore
wrap_alg,
enc_alg,
epk,
Expand Down Expand Up @@ -273,7 +282,7 @@ async def ecdh_1pu_encrypt(
except AskarError:
raise CryptoServiceError("Error creating ephemeral key")

apu = b64url(sender_key.kid)
apu = sender_key.kid
apv = []
for recip_key in to_keys:
if agree_alg:
Expand All @@ -283,15 +292,15 @@ async def ecdh_1pu_encrypt(
agree_alg = recip_key.key.algorithm
apv.append(recip_key.kid)
apv.sort()
apv = b64url(hashlib.sha256((".".join(apv)).encode()).digest())
apv = hashlib.sha256((".".join(apv)).encode()).digest()

builder.set_protected(
OrderedDict(
[
("alg", alg_id),
("enc", enc_id),
("apu", apu),
("apv", apv),
("apu", b64url(apu)),
("apv", b64url(apv)),
("epk", json.loads(epk.get_jwk_public())),
("skid", sender_key.kid),
]
Expand Down Expand Up @@ -351,12 +360,10 @@ async def ecdh_1pu_decrypt(
except AskarError:
raise CryptoServiceError("Error loading ephemeral key")

apu = wrapper.protected.get("apu")
apv = wrapper.protected.get("apv")
# apu and apv are allowed to be None

try:
cek = ecdh.Ecdh1PU(alg_id, apu, apv).receiver_unwrap_key( # type: ignore
cek = ecdh.Ecdh1PU(
alg_id, wrapper.apu_bytes, wrapper.apv_bytes
).receiver_unwrap_key( # type: ignore
wrap_alg,
enc_alg,
epk,
Expand Down
40 changes: 34 additions & 6 deletions didcomm_messaging/crypto/backend/authlib.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,15 @@
from typing import Mapping, Optional, Sequence, Tuple, Union

from pydid import VerificationMethod

from didcomm_messaging.crypto.base import (
CryptoService,
CryptoServiceError,
PublicKey,
SecretKey,
)
from didcomm_messaging.multiformats import multibase, multicodec
from didcomm_messaging.multiformats.multibase import Base64UrlEncoder
from didcomm_messaging.crypto.base import CryptoService, PublicKey, SecretKey


try:
Expand Down Expand Up @@ -136,7 +142,7 @@ def verification_method_to_public_key(cls, vm: VerificationMethod) -> AuthlibKey
"""Return a PublicKey from a verification method."""
return AuthlibKey.from_verification_method(vm)

def _build_header(
def _build_header_ecdh_1pu(
self, to: Sequence[AuthlibKey], frm: AuthlibSecretKey, alg: str, enc: str
):
skid = frm.kid
Expand All @@ -155,17 +161,39 @@ def _build_header(
recipients = [{"header": {"kid": kid}} for kid in kids]
return {"protected": protected, "recipients": recipients}

def _build_header_ecdh_es(self, to: Sequence[AuthlibKey], alg: str, enc: str):
kids = [to_key.kid for to_key in to]

apv = b64url.encode(hashlib.sha256((".".join(sorted(kids))).encode()).digest())
protected = {
"typ": "application/didcomm-encrypted+json",
"alg": alg,
"enc": enc,
"apv": apv,
}
recipients = [{"header": {"kid": kid}} for kid in kids]
return {"protected": protected, "recipients": recipients}

async def ecdh_es_encrypt(
self, to_keys: Sequence[AuthlibKey], message: bytes
) -> bytes:
"""Encrypt a message using ECDH-ES."""
return await super().ecdh_es_encrypt(to_keys, message)
header = self._build_header_ecdh_es(to_keys, "ECDH-ES+A256KW", "XC20P")
jwe = JsonWebEncryption()
res = jwe.serialize_json(header, message, [value.key for value in to_keys])
return json.dumps(res).encode()

async def ecdh_es_decrypt(
self, enc_message: Union[str, bytes], recip_key: AuthlibSecretKey
) -> bytes:
"""Decrypt a message using ECDH-ES."""
return await super().ecdh_es_decrypt(enc_message, recip_key)
try:
jwe = JsonWebEncryption()
res = jwe.deserialize_json(enc_message, recip_key.key)
except Exception as err:
raise CryptoServiceError("Invalid JWE") from err

return res["payload"]

async def ecdh_1pu_encrypt(
self,
Expand All @@ -174,7 +202,7 @@ async def ecdh_1pu_encrypt(
message: bytes,
) -> bytes:
"""Encrypt a message using ECDH-1PU."""
header = self._build_header(
header = self._build_header_ecdh_1pu(
to_keys, sender_key, "ECDH-1PU+A256KW", "A256CBC-HS512"
)
jwe = JsonWebEncryption()
Expand All @@ -196,6 +224,6 @@ async def ecdh_1pu_decrypt(
enc_message, recip_key.key, sender_key=sender_key.key
)
except Exception as err:
raise ValueError("Invalid JWE") from err
raise CryptoServiceError("Invalid JWE") from err

return res["payload"]
18 changes: 16 additions & 2 deletions didcomm_messaging/crypto/jwe.py
Original file line number Diff line number Diff line change
Expand Up @@ -367,7 +367,7 @@ def get_recipients(self) -> Iterable[JweRecipient]:
outer envelope.
"""
header = self.protected.copy()
header.update(self.unprotected)
header.update(self.unprotected or {})
for recip in self.recipients:
if recip.header:
recip_h = header.copy()
Expand All @@ -393,7 +393,7 @@ def get_recipient(self, kid: str) -> JweRecipient:
for recip in self.recipients:
if recip.header and recip.header.get("kid") == kid:
header = self.protected.copy()
header.update(self.unprotected)
header.update(self.unprotected or {})
header.update(recip.header)
return JweRecipient(encrypted_key=recip.encrypted_key, header=header)
raise ValueError(f"Unknown recipient: {kid}")
Expand All @@ -405,3 +405,17 @@ def combined_aad(self) -> bytes:
if self.aad:
aad += b"." + b64url(self.aad).encode("utf-8")
return aad

@property
def apu_bytes(self) -> bytes:
"""Accessor for the Agreement PartyUInfo."""
if "apu" in self.protected:
return from_b64url(self.protected["apu"])
raise ValueError("Missing apu")

@property
def apv_bytes(self) -> bytes:
"""Accessor for the Agreement PartyVInfo."""
if "apv" in self.protected:
return from_b64url(self.protected["apv"])
raise ValueError("Missing apv")
31 changes: 30 additions & 1 deletion pdm.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ did_peer = [
]
authlib = [
"authlib>=1.2.1",
"pycryptodomex>=3.19.0",
]

[build-system]
Expand Down
17 changes: 17 additions & 0 deletions tests/crypto/test_askar.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,3 +35,20 @@ async def test_1pu_round_trip(crypto: AskarCryptoService):

plaintext = await crypto.ecdh_1pu_decrypt(enc_message, bob_priv_key, alice_key)
assert plaintext == MESSAGE


@pytest.mark.asyncio
async def test_es_round_trip(crypto: AskarCryptoService):
"""Test ECDH-ES round trip."""
alg = KeyAlg.X25519
alice_sk = Key.generate(alg)
alice_pk = Key.from_jwk(alice_sk.get_jwk_public())
bob_sk = Key.generate(alg)
bob_pk = Key.from_jwk(bob_sk.get_jwk_public())
bob_key = AskarKey(bob_sk, BOB_KID)
bob_priv_key = AskarSecretKey(bob_sk, BOB_KID)
alice_key = AskarKey(alice_sk, ALICE_KID)
alice_priv_key = AskarSecretKey(alice_sk, ALICE_KID)
enc_message = await crypto.ecdh_es_encrypt([bob_key], MESSAGE)
plaintext = await crypto.ecdh_es_decrypt(enc_message, bob_priv_key)
assert plaintext == MESSAGE
45 changes: 40 additions & 5 deletions tests/crypto/test_askar_x_authlib.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ def authlib():


@pytest.mark.asyncio
async def test_compat(
async def test_compat_ecdh_1pu(
askar: AskarCryptoService,
authlib: AuthlibCryptoService,
alice: tuple[AskarSecretKey, AuthlibKey],
Expand All @@ -77,11 +77,46 @@ async def test_compat(
bob_sk, bob_pk = bob

to_alice = b"Dear alice, please decrypt this"
enc_message = await authlib.ecdh_1pu_encrypt([alice_pk], bob_sk, to_alice)
plaintext = await askar.ecdh_1pu_decrypt(enc_message, alice_sk, bob_pk)
alice_enc_message = await authlib.ecdh_1pu_encrypt([alice_pk], bob_sk, to_alice)
print(alice_enc_message)

plaintext = await askar.ecdh_1pu_decrypt(alice_enc_message, alice_sk, bob_pk)
assert plaintext == to_alice

to_bob = b"Dear bob, please decrypt this"
bob_enc_message = await askar.ecdh_1pu_encrypt([bob_pk], alice_sk, to_bob)

print(bob_enc_message)

plaintext = await authlib.ecdh_1pu_decrypt(bob_enc_message, bob_sk, alice_pk)
assert plaintext == to_bob


@pytest.mark.asyncio
async def test_compat_ecdh_es(
askar: AskarCryptoService,
authlib: AuthlibCryptoService,
alice: tuple[AskarSecretKey, AuthlibKey],
bob: tuple[AuthlibSecretKey, AskarKey],
):
"""Test compabibility between Askar and Authlib.
Alice uses Askar, Bob uses Authlib.
"""
alice_sk, alice_pk = alice
bob_sk, bob_pk = bob

to_alice = b"Dear alice, please decrypt this"
alice_enc_message = await authlib.ecdh_es_encrypt([alice_pk], to_alice)
print(alice_enc_message)

plaintext = await askar.ecdh_es_decrypt(alice_enc_message, alice_sk)
assert plaintext == to_alice

to_bob = b"Dear bob, please decrypt this"
enc_message = await askar.ecdh_1pu_encrypt([bob_pk], alice_sk, to_bob)
plaintext = await authlib.ecdh_1pu_decrypt(enc_message, bob_sk, alice_pk)
bob_enc_message = await askar.ecdh_es_encrypt([bob_pk], to_bob)

print(bob_enc_message)

plaintext = await authlib.ecdh_es_decrypt(bob_enc_message, bob_sk)
assert plaintext == to_bob
20 changes: 20 additions & 0 deletions tests/crypto/test_authlib.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,3 +47,23 @@ async def test_1pu_round_trip(crypto: AuthlibCryptoService):

plaintext = await crypto.ecdh_1pu_decrypt(enc_message, bob_priv_key, alice_key)
assert plaintext == MESSAGE


@pytest.mark.asyncio
async def test_es_round_trip(crypto: AuthlibCryptoService):
"""Test ECDH-ES round trip."""
alice_sk = OKPKey.generate_key("X25519", is_private=True)
alice_pk = alice_sk.get_public_key()
bob_sk = OKPKey.generate_key("X25519", is_private=True)
bob_pk = bob_sk.get_public_key()

bob_key = AuthlibKey(bob_sk, BOB_KID)
bob_priv_key = AuthlibSecretKey(bob_sk, BOB_KID)

alice_key = AuthlibKey(alice_sk, ALICE_KID)
alice_priv_key = AuthlibSecretKey(alice_sk, ALICE_KID)

enc_message = await crypto.ecdh_es_encrypt([bob_key], MESSAGE)

plaintext = await crypto.ecdh_es_decrypt(enc_message, bob_priv_key)
assert plaintext == MESSAGE

0 comments on commit 91196b3

Please sign in to comment.