Skip to content

Commit e26f09c

Browse files
committed
Async Support/Wrapper for callback functions.
1 parent ed5a551 commit e26f09c

File tree

3 files changed

+210
-13
lines changed

3 files changed

+210
-13
lines changed

pyas2lib/as2.py

Lines changed: 68 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,17 @@
11
"""Define the core functions/classes of the pyas2 package."""
2-
import logging
3-
import hashlib
2+
import asyncio
43
import binascii
4+
import hashlib
5+
import inspect
6+
import logging
57
import traceback
68
from dataclasses import dataclass
79
from email import encoders
810
from email import message as email_message
911
from email import message_from_bytes as parse_mime
1012
from email import utils as email_utils
1113
from email.mime.multipart import MIMEMultipart
14+
1215
from oscrypto import asymmetric
1316

1417
from pyas2lib.cms import (
@@ -564,7 +567,7 @@ def _decompress_data(self, payload):
564567

565568
return False, payload
566569

567-
def parse(
570+
async def aparse(
568571
self,
569572
raw_content,
570573
find_org_cb=None,
@@ -631,22 +634,41 @@ def parse(
631634
# Get the organization and partner for this transmission
632635
org_id = unquote_as2name(as2_headers["as2-to"])
633636
partner_id = unquote_as2name(as2_headers["as2-from"])
637+
634638
if find_org_partner_cb:
635-
self.receiver, self.sender = find_org_partner_cb(org_id, partner_id)
639+
if inspect.iscoroutinefunction(find_org_partner_cb):
640+
self.receiver, self.sender = await find_org_partner_cb(
641+
org_id, partner_id
642+
)
643+
else:
644+
self.receiver, self.sender = find_org_partner_cb(org_id, partner_id)
645+
636646
elif find_org_cb and find_partner_cb:
637-
self.receiver = find_org_cb(org_id)
638-
self.sender = find_partner_cb(partner_id)
647+
if inspect.iscoroutinefunction(find_org_cb):
648+
self.receiver = await find_org_cb(org_id)
649+
else:
650+
self.receiver = find_org_cb(org_id)
651+
652+
if inspect.iscoroutinefunction(find_partner_cb):
653+
self.sender = await find_partner_cb(partner_id)
654+
else:
655+
self.sender = find_partner_cb(partner_id)
639656

640657
if not self.receiver:
641658
raise PartnerNotFound(f"Unknown AS2 organization with id {org_id}")
642659

643660
if not self.sender:
644661
raise PartnerNotFound(f"Unknown AS2 partner with id {partner_id}")
645662

646-
if find_message_cb and find_message_cb(self.message_id, partner_id):
647-
raise DuplicateDocument(
648-
"Duplicate message received, message with this ID already processed."
649-
)
663+
if find_message_cb:
664+
if inspect.iscoroutinefunction(find_message_cb):
665+
message_exists = await find_message_cb(self.message_id, partner_id)
666+
else:
667+
message_exists = find_message_cb(self.message_id, partner_id)
668+
if message_exists:
669+
raise DuplicateDocument(
670+
"Duplicate message received, message with this ID already processed."
671+
)
650672

651673
if (
652674
self.sender.encrypt
@@ -767,6 +789,18 @@ def parse(
767789

768790
return status, exception, mdn
769791

792+
def parse(self, *args, **kwargs):
793+
"""
794+
A synchronous wrapper for the asynchronous parse method.
795+
It runs the parse coroutine in an event loop and returns the result.
796+
"""
797+
loop = asyncio.get_event_loop()
798+
if loop.is_running():
799+
raise RuntimeError(
800+
"Cannot run synchronous parse within an already running event loop, use aparse."
801+
)
802+
return loop.run_until_complete(self.aparse(*args, **kwargs))
803+
770804

771805
class Mdn:
772806
"""Class for handling AS2 MDNs. Includes functions for both
@@ -945,7 +979,7 @@ def build(
945979
f"content:\n {mime_to_bytes(self.payload)}"
946980
)
947981

948-
def parse(self, raw_content, find_message_cb):
982+
async def aparse(self, raw_content, find_message_cb):
949983
"""Function parses the RAW AS2 MDN, verifies it and extracts the
950984
processing status of the orginal AS2 message.
951985
@@ -970,7 +1004,17 @@ def parse(self, raw_content, find_message_cb):
9701004
self.orig_message_id, orig_recipient = self.detect_mdn()
9711005

9721006
# Call the find message callback which should return a Message instance
973-
orig_message = find_message_cb(self.orig_message_id, orig_recipient)
1007+
if inspect.iscoroutinefunction(find_message_cb):
1008+
orig_message = await find_message_cb(
1009+
self.orig_message_id, orig_recipient
1010+
)
1011+
else:
1012+
orig_message = find_message_cb(self.orig_message_id, orig_recipient)
1013+
1014+
if not orig_message:
1015+
status = "failed/Failure"
1016+
details_status = "original-message-not-found"
1017+
return status, details_status
9741018

9751019
if not orig_message:
9761020
status = "failed/Failure"
@@ -1053,6 +1097,18 @@ def parse(self, raw_content, find_message_cb):
10531097
logger.error(f"Failed to parse AS2 MDN\n: {traceback.format_exc()}")
10541098
return status, detailed_status
10551099

1100+
def parse(self, *args, **kwargs):
1101+
"""
1102+
A synchronous wrapper for the asynchronous parse method.
1103+
It runs the parse coroutine in an event loop and returns the result.
1104+
"""
1105+
loop = asyncio.get_event_loop()
1106+
if loop.is_running():
1107+
raise RuntimeError(
1108+
"Cannot run synchronous parse within an already running event loop, use aparse."
1109+
)
1110+
return loop.run_until_complete(self.aparse(*args, **kwargs))
1111+
10561112
def detect_mdn(self):
10571113
"""Function checks if the received raw message is an AS2 MDN or not.
10581114

pyas2lib/tests/test_async.py

Lines changed: 140 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,140 @@
1+
import os
2+
3+
import pytest
4+
5+
from pyas2lib import as2
6+
from pyas2lib.tests import TEST_DIR
7+
8+
with open(os.path.join(TEST_DIR, "payload.txt"), "rb") as fp:
9+
test_data = fp.read()
10+
11+
with open(os.path.join(TEST_DIR, "cert_test.p12"), "rb") as fp:
12+
private_key = fp.read()
13+
14+
with open(os.path.join(TEST_DIR, "cert_test_public.pem"), "rb") as fp:
15+
public_key = fp.read()
16+
17+
org = as2.Organization(
18+
as2_name="some_organization",
19+
sign_key=private_key,
20+
sign_key_pass="test",
21+
decrypt_key=private_key,
22+
decrypt_key_pass="test",
23+
)
24+
partner = as2.Partner(
25+
as2_name="some_partner",
26+
verify_cert=public_key,
27+
encrypt_cert=public_key,
28+
)
29+
30+
31+
async def afind_org(headers):
32+
return org
33+
34+
35+
async def afind_partner(headers):
36+
return partner
37+
38+
39+
async def afind_duplicate_message(message_id, message_recipient):
40+
return True
41+
42+
43+
async def afind_org_partner(as2_org, as2_partner):
44+
return org, partner
45+
46+
47+
@pytest.mark.asyncio
48+
async def test_async_callbacks_with_duplicate_message():
49+
"""Test case where async callbacks are used and a duplicate message is sent to the partner"""
50+
51+
# Build an As2 message to be transmitted to partner
52+
partner.sign = True
53+
partner.encrypt = True
54+
partner.mdn_mode = as2.SYNCHRONOUS_MDN
55+
out_message = as2.Message(org, partner)
56+
out_message.build(test_data)
57+
58+
async def afind_message(message_id, message_recipient):
59+
return out_message
60+
61+
# Parse the generated AS2 message as the partner
62+
raw_out_message = out_message.headers_str + b"\r\n" + out_message.content
63+
in_message = as2.Message()
64+
_, _, mdn = await in_message.aparse(
65+
raw_out_message,
66+
find_org_cb=afind_org,
67+
find_partner_cb=afind_partner,
68+
find_message_cb=afind_duplicate_message,
69+
)
70+
71+
out_mdn = as2.Mdn()
72+
status, detailed_status = await out_mdn.aparse(
73+
mdn.headers_str + b"\r\n" + mdn.content,
74+
find_message_cb=afind_message,
75+
)
76+
assert status == "processed/Warning"
77+
assert detailed_status == "duplicate-document"
78+
79+
80+
@pytest.mark.asyncio
81+
async def test_async_partnership():
82+
"""Test Async Partnership callback"""
83+
84+
# Build an As2 message to be transmitted to partner
85+
out_message = as2.Message(org, partner)
86+
out_message.build(test_data)
87+
raw_out_message = out_message.headers_str + b"\r\n" + out_message.content
88+
89+
# Parse the generated AS2 message as the partner
90+
in_message = as2.Message()
91+
status, _, _ = await in_message.aparse(
92+
raw_out_message, find_org_partner_cb=afind_org_partner
93+
)
94+
95+
# Compare contents of the input and output messages
96+
assert status == "processed"
97+
98+
99+
@pytest.mark.asyncio
100+
async def test_runtime_error():
101+
"""Test to get Runtime error when calling parse instead of aparse from Async Context"""
102+
103+
with pytest.raises(
104+
RuntimeError,
105+
match="Cannot run synchronous parse within an already running event loop, use aparse.",
106+
):
107+
out_message = as2.Message(org, partner)
108+
out_message.build(test_data)
109+
raw_out_message = out_message.headers_str + b"\r\n" + out_message.content
110+
111+
in_message = as2.Message()
112+
status, _, _ = in_message.parse(
113+
raw_out_message, find_org_partner_cb=afind_org_partner
114+
)
115+
116+
with pytest.raises(
117+
RuntimeError,
118+
match="Cannot run synchronous parse within an already running event loop, use aparse.",
119+
):
120+
partner.sign = True
121+
partner.encrypt = True
122+
partner.mdn_mode = as2.SYNCHRONOUS_MDN
123+
out_message = as2.Message(org, partner)
124+
out_message.build(test_data)
125+
126+
# Parse the generated AS2 message as the partner
127+
raw_out_message = out_message.headers_str + b"\r\n" + out_message.content
128+
in_message = as2.Message()
129+
_, _, mdn = await in_message.aparse(
130+
raw_out_message,
131+
find_org_cb=afind_org,
132+
find_partner_cb=afind_partner,
133+
find_message_cb=afind_duplicate_message,
134+
)
135+
136+
out_mdn = as2.Mdn()
137+
_, _ = out_mdn.parse(
138+
mdn.headers_str + b"\r\n" + mdn.content,
139+
find_message_cb=afind_duplicate_message,
140+
)

setup.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,8 @@
77
]
88

99
tests_require = [
10-
"pytest==6.2.5",
10+
"pytest==7.4.4",
11+
"pytest-asyncio==0.21.1",
1112
"toml==0.10.2",
1213
"pytest-cov==2.8.1",
1314
"coverage==5.0.4",

0 commit comments

Comments
 (0)