Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 0 additions & 31 deletions lambdas/id_sync/src/pds_details.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,37 +2,6 @@
Operations related to PDS (Patient Demographic Service)
"""

import tempfile

from common.api_clients.authentication import AppRestrictedAuth, Service
from common.api_clients.pds_service import PdsService
from common.cache import Cache
from common.clients import get_secrets_manager_client, logger
from exceptions.id_sync_exception import IdSyncException
from os_vars import get_pds_env

pds_env = get_pds_env()
safe_tmp_dir = tempfile.mkdtemp(dir="/tmp")


# Get Patient details from external service PDS using NHS number from MNS notification
def pds_get_patient_details(nhs_number: str) -> dict:
try:
cache = Cache(directory=safe_tmp_dir)
authenticator = AppRestrictedAuth(
service=Service.PDS,
secret_manager_client=get_secrets_manager_client(),
environment=pds_env,
cache=cache,
)
pds_service = PdsService(authenticator, pds_env)
patient = pds_service.get_patient_details(nhs_number)
return patient
except Exception as e:
msg = "Error retrieving patient details from PDS"
logger.exception(msg)
raise IdSyncException(message=msg) from e


def get_nhs_number_from_pds_resource(pds_resource: dict) -> str:
"""Simple helper to get the NHS Number from a PDS Resource. No handling as this is a mandatory field in the PDS
Expand Down
15 changes: 11 additions & 4 deletions lambdas/mns_publisher/src/process_records.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import json
import os
from typing import Tuple

from aws_lambda_typing.events.sqs import SQSMessage

Expand All @@ -11,7 +10,15 @@
from create_notification import create_mns_notification

mns_env = os.getenv("MNS_ENV", "int")
MNS_TEST_QUEUE_URL = os.getenv("MNS_TEST_QUEUE_URL")
_mns_service: MnsService | MockMnsService | None = None


def _get_runtime_mns_service() -> MnsService | MockMnsService:
global _mns_service
if _mns_service is None:
_mns_service = get_mns_service(mns_env=mns_env)

return _mns_service


def process_records(records: list[SQSMessage]) -> dict[str, list]:
Expand All @@ -21,7 +28,7 @@ def process_records(records: list[SQSMessage]) -> dict[str, list]:
Returns: List of failed item identifiers for partial batch failure
"""
batch_item_failures = []
mns_service = get_mns_service(mns_env=mns_env)
mns_service = _get_runtime_mns_service()

for record in records:
try:
Expand Down Expand Up @@ -68,7 +75,7 @@ def process_record(record: SQSMessage, mns_service: MnsService | MockMnsService)
logger.info("Successfully created MNS notification", extra={"mns_notification_id": notification_id})


def extract_trace_ids(record: SQSMessage) -> Tuple[str, str | None]:
def extract_trace_ids(record: SQSMessage) -> tuple[str, str | None]:
"""
Extract identifiers for tracing from SQS record.
Returns: Tuple of (message_id, immunisation_id)
Expand Down
10 changes: 5 additions & 5 deletions lambdas/mns_publisher/tests/test_lambda_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,7 @@ def setUpClass(cls):
cls.sample_sqs_record = load_sample_sqs_event()

@patch("process_records.logger")
@patch("process_records.get_mns_service")
@patch("process_records._get_runtime_mns_service")
@patch("process_records.process_record")
def test_process_records_all_success(self, mock_process_record, mock_get_mns, mock_logger):
"""Test processing multiple records with all successes."""
Expand All @@ -145,7 +145,7 @@ def test_process_records_all_success(self, mock_process_record, mock_get_mns, mo
mock_logger.info.assert_called_with("Successfully processed all 2 messages")

@patch("process_records.logger")
@patch("process_records.get_mns_service")
@patch("process_records._get_runtime_mns_service")
@patch("process_records.process_record")
def test_process_records_partial_failure(self, mock_process_record, mock_get_mns, mock_logger):
"""Test processing with some failures."""
Expand All @@ -167,7 +167,7 @@ def test_process_records_partial_failure(self, mock_process_record, mock_get_mns
mock_logger.warning.assert_called_with("Batch completed with 1 failures")

@patch("process_records.logger")
@patch("process_records.get_mns_service")
@patch("process_records._get_runtime_mns_service")
@patch("process_records.process_record")
def test_process_records_empty_list(self, mock_process_record, mock_get_mns, mock_logger):
"""Test processing empty record list."""
Expand All @@ -181,7 +181,7 @@ def test_process_records_empty_list(self, mock_process_record, mock_get_mns, moc
mock_logger.info.assert_called_with("Successfully processed all 0 messages")

@patch("process_records.logger")
@patch("process_records.get_mns_service")
@patch("process_records._get_runtime_mns_service")
@patch("process_records.process_record")
def test_process_records_mns_service_created_once(self, mock_process_record, mock_get_mns, mock_logger):
"""Test that MNS service is created only once for batch."""
Expand Down Expand Up @@ -300,7 +300,7 @@ def test_successful_notification_creation_with_gp(self, mock_logger, mock_get_to

@responses.activate
@patch("common.api_clients.authentication.AppRestrictedAuth.get_access_token")
@patch("process_records.get_mns_service")
@patch("process_records._get_runtime_mns_service")
@patch("process_records.logger")
def test_pds_failure(self, mock_logger, mock_get_mns, mock_get_token):
"""
Expand Down
103 changes: 50 additions & 53 deletions lambdas/shared/src/common/api_clients/authentication.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,96 +2,93 @@
import json
import time
import uuid
from enum import Enum
from typing import Any

import jwt
import requests

from common.clients import logger
from common.models.errors import UnhandledResponseError

from ..cache import Cache
GRANT_TYPE_CLIENT_CREDENTIALS = "client_credentials"
CLIENT_ASSERTION_TYPE_JWT_BEARER = "urn:ietf:params:oauth:client-assertion-type:jwt-bearer"
CONTENT_TYPE_X_WWW_FORM_URLENCODED = "application/x-www-form-urlencoded"


class Service(Enum):
PDS = "pds"
IMMUNIZATION = "imms"
JWT_EXPIRY_SECONDS = 5 * 60
ACCESS_TOKEN_EXPIRY_SECONDS = 10 * 60
# Throw away the cached token earlier than the exact expiry time so we have enough
# time left to use it (and to account for network latency, clock skew etc.)
ACCESS_TOKEN_MIN_ACCEPTABLE_LIFETIME_SECONDS = 30


class AppRestrictedAuth:
def __init__(self, service: Service, secret_manager_client, environment, cache: Cache):
def __init__(self, secret_manager_client: Any, environment: str, secret_name: str | None = None):
self.secret_manager_client = secret_manager_client
self.cache = cache
self.cache_key = f"{service.value}_access_token"

self.expiry = 30
self.secret_name = (
f"imms/pds/{environment}/jwt-secrets"
if service == Service.PDS
else f"imms/immunization/{environment}/jwt-secrets"
)

self.cached_access_token: str | None = None
self.cached_access_token_expiry_time: int | None = None

self.secret_name = f"imms/pds/{environment}/jwt-secrets" if secret_name is None else secret_name

self.token_url = (
f"https://{environment}.api.service.nhs.uk/oauth2/token"
if environment != "prod"
else "https://api.service.nhs.uk/oauth2/token"
)

def get_service_secrets(self):
kwargs = {"SecretId": self.secret_name}
response = self.secret_manager_client.get_secret_value(**kwargs)
def get_service_secrets(self) -> dict[str, Any]:
response = self.secret_manager_client.get_secret_value(SecretId=self.secret_name)
secret_object = json.loads(response["SecretString"])
secret_object["private_key"] = base64.b64decode(secret_object["private_key_b64"]).decode()

return secret_object

def create_jwt(self, now: int):
logger.info("create_jwt")
def create_jwt(self, now: int) -> str:
secret_object = self.get_service_secrets()
claims = {
"iss": secret_object["api_key"],
"sub": secret_object["api_key"],
"aud": self.token_url,
"iat": now,
"exp": now + self.expiry,
"jti": str(uuid.uuid4()),
}

return jwt.encode(
claims,
{
"iss": secret_object["api_key"],
"sub": secret_object["api_key"],
"aud": self.token_url,
"iat": now,
"exp": now + JWT_EXPIRY_SECONDS,
"jti": str(uuid.uuid4()),
},
secret_object["private_key"],
algorithm="RS512",
headers={"kid": secret_object["kid"]},
)

def get_access_token(self):
logger.info("get_access_token")
def get_access_token(self) -> str:
now = int(time.time())
logger.info(f"Current time: {now}, Expiry time: {now + self.expiry}")
# Check if token is cached and not expired
logger.info(f"Cache key: {self.cache_key}")
logger.info("Checking cache for access token")
cached = self.cache.get(self.cache_key)

if cached and cached["expires_at"] > now:
logger.info("Returning cached access token")
return cached["token"]
if (
self.cached_access_token
and self.cached_access_token_expiry_time > now + ACCESS_TOKEN_MIN_ACCEPTABLE_LIFETIME_SECONDS
):
return self.cached_access_token

logger.info("No valid cached token found, creating new token")
logger.info("Requesting new access token")
_jwt = self.create_jwt(now)

headers = {"Content-Type": "application/x-www-form-urlencoded"}
data = {
"grant_type": "client_credentials",
"client_assertion_type": "urn:ietf:params:oauth:client-assertion-type:jwt-bearer",
"client_assertion": _jwt,
}
token_response = requests.post(self.token_url, data=data, headers=headers)
try:
token_response = requests.post(
self.token_url,
data={
"grant_type": GRANT_TYPE_CLIENT_CREDENTIALS,
"client_assertion_type": CLIENT_ASSERTION_TYPE_JWT_BEARER,
"client_assertion": _jwt,
},
headers={"Content-Type": CONTENT_TYPE_X_WWW_FORM_URLENCODED},
timeout=10,
)
except requests.RequestException as error:
logger.exception("Failed to fetch access token from %s", self.token_url)
raise UnhandledResponseError(response=str(error), message="Failed to get access token") from error

if token_response.status_code != 200:
raise UnhandledResponseError(response=token_response.text, message="Failed to get access token")

token = token_response.json().get("access_token")

self.cache.put(self.cache_key, {"token": token, "expires_at": now + self.expiry})

self.cached_access_token = token
self.cached_access_token_expiry_time = now + ACCESS_TOKEN_EXPIRY_SECONDS
return token
27 changes: 15 additions & 12 deletions lambdas/shared/src/common/api_clients/get_pds_details.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,30 +3,33 @@
"""

import os
import tempfile

from common.api_clients.authentication import AppRestrictedAuth, Service
from common.api_clients.authentication import AppRestrictedAuth
from common.api_clients.errors import PdsSyncException
from common.api_clients.pds_service import PdsService
from common.cache import Cache
from common.clients import get_secrets_manager_client, logger

PDS_ENV = os.getenv("PDS_ENV", "int")
safe_tmp_dir = tempfile.mkdtemp(dir="/tmp") # NOSONAR(S5443)

_pds_service: PdsService | None = None

# Get Patient details from external service PDS using NHS number from MNS notification
def pds_get_patient_details(nhs_number: str) -> dict:
try:
cache = Cache(directory=safe_tmp_dir)

def get_pds_service() -> PdsService:
global _pds_service
if _pds_service is None:
authenticator = AppRestrictedAuth(
service=Service.PDS,
secret_manager_client=get_secrets_manager_client(),
environment=PDS_ENV,
cache=cache,
)
pds_service = PdsService(authenticator, PDS_ENV)
patient = pds_service.get_patient_details(nhs_number)
_pds_service = PdsService(authenticator, PDS_ENV)

return _pds_service


# Get Patient details from external service PDS using NHS number from MNS notification
def pds_get_patient_details(nhs_number: str) -> dict:
try:
patient = get_pds_service().get_patient_details(nhs_number)
return patient
except Exception as e:
msg = "Error retrieving patient details from PDS"
Expand Down
6 changes: 3 additions & 3 deletions lambdas/shared/src/common/api_clients/mns_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@
class MnsService:
def __init__(self, authenticator: AppRestrictedAuth):
self.authenticator = authenticator
self.access_token = self.authenticator.get_access_token()
logging.info(f"Using SQS ARN for subscription: {SQS_ARN}")

def _build_subscription_payload(self, event_type: str, reason: str | None = None, status: str = "requested") -> dict:
Expand Down Expand Up @@ -54,9 +53,10 @@ def _build_subscription_payload(self, event_type: str, reason: str | None = None

def _build_headers(self, content_type: str = "application/fhir+json") -> dict:
"""Build request headers with authentication and correlation ID."""
access_token = self.authenticator.get_access_token()
return {
"Content-Type": content_type,
"Authorization": f"Bearer {self.access_token}",
"Authorization": f"Bearer {access_token}",
"X-Correlation-ID": str(uuid.uuid4()),
}

Expand Down Expand Up @@ -138,7 +138,7 @@ def check_delete_subscription(self):
return f"Error deleting subscription: {str(e)}"

def publish_notification(self, notification_payload: MnsNotificationPayload) -> dict | None:
response = requests.request(
response = request_with_retry_backoff(
"POST",
f"{MNS_BASE_URL}/events",
headers=self._build_headers(content_type="application/cloudevents+json"),
Expand Down
6 changes: 1 addition & 5 deletions lambdas/shared/src/common/api_clients/mns_setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,10 @@
import boto3
from botocore.config import Config

from common.api_clients.authentication import AppRestrictedAuth, Service
from common.api_clients.authentication import AppRestrictedAuth
from common.api_clients.constants import DEV_ENVIRONMENT
from common.api_clients.mns_service import MnsService
from common.api_clients.mock_mns_service import MockMnsService
from common.cache import Cache

logging.basicConfig(level=logging.INFO)
MNS_TEST_QUEUE_URL = os.getenv("MNS_TEST_QUEUE_URL")
Expand All @@ -20,13 +19,10 @@ def get_mns_service(mns_env: str = "int"):
return MockMnsService(MNS_TEST_QUEUE_URL)
else:
boto_config = Config(region_name="eu-west-2")
cache = Cache(directory="/tmp")
logging.info("Creating authenticator...")
authenticator = AppRestrictedAuth(
service=Service.PDS,
secret_manager_client=boto3.client("secretsmanager", config=boto_config),
environment=mns_env,
cache=cache,
)
logging.info("Authentication Initiated...")
return MnsService(authenticator)
Loading
Loading