11"""Define the core functions/classes of the pyas2 package."""
2- import logging
3- import hashlib
2+ import asyncio
43import binascii
4+ import hashlib
5+ import inspect
6+ import logging
57import traceback
68from dataclasses import dataclass
79from email import encoders
810from email import message as email_message
911from email import message_from_bytes as parse_mime
1012from email import utils as email_utils
1113from email .mime .multipart import MIMEMultipart
14+
1215from oscrypto import asymmetric
1316
1417from pyas2lib .cms import (
@@ -564,7 +567,7 @@ def _decompress_data(self, payload):
564567
565568 return False , payload
566569
567- def parse (
570+ async def aparse (
568571 self ,
569572 raw_content ,
570573 find_org_cb = None ,
@@ -631,22 +634,41 @@ def parse(
631634 # Get the organization and partner for this transmission
632635 org_id = unquote_as2name (as2_headers ["as2-to" ])
633636 partner_id = unquote_as2name (as2_headers ["as2-from" ])
637+
634638 if find_org_partner_cb :
635- self .receiver , self .sender = find_org_partner_cb (org_id , partner_id )
639+ if inspect .iscoroutinefunction (find_org_partner_cb ):
640+ self .receiver , self .sender = await find_org_partner_cb (
641+ org_id , partner_id
642+ )
643+ else :
644+ self .receiver , self .sender = find_org_partner_cb (org_id , partner_id )
645+
636646 elif find_org_cb and find_partner_cb :
637- self .receiver = find_org_cb (org_id )
638- self .sender = find_partner_cb (partner_id )
647+ if inspect .iscoroutinefunction (find_org_cb ):
648+ self .receiver = await find_org_cb (org_id )
649+ else :
650+ self .receiver = find_org_cb (org_id )
651+
652+ if inspect .iscoroutinefunction (find_partner_cb ):
653+ self .sender = await find_partner_cb (partner_id )
654+ else :
655+ self .sender = find_partner_cb (partner_id )
639656
640657 if not self .receiver :
641658 raise PartnerNotFound (f"Unknown AS2 organization with id { org_id } " )
642659
643660 if not self .sender :
644661 raise PartnerNotFound (f"Unknown AS2 partner with id { partner_id } " )
645662
646- if find_message_cb and find_message_cb (self .message_id , partner_id ):
647- raise DuplicateDocument (
648- "Duplicate message received, message with this ID already processed."
649- )
663+ if find_message_cb :
664+ if inspect .iscoroutinefunction (find_message_cb ):
665+ message_exists = await find_message_cb (self .message_id , partner_id )
666+ else :
667+ message_exists = find_message_cb (self .message_id , partner_id )
668+ if message_exists :
669+ raise DuplicateDocument (
670+ "Duplicate message received, message with this ID already processed."
671+ )
650672
651673 if (
652674 self .sender .encrypt
@@ -767,6 +789,18 @@ def parse(
767789
768790 return status , exception , mdn
769791
792+ def parse (self , * args , ** kwargs ):
793+ """
794+ A synchronous wrapper for the asynchronous parse method.
795+ It runs the parse coroutine in an event loop and returns the result.
796+ """
797+ loop = asyncio .get_event_loop ()
798+ if loop .is_running ():
799+ raise RuntimeError (
800+ "Cannot run synchronous parse within an already running event loop, use aparse."
801+ )
802+ return loop .run_until_complete (self .aparse (* args , ** kwargs ))
803+
770804
771805class Mdn :
772806 """Class for handling AS2 MDNs. Includes functions for both
@@ -945,7 +979,7 @@ def build(
945979 f"content:\n { mime_to_bytes (self .payload )} "
946980 )
947981
948- def parse (self , raw_content , find_message_cb ):
982+ async def aparse (self , raw_content , find_message_cb ):
949983 """Function parses the RAW AS2 MDN, verifies it and extracts the
950984 processing status of the orginal AS2 message.
951985
@@ -970,7 +1004,17 @@ def parse(self, raw_content, find_message_cb):
9701004 self .orig_message_id , orig_recipient = self .detect_mdn ()
9711005
9721006 # Call the find message callback which should return a Message instance
973- orig_message = find_message_cb (self .orig_message_id , orig_recipient )
1007+ if inspect .iscoroutinefunction (find_message_cb ):
1008+ orig_message = await find_message_cb (
1009+ self .orig_message_id , orig_recipient
1010+ )
1011+ else :
1012+ orig_message = find_message_cb (self .orig_message_id , orig_recipient )
1013+
1014+ if not orig_message :
1015+ status = "failed/Failure"
1016+ details_status = "original-message-not-found"
1017+ return status , details_status
9741018
9751019 if not orig_message :
9761020 status = "failed/Failure"
@@ -1053,6 +1097,18 @@ def parse(self, raw_content, find_message_cb):
10531097 logger .error (f"Failed to parse AS2 MDN\n : { traceback .format_exc ()} " )
10541098 return status , detailed_status
10551099
1100+ def parse (self , * args , ** kwargs ):
1101+ """
1102+ A synchronous wrapper for the asynchronous parse method.
1103+ It runs the parse coroutine in an event loop and returns the result.
1104+ """
1105+ loop = asyncio .get_event_loop ()
1106+ if loop .is_running ():
1107+ raise RuntimeError (
1108+ "Cannot run synchronous parse within an already running event loop, use aparse."
1109+ )
1110+ return loop .run_until_complete (self .aparse (* args , ** kwargs ))
1111+
10561112 def detect_mdn (self ):
10571113 """Function checks if the received raw message is an AS2 MDN or not.
10581114
0 commit comments