diff --git a/lambdas/id_sync/src/pds_details.py b/lambdas/id_sync/src/pds_details.py index 62ef6c247..27492ceb7 100644 --- a/lambdas/id_sync/src/pds_details.py +++ b/lambdas/id_sync/src/pds_details.py @@ -2,37 +2,6 @@ Operations related to PDS (Patient Demographic Service) """ -import tempfile - -from common.api_clients.authentication import AppRestrictedAuth, Service -from common.api_clients.pds_service import PdsService -from common.cache import Cache -from common.clients import get_secrets_manager_client, logger -from exceptions.id_sync_exception import IdSyncException -from os_vars import get_pds_env - -pds_env = get_pds_env() -safe_tmp_dir = tempfile.mkdtemp(dir="/tmp") - - -# Get Patient details from external service PDS using NHS number from MNS notification -def pds_get_patient_details(nhs_number: str) -> dict: - try: - cache = Cache(directory=safe_tmp_dir) - authenticator = AppRestrictedAuth( - service=Service.PDS, - secret_manager_client=get_secrets_manager_client(), - environment=pds_env, - cache=cache, - ) - pds_service = PdsService(authenticator, pds_env) - patient = pds_service.get_patient_details(nhs_number) - return patient - except Exception as e: - msg = "Error retrieving patient details from PDS" - logger.exception(msg) - raise IdSyncException(message=msg) from e - def get_nhs_number_from_pds_resource(pds_resource: dict) -> str: """Simple helper to get the NHS Number from a PDS Resource. No handling as this is a mandatory field in the PDS diff --git a/lambdas/mns_publisher/src/process_records.py b/lambdas/mns_publisher/src/process_records.py index e55924d70..594a5e5d7 100644 --- a/lambdas/mns_publisher/src/process_records.py +++ b/lambdas/mns_publisher/src/process_records.py @@ -1,6 +1,5 @@ import json import os -from typing import Tuple from aws_lambda_typing.events.sqs import SQSMessage @@ -11,7 +10,15 @@ from create_notification import create_mns_notification mns_env = os.getenv("MNS_ENV", "int") -MNS_TEST_QUEUE_URL = os.getenv("MNS_TEST_QUEUE_URL") +_mns_service: MnsService | MockMnsService | None = None + + +def _get_runtime_mns_service() -> MnsService | MockMnsService: + global _mns_service + if _mns_service is None: + _mns_service = get_mns_service(mns_env=mns_env) + + return _mns_service def process_records(records: list[SQSMessage]) -> dict[str, list]: @@ -21,7 +28,7 @@ def process_records(records: list[SQSMessage]) -> dict[str, list]: Returns: List of failed item identifiers for partial batch failure """ batch_item_failures = [] - mns_service = get_mns_service(mns_env=mns_env) + mns_service = _get_runtime_mns_service() for record in records: try: @@ -68,7 +75,7 @@ def process_record(record: SQSMessage, mns_service: MnsService | MockMnsService) logger.info("Successfully created MNS notification", extra={"mns_notification_id": notification_id}) -def extract_trace_ids(record: SQSMessage) -> Tuple[str, str | None]: +def extract_trace_ids(record: SQSMessage) -> tuple[str, str | None]: """ Extract identifiers for tracing from SQS record. Returns: Tuple of (message_id, immunisation_id) diff --git a/lambdas/mns_publisher/tests/test_lambda_handler.py b/lambdas/mns_publisher/tests/test_lambda_handler.py index 1602e0121..2c45682f1 100644 --- a/lambdas/mns_publisher/tests/test_lambda_handler.py +++ b/lambdas/mns_publisher/tests/test_lambda_handler.py @@ -125,7 +125,7 @@ def setUpClass(cls): cls.sample_sqs_record = load_sample_sqs_event() @patch("process_records.logger") - @patch("process_records.get_mns_service") + @patch("process_records._get_runtime_mns_service") @patch("process_records.process_record") def test_process_records_all_success(self, mock_process_record, mock_get_mns, mock_logger): """Test processing multiple records with all successes.""" @@ -145,7 +145,7 @@ def test_process_records_all_success(self, mock_process_record, mock_get_mns, mo mock_logger.info.assert_called_with("Successfully processed all 2 messages") @patch("process_records.logger") - @patch("process_records.get_mns_service") + @patch("process_records._get_runtime_mns_service") @patch("process_records.process_record") def test_process_records_partial_failure(self, mock_process_record, mock_get_mns, mock_logger): """Test processing with some failures.""" @@ -167,7 +167,7 @@ def test_process_records_partial_failure(self, mock_process_record, mock_get_mns mock_logger.warning.assert_called_with("Batch completed with 1 failures") @patch("process_records.logger") - @patch("process_records.get_mns_service") + @patch("process_records._get_runtime_mns_service") @patch("process_records.process_record") def test_process_records_empty_list(self, mock_process_record, mock_get_mns, mock_logger): """Test processing empty record list.""" @@ -181,7 +181,7 @@ def test_process_records_empty_list(self, mock_process_record, mock_get_mns, moc mock_logger.info.assert_called_with("Successfully processed all 0 messages") @patch("process_records.logger") - @patch("process_records.get_mns_service") + @patch("process_records._get_runtime_mns_service") @patch("process_records.process_record") def test_process_records_mns_service_created_once(self, mock_process_record, mock_get_mns, mock_logger): """Test that MNS service is created only once for batch.""" @@ -300,7 +300,7 @@ def test_successful_notification_creation_with_gp(self, mock_logger, mock_get_to @responses.activate @patch("common.api_clients.authentication.AppRestrictedAuth.get_access_token") - @patch("process_records.get_mns_service") + @patch("process_records._get_runtime_mns_service") @patch("process_records.logger") def test_pds_failure(self, mock_logger, mock_get_mns, mock_get_token): """ diff --git a/lambdas/shared/src/common/api_clients/authentication.py b/lambdas/shared/src/common/api_clients/authentication.py index 396d41c19..ec4a9416f 100644 --- a/lambdas/shared/src/common/api_clients/authentication.py +++ b/lambdas/shared/src/common/api_clients/authentication.py @@ -2,7 +2,7 @@ import json import time import uuid -from enum import Enum +from typing import Any import jwt import requests @@ -10,26 +10,25 @@ from common.clients import logger from common.models.errors import UnhandledResponseError -from ..cache import Cache +GRANT_TYPE_CLIENT_CREDENTIALS = "client_credentials" +CLIENT_ASSERTION_TYPE_JWT_BEARER = "urn:ietf:params:oauth:client-assertion-type:jwt-bearer" +CONTENT_TYPE_X_WWW_FORM_URLENCODED = "application/x-www-form-urlencoded" - -class Service(Enum): - PDS = "pds" - IMMUNIZATION = "imms" +JWT_EXPIRY_SECONDS = 5 * 60 +ACCESS_TOKEN_EXPIRY_SECONDS = 10 * 60 +# Throw away the cached token earlier than the exact expiry time so we have enough +# time left to use it (and to account for network latency, clock skew etc.) +ACCESS_TOKEN_MIN_ACCEPTABLE_LIFETIME_SECONDS = 30 class AppRestrictedAuth: - def __init__(self, service: Service, secret_manager_client, environment, cache: Cache): + def __init__(self, secret_manager_client: Any, environment: str, secret_name: str | None = None): self.secret_manager_client = secret_manager_client - self.cache = cache - self.cache_key = f"{service.value}_access_token" - - self.expiry = 30 - self.secret_name = ( - f"imms/pds/{environment}/jwt-secrets" - if service == Service.PDS - else f"imms/immunization/{environment}/jwt-secrets" - ) + + self.cached_access_token: str | None = None + self.cached_access_token_expiry_time: int | None = None + + self.secret_name = f"imms/pds/{environment}/jwt-secrets" if secret_name is None else secret_name self.token_url = ( f"https://{environment}.api.service.nhs.uk/oauth2/token" @@ -37,61 +36,59 @@ def __init__(self, service: Service, secret_manager_client, environment, cache: else "https://api.service.nhs.uk/oauth2/token" ) - def get_service_secrets(self): - kwargs = {"SecretId": self.secret_name} - response = self.secret_manager_client.get_secret_value(**kwargs) + def get_service_secrets(self) -> dict[str, Any]: + response = self.secret_manager_client.get_secret_value(SecretId=self.secret_name) secret_object = json.loads(response["SecretString"]) secret_object["private_key"] = base64.b64decode(secret_object["private_key_b64"]).decode() - return secret_object - def create_jwt(self, now: int): - logger.info("create_jwt") + def create_jwt(self, now: int) -> str: secret_object = self.get_service_secrets() - claims = { - "iss": secret_object["api_key"], - "sub": secret_object["api_key"], - "aud": self.token_url, - "iat": now, - "exp": now + self.expiry, - "jti": str(uuid.uuid4()), - } - return jwt.encode( - claims, + { + "iss": secret_object["api_key"], + "sub": secret_object["api_key"], + "aud": self.token_url, + "iat": now, + "exp": now + JWT_EXPIRY_SECONDS, + "jti": str(uuid.uuid4()), + }, secret_object["private_key"], algorithm="RS512", headers={"kid": secret_object["kid"]}, ) - def get_access_token(self): - logger.info("get_access_token") + def get_access_token(self) -> str: now = int(time.time()) - logger.info(f"Current time: {now}, Expiry time: {now + self.expiry}") - # Check if token is cached and not expired - logger.info(f"Cache key: {self.cache_key}") - logger.info("Checking cache for access token") - cached = self.cache.get(self.cache_key) - if cached and cached["expires_at"] > now: - logger.info("Returning cached access token") - return cached["token"] + if ( + self.cached_access_token + and self.cached_access_token_expiry_time > now + ACCESS_TOKEN_MIN_ACCEPTABLE_LIFETIME_SECONDS + ): + return self.cached_access_token - logger.info("No valid cached token found, creating new token") + logger.info("Requesting new access token") _jwt = self.create_jwt(now) - headers = {"Content-Type": "application/x-www-form-urlencoded"} - data = { - "grant_type": "client_credentials", - "client_assertion_type": "urn:ietf:params:oauth:client-assertion-type:jwt-bearer", - "client_assertion": _jwt, - } - token_response = requests.post(self.token_url, data=data, headers=headers) + try: + token_response = requests.post( + self.token_url, + data={ + "grant_type": GRANT_TYPE_CLIENT_CREDENTIALS, + "client_assertion_type": CLIENT_ASSERTION_TYPE_JWT_BEARER, + "client_assertion": _jwt, + }, + headers={"Content-Type": CONTENT_TYPE_X_WWW_FORM_URLENCODED}, + timeout=10, + ) + except requests.RequestException as error: + logger.exception("Failed to fetch access token from %s", self.token_url) + raise UnhandledResponseError(response=str(error), message="Failed to get access token") from error + if token_response.status_code != 200: raise UnhandledResponseError(response=token_response.text, message="Failed to get access token") token = token_response.json().get("access_token") - - self.cache.put(self.cache_key, {"token": token, "expires_at": now + self.expiry}) - + self.cached_access_token = token + self.cached_access_token_expiry_time = now + ACCESS_TOKEN_EXPIRY_SECONDS return token diff --git a/lambdas/shared/src/common/api_clients/get_pds_details.py b/lambdas/shared/src/common/api_clients/get_pds_details.py index 63844b3cd..7f728c81e 100644 --- a/lambdas/shared/src/common/api_clients/get_pds_details.py +++ b/lambdas/shared/src/common/api_clients/get_pds_details.py @@ -3,30 +3,33 @@ """ import os -import tempfile -from common.api_clients.authentication import AppRestrictedAuth, Service +from common.api_clients.authentication import AppRestrictedAuth from common.api_clients.errors import PdsSyncException from common.api_clients.pds_service import PdsService -from common.cache import Cache from common.clients import get_secrets_manager_client, logger PDS_ENV = os.getenv("PDS_ENV", "int") -safe_tmp_dir = tempfile.mkdtemp(dir="/tmp") # NOSONAR(S5443) +_pds_service: PdsService | None = None -# Get Patient details from external service PDS using NHS number from MNS notification -def pds_get_patient_details(nhs_number: str) -> dict: - try: - cache = Cache(directory=safe_tmp_dir) + +def get_pds_service() -> PdsService: + global _pds_service + if _pds_service is None: authenticator = AppRestrictedAuth( - service=Service.PDS, secret_manager_client=get_secrets_manager_client(), environment=PDS_ENV, - cache=cache, ) - pds_service = PdsService(authenticator, PDS_ENV) - patient = pds_service.get_patient_details(nhs_number) + _pds_service = PdsService(authenticator, PDS_ENV) + + return _pds_service + + +# Get Patient details from external service PDS using NHS number from MNS notification +def pds_get_patient_details(nhs_number: str) -> dict: + try: + patient = get_pds_service().get_patient_details(nhs_number) return patient except Exception as e: msg = "Error retrieving patient details from PDS" diff --git a/lambdas/shared/src/common/api_clients/mns_service.py b/lambdas/shared/src/common/api_clients/mns_service.py index 417fef6ad..263c58e01 100644 --- a/lambdas/shared/src/common/api_clients/mns_service.py +++ b/lambdas/shared/src/common/api_clients/mns_service.py @@ -25,7 +25,6 @@ class MnsService: def __init__(self, authenticator: AppRestrictedAuth): self.authenticator = authenticator - self.access_token = self.authenticator.get_access_token() logging.info(f"Using SQS ARN for subscription: {SQS_ARN}") def _build_subscription_payload(self, event_type: str, reason: str | None = None, status: str = "requested") -> dict: @@ -54,9 +53,10 @@ def _build_subscription_payload(self, event_type: str, reason: str | None = None def _build_headers(self, content_type: str = "application/fhir+json") -> dict: """Build request headers with authentication and correlation ID.""" + access_token = self.authenticator.get_access_token() return { "Content-Type": content_type, - "Authorization": f"Bearer {self.access_token}", + "Authorization": f"Bearer {access_token}", "X-Correlation-ID": str(uuid.uuid4()), } @@ -138,7 +138,7 @@ def check_delete_subscription(self): return f"Error deleting subscription: {str(e)}" def publish_notification(self, notification_payload: MnsNotificationPayload) -> dict | None: - response = requests.request( + response = request_with_retry_backoff( "POST", f"{MNS_BASE_URL}/events", headers=self._build_headers(content_type="application/cloudevents+json"), diff --git a/lambdas/shared/src/common/api_clients/mns_setup.py b/lambdas/shared/src/common/api_clients/mns_setup.py index 5cecd4440..8df2ec4b0 100644 --- a/lambdas/shared/src/common/api_clients/mns_setup.py +++ b/lambdas/shared/src/common/api_clients/mns_setup.py @@ -4,11 +4,10 @@ import boto3 from botocore.config import Config -from common.api_clients.authentication import AppRestrictedAuth, Service +from common.api_clients.authentication import AppRestrictedAuth from common.api_clients.constants import DEV_ENVIRONMENT from common.api_clients.mns_service import MnsService from common.api_clients.mock_mns_service import MockMnsService -from common.cache import Cache logging.basicConfig(level=logging.INFO) MNS_TEST_QUEUE_URL = os.getenv("MNS_TEST_QUEUE_URL") @@ -20,13 +19,10 @@ def get_mns_service(mns_env: str = "int"): return MockMnsService(MNS_TEST_QUEUE_URL) else: boto_config = Config(region_name="eu-west-2") - cache = Cache(directory="/tmp") logging.info("Creating authenticator...") authenticator = AppRestrictedAuth( - service=Service.PDS, secret_manager_client=boto3.client("secretsmanager", config=boto_config), environment=mns_env, - cache=cache, ) logging.info("Authentication Initiated...") return MnsService(authenticator) diff --git a/lambdas/shared/src/common/cache.py b/lambdas/shared/src/common/cache.py deleted file mode 100644 index 94fd9abbd..000000000 --- a/lambdas/shared/src/common/cache.py +++ /dev/null @@ -1,33 +0,0 @@ -import json - - -class Cache: - """Key-value file cache""" - - def __init__(self, directory): - filename = f"{directory}/cache.json" - with open(filename, "a+") as self.cache_file: - self.cache_file.seek(0) - content = self.cache_file.read() - if len(content) == 0: - self.cache_dict = {} - else: - self.cache_dict = json.loads(content) - - def put(self, key: str, value: dict): - self.cache_dict[key] = value - self._overwrite() - - def get(self, key: str) -> dict | None: - return self.cache_dict.get(key, None) - - def delete(self, key: str): - if key not in self.cache_dict: - return - del self.cache_dict[key] - - def _overwrite(self): - with open(self.cache_file.name, "w") as self.cache_file: - self.cache_file.seek(0) - self.cache_file.write(json.dumps(self.cache_dict)) - self.cache_file.truncate() diff --git a/lambdas/shared/tests/test_common/api_clients/test_authentication.py b/lambdas/shared/tests/test_common/api_clients/test_authentication.py index 11fc2e1d8..d77929705 100644 --- a/lambdas/shared/tests/test_common/api_clients/test_authentication.py +++ b/lambdas/shared/tests/test_common/api_clients/test_authentication.py @@ -7,7 +7,7 @@ import responses from responses import matchers -from common.api_clients.authentication import AppRestrictedAuth, Service +from common.api_clients.authentication import ACCESS_TOKEN_EXPIRY_SECONDS, AppRestrictedAuth from common.models.errors import UnhandledResponseError @@ -29,11 +29,8 @@ def setUp(self): self.secret_manager_client = MagicMock() self.secret_manager_client.get_secret_value.return_value = secret_response - self.cache = MagicMock() - self.cache.get.return_value = None - env = "an-env" - self.authenticator = AppRestrictedAuth(Service.PDS, self.secret_manager_client, env, self.cache) + self.authenticator = AppRestrictedAuth(self.secret_manager_client, env) self.url = f"https://{env}.api.service.nhs.uk/oauth2/token" @responses.activate @@ -89,35 +86,30 @@ def test_env_mapping(self): """it should target int environment for none-prod environment, otherwise int""" # For env=none-prod env = "some-env" - auth = AppRestrictedAuth(Service.PDS, None, env, None) + auth = AppRestrictedAuth(None, env) self.assertTrue(auth.token_url.startswith(f"https://{env}.")) # For env=prod env = "prod" - auth = AppRestrictedAuth(Service.PDS, None, env, None) - self.assertTrue(env not in auth.token_url) + auth = AppRestrictedAuth(None, env) + self.assertNotIn(env, auth.token_url) def test_returned_cached_token(self): """it should return cached token""" - cached_token = { - "token": "a-cached-access-token", - "expires_at": int(time.time()) + 99999, # make sure it's not expired - } - self.cache.get.return_value = cached_token + self.authenticator.cached_access_token = "a-cached-access-token" + self.authenticator.cached_access_token_expiry_time = int(time.time()) + 99999 # make sure it's not expired # When token = self.authenticator.get_access_token() # Then - self.assertEqual(token, cached_token["token"]) + self.assertEqual(token, "a-cached-access-token") self.secret_manager_client.assert_not_called() @responses.activate def test_update_cache(self): """it should update cached token""" - self.cache.get.return_value = None token = "a-new-access-token" - cached_token = {"token": token, "expires_at": ANY} responses.add(responses.POST, self.url, status=200, json={"access_token": token}) with patch("jwt.encode") as mock_jwt: @@ -126,18 +118,15 @@ def test_update_cache(self): self.authenticator.get_access_token() # Then - self.cache.put.assert_called_once_with(f"{Service.PDS.value}_access_token", cached_token) + self.assertEqual(self.authenticator.cached_access_token, "a-new-access-token") @responses.activate def test_expired_token_in_cache(self): """it should not return cached access token if it's expired""" now_epoch = 12345 - expires_at = now_epoch + self.authenticator.expiry - cached_token = { - "token": "an-expired-cached-access-token", - "expires_at": expires_at, - } - self.cache.get.return_value = cached_token + expires_at = now_epoch + ACCESS_TOKEN_EXPIRY_SECONDS + self.authenticator.cached_access_token = ("an-expired-cached-access-token",) + self.authenticator.cached_access_token_expiry_time = expires_at new_token = "a-new-token" responses.add(responses.POST, self.url, status=200, json={"access_token": new_token}) @@ -151,42 +140,12 @@ def test_expired_token_in_cache(self): self.authenticator.get_access_token() # Then - exp_cached_token = { - "token": new_token, - "expires_at": new_now + self.authenticator.expiry, - } - self.cache.put.assert_called_once_with(ANY, exp_cached_token) - - @responses.activate - def test_uses_cache_for_token(self): - """it should use the cache for the `Service` auth call""" - - token = "a-new-access-token" - token_call = responses.add(responses.POST, self.url, status=200, json={"access_token": token}) - values = {} - - def get_side_effect(key): - return values.get(key, None) - - def put_side_effect(key, value): - values[key] = value - - self.cache.get.side_effect = get_side_effect - self.cache.put.side_effect = put_side_effect - - with patch("common.api_clients.authentication.jwt.encode") as mock_jwt: - mock_jwt.return_value = "a-jwt" - # When - self.assertEqual(0, token_call.call_count) - self.authenticator.get_access_token() - self.assertEqual(1, token_call.call_count) - self.authenticator.get_access_token() - self.assertEqual(1, token_call.call_count) + self.assertEqual(self.authenticator.cached_access_token, new_token) + self.assertEqual(self.authenticator.cached_access_token_expiry_time, new_now + ACCESS_TOKEN_EXPIRY_SECONDS) @responses.activate def test_raise_exception(self): """it should raise exception if auth response is not 200""" - self.cache.get.return_value = None responses.add(responses.POST, self.url, status=400) with patch("common.api_clients.authentication.jwt.encode") as mock_jwt: diff --git a/lambdas/shared/tests/test_common/api_clients/test_mns_service.py b/lambdas/shared/tests/test_common/api_clients/test_mns_service.py index 3cc9daab9..b1f6191ad 100644 --- a/lambdas/shared/tests/test_common/api_clients/test_mns_service.py +++ b/lambdas/shared/tests/test_common/api_clients/test_mns_service.py @@ -50,7 +50,7 @@ def test_successful_subscription(self, mock_request): # Assert self.assertEqual(result, {"subscriptionId": "abc123"}) self.assertEqual(mock_request.call_count, 2) - self.authenticator.get_access_token.assert_called_once() + self.assertGreaterEqual(self.authenticator.get_access_token.call_count, 1) @patch("common.api_clients.mns_service.requests.request") def test_not_found_subscription(self, mock_request): @@ -293,13 +293,13 @@ def test_unhandled_status_code(self): self.assertIn("Unhandled error: 418", str(context.exception)) self.assertEqual(context.exception.response, {"resource": 1234}) - @patch("common.api_clients.mns_service.requests.request") - def test_publish_notification_success(self, mock_request): + @patch("common.api_clients.mns_service.request_with_retry_backoff") + def test_publish_notification_success(self, mock_request_with_retry_backoff): """Test successful notification publishing.""" mock_response = Mock() mock_response.status_code = 200 mock_response.json.return_value = {"status": "published"} - mock_request.return_value = mock_response + mock_request_with_retry_backoff.return_value = mock_response notification_payload = { "specversion": "1.0", @@ -313,27 +313,27 @@ def test_publish_notification_success(self, mock_request): self.assertEqual(result["status"], "published") - # Verify the request was made correctly - mock_request.assert_called_once() - call_args = mock_request.call_args + # Verify the request was made correctly through retry helper + mock_request_with_retry_backoff.assert_called_once() + call_args = mock_request_with_retry_backoff.call_args headers = call_args[1]["headers"] self.assertEqual(headers["Content-Type"], "application/cloudevents+json") - mock_request.assert_called_once() - @patch("common.api_clients.mns_service.requests.request") + @patch("common.api_clients.mns_service.request_with_retry_backoff") @patch("common.api_clients.mns_service.raise_error_response") - def test_publish_notification_failure(self, mock_raise_error, mock_request): + def test_publish_notification_failure(self, mock_raise_error, mock_request_with_retry_backoff): """Test notification publishing failure.""" mock_response = Mock() mock_response.status_code = 400 - mock_request.return_value = mock_response + mock_request_with_retry_backoff.return_value = mock_response notification_payload = {"id": "test-id"} service = MnsService(self.authenticator) service.publish_notification(notification_payload) + mock_request_with_retry_backoff.assert_called_once() mock_raise_error.assert_called_once_with(mock_response) diff --git a/lambdas/shared/tests/test_common/api_clients/test_pds_details.py b/lambdas/shared/tests/test_common/api_clients/test_pds_details.py index f833c10d0..e58b430ee 100644 --- a/lambdas/shared/tests/test_common/api_clients/test_pds_details.py +++ b/lambdas/shared/tests/test_common/api_clients/test_pds_details.py @@ -2,30 +2,17 @@ from unittest.mock import MagicMock, patch from common.api_clients.errors import PdsSyncException -from common.api_clients.get_pds_details import pds_get_patient_details +from common.api_clients.get_pds_details import get_pds_service, pds_get_patient_details class TestGetPdsPatientDetails(unittest.TestCase): def setUp(self): - """Set up test fixtures and mocks""" self.test_patient_id = "9912003888" + get_pds_service.__globals__["_pds_service"] = None - # Patch all external dependencies self.logger_patcher = patch("common.api_clients.get_pds_details.logger") self.mock_logger = self.logger_patcher.start() - self.secrets_manager_patcher = patch("common.clients.global_secrets_manager_client") - self.mock_secrets_manager = self.secrets_manager_patcher.start() - - self.pds_env_patcher = patch("os.getenv") - self.mock_pds_env = self.pds_env_patcher.start() - self.mock_pds_env.return_value = "test-env" - - self.cache_patcher = patch("common.api_clients.get_pds_details.Cache") - self.mock_cache_class = self.cache_patcher.start() - self.mock_cache_instance = MagicMock() - self.mock_cache_class.return_value = self.mock_cache_instance - self.auth_patcher = patch("common.api_clients.get_pds_details.AppRestrictedAuth") self.mock_auth_class = self.auth_patcher.start() self.mock_auth_instance = MagicMock() @@ -37,12 +24,10 @@ def setUp(self): self.mock_pds_service_class.return_value = self.mock_pds_service_instance def tearDown(self): - """Clean up patches""" + get_pds_service.__globals__["_pds_service"] = None patch.stopall() def test_pds_get_patient_details_success(self): - """Test successful retrieval of patient details""" - # Arrange expected_patient_data = { "identifier": [{"value": "9912003888"}], "name": "John Doe", @@ -51,159 +36,55 @@ def test_pds_get_patient_details_success(self): } self.mock_pds_service_instance.get_patient_details.return_value = expected_patient_data - # Act result = pds_get_patient_details(self.test_patient_id) - # Assert self.assertEqual(result["identifier"][0]["value"], "9912003888") - - # Verify Cache was initialized correctly - self.mock_cache_class.assert_called_once() - - # Verify get_patient_details was called - self.mock_pds_service_instance.get_patient_details.assert_called_once() + self.mock_auth_class.assert_called_once() + self.mock_pds_service_class.assert_called_once() def test_pds_get_patient_details_no_patient_found(self): - """Test when PDS returns None (no patient found)""" - # Arrange self.mock_pds_service_instance.get_patient_details.return_value = None - # Act result = pds_get_patient_details(self.test_patient_id) - # Assert self.assertIsNone(result) - self.mock_pds_service_instance.get_patient_details.assert_called_once_with(self.test_patient_id) - def test_pds_get_patient_details_empty_response(self): - """Test when PDS returns empty dict (falsy)""" - # Arrange - self.mock_pds_service_instance.get_patient_details.return_value = None - - # Act - result = pds_get_patient_details(self.test_patient_id) - - # Assert - self.assertIsNone(result) - def test_pds_get_patient_details_pds_service_exception(self): - """Test when PdsService.get_patient_details raises an exception""" - # Arrange mock_exception = Exception("My custom error") self.mock_pds_service_instance.get_patient_details.side_effect = mock_exception - # Act with self.assertRaises(PdsSyncException) as context: pds_get_patient_details(self.test_patient_id) exception = context.exception - # Assert self.assertEqual( exception.message, "Error retrieving patient details from PDS", ) - # Verify exception was logged self.mock_logger.exception.assert_called_once_with("Error retrieving patient details from PDS") - self.mock_pds_service_instance.get_patient_details.assert_called_once_with(self.test_patient_id) - def test_pds_get_patient_details_cache_initialization_error(self): - """Test when Cache initialization fails""" - # Arrange - self.mock_cache_class.side_effect = OSError("Cannot write to /tmp") - - # Act - with self.assertRaises(PdsSyncException) as context: - pds_get_patient_details(self.test_patient_id) - - # Assert - exception = context.exception - self.assertEqual( - exception.message, - "Error retrieving patient details from PDS", - ) - - # Verify exception was logged - self.mock_logger.exception.assert_called_once_with("Error retrieving patient details from PDS") - - self.mock_cache_class.assert_called_once() - def test_pds_get_patient_details_auth_initialization_error(self): - """Test when AppRestrictedAuth initialization fails""" - # Arrange self.mock_auth_class.side_effect = ValueError("Invalid authentication parameters") - # Act with self.assertRaises(PdsSyncException) as context: pds_get_patient_details(self.test_patient_id) - # Assert exception = context.exception self.assertEqual( exception.message, "Error retrieving patient details from PDS", ) - # Verify exception was logged self.mock_logger.exception.assert_called_once_with("Error retrieving patient details from PDS") - def test_pds_get_patient_details_exception(self): - """Test when logger.info throws an exception""" - # Arrange - test_exception = Exception("some-random-error") - self.mock_pds_service_class.side_effect = test_exception - test_nhs_number = "another-nhs-number" - - # Act - with self.assertRaises(Exception) as context: - pds_get_patient_details(test_nhs_number) - - exception = context.exception - # Assert - self.assertEqual( - exception.message, - "Error retrieving patient details from PDS", - ) - # Verify logger.exception was called due to the caught exception - self.mock_logger.exception.assert_called_once_with("Error retrieving patient details from PDS") + def test_reuses_same_pds_service_instance(self): + pds_get_patient_details("1111111111") + pds_get_patient_details("2222222222") - def test_pds_get_patient_details_different_patient_ids(self): - """Test with different patient ID formats""" - test_cases = [ - ("9912003888", {"identifier": [{"value": "9912003888"}]}), - ("1234567890", {"identifier": [{"value": "1234567890"}]}), - ("0000000000", {"identifier": [{"value": "0000000000"}]}), - ] - - for patient_id, expected_response in test_cases: - with self.subTest(patient_id=patient_id): - # Reset mocks - self.mock_pds_service_instance.reset_mock() - self.mock_logger.reset_mock() - - # Arrange - self.mock_pds_service_instance.get_patient_details.return_value = expected_response - - # Act - result = pds_get_patient_details(patient_id) - - # Assert - self.assertEqual(result, expected_response) - self.mock_pds_service_instance.get_patient_details.assert_called_once_with(patient_id) - - def test_pds_get_patient_details(self): - """Test with complex identifier structure""" - # Arrange - test_nhs_number = "9912003888" - pds_id = "abcefghijkl" - mock_pds_response = {"identifier": [{"value": pds_id}]} - self.mock_pds_service_instance.get_patient_details.return_value = mock_pds_response - # Act - result = pds_get_patient_details(test_nhs_number) - - # Assert - function should extract the value from first identifier - self.assertEqual(result, mock_pds_response) - self.mock_pds_service_instance.get_patient_details.assert_called_once_with(test_nhs_number) + self.mock_auth_class.assert_called_once() + self.mock_pds_service_class.assert_called_once() + self.assertEqual(self.mock_pds_service_instance.get_patient_details.call_count, 2) diff --git a/lambdas/shared/tests/test_common/test_cache.py b/lambdas/shared/tests/test_common/test_cache.py deleted file mode 100644 index 8125099ac..000000000 --- a/lambdas/shared/tests/test_common/test_cache.py +++ /dev/null @@ -1,88 +0,0 @@ -import json -import os -import tempfile -import unittest - -from src.common.cache import Cache - - -class TestCache(unittest.TestCase): - def setUp(self): - self.cache = Cache(tempfile.gettempdir()) - - def test_cache_put(self): - """it should store cache in specified key""" - value = {"foo": "a-foo", "bar": 42} - key = "a_key" - - # When - self.cache.put(key, value) - act_value = self.cache.get(key) - - # Then - self.assertDictEqual(value, act_value) - - def test_cache_put_overwrite(self): - """it should store updated cache value""" - value = {"foo": "a-foo", "bar": 42} - key = "a_key" - self.cache.put(key, value) - - new_value = {"foo": "new-foo"} - self.cache.put(key, new_value) - - # When - updated_value = self.cache.get(key) - - # Then - self.assertDictEqual(new_value, updated_value) - - def test_key_not_found(self): - """it should return None if key doesn't exist""" - value = self.cache.get("it-does-not-exist") - self.assertIsNone(value) - - def test_delete(self): - """it should delete key""" - key = "a_key" - self.cache.put(key, {"a": "b"}) - self.cache.delete(key) - - value = self.cache.get(key) - self.assertIsNone(value) - - def test_delete_key_not_found(self): - """it should return None gracefully if key doesn't exist""" - value = self.cache.delete("it-does-not-exist") - self.assertIsNone(value) - - def test_write_to_file(self): - """it should update the cache file""" - value = {"foo": "a-long-foo-so-to-make-sure-truncate-is-working", "bar": 42} - key = "a_key" - self.cache.put(key, value) - # Add one and delete to make sure file gets updated - self.cache.put("to-delete-key", {"x": "y"}) - self.cache.delete("to-delete-key") - - # When - new_value = {"a": "b"} - self.cache.put(key, new_value) - - # Then - with open(self.cache.cache_file.name) as stored: - content = json.loads(stored.read()) - self.assertDictEqual(content[key], new_value) - - def test_cache_create_empty(self): - """it should gracefully create an empty cache""" - filename = f"{tempfile.gettempdir()}/cache.json" - os.remove(filename) - - # When - self.cache = Cache(tempfile.gettempdir()) - - # Then - with open(self.cache.cache_file.name) as stored: - content = stored.read() - self.assertEqual(len(content), 0)