Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Makefile
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
SHELL=/usr/bin/env bash -euo pipefail

PYTHON_PROJECT_DIRS_WITH_UNIT_TESTS = lambdas/backend lambdas/ack_backend lambdas/batch_processor_filter lambdas/delta_backend lambdas/filenameprocessor lambdas/id_sync lambdas/mesh_processor lambdas/mns_subscription lambdas/recordforwarder lambdas/recordprocessor lambdas/redis_sync lambdas/shared
PYTHON_PROJECT_DIRS = tests/e2e_automation quality_checks $(PYTHON_PROJECT_DIRS_WITH_UNIT_TESTS)
PYTHON_PROJECT_DIRS = tests/e2e_automation tests/perf_tests quality_checks $(PYTHON_PROJECT_DIRS_WITH_UNIT_TESTS)

.PHONY: install lint format format-check clean publish oas build-proxy release initialise-all-python-venvs update-all-python-dependencies run-all-python-unit-tests build-all-docker-images

Expand Down
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ See https://nhsd-confluence.digital.nhs.uk/display/APM/Glossary.
| `id_sync` | **Imms Cross-cutting** – Handles [MNS](https://digital.nhs.uk/developer/api-catalogue/multicast-notification-service) NHS Number Change events and applies updates to affected records. |
| `mesh_processor` | **Imms Batch** – Triggered when new files are received via MESH. Moves them into the Imms Batch processing system. |
| `mns_subscription` | **Imms Cross-cutting** – Simple helper Lambda which sets up our required MNS subscription. Used in pipelines in DEV. |
| `perf_tests` | **Imms API** – Locust performance tests for the Immunisation API. |
| `recordforwarder` | **Imms Batch** – Consumes from the stream and applies the processed batch file row operations (CUD) to IEDS. |
| `recordprocessor` | **Imms Batch** – ECS Task - **not** a Lambda function - responsible for processing batch file rows and forwarding to the stream. |
| `redis_sync` | **Imms Cross-cutting** – Handles config file updates. E.g. disease mapping or permission files. |
Expand Down
6 changes: 1 addition & 5 deletions lambdas/id_sync/src/pds_details.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,8 @@

import tempfile

from common.api_clients.authentication import AppRestrictedAuth, Service
from common.api_clients.authentication import AppRestrictedAuth
from common.api_clients.pds_service import PdsService
from common.cache import Cache
from common.clients import get_secrets_manager_client, logger
from exceptions.id_sync_exception import IdSyncException
from os_vars import get_pds_env
Expand All @@ -18,12 +17,9 @@
# Get Patient details from external service PDS using NHS number from MNS notification
def pds_get_patient_details(nhs_number: str) -> dict:
try:
cache = Cache(directory=safe_tmp_dir)
authenticator = AppRestrictedAuth(
service=Service.PDS,
secret_manager_client=get_secrets_manager_client(),
environment=pds_env,
cache=cache,
)
pds_service = PdsService(authenticator, pds_env)
patient = pds_service.get_patient_details(nhs_number)
Expand Down
29 changes: 0 additions & 29 deletions lambdas/id_sync/tests/test_pds_details.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,11 +21,6 @@ def setUp(self):
self.mock_pds_env = self.pds_env_patcher.start()
self.mock_pds_env.return_value = "test-env"

self.cache_patcher = patch("pds_details.Cache")
self.mock_cache_class = self.cache_patcher.start()
self.mock_cache_instance = MagicMock()
self.mock_cache_class.return_value = self.mock_cache_instance

self.auth_patcher = patch("pds_details.AppRestrictedAuth")
self.mock_auth_class = self.auth_patcher.start()
self.mock_auth_instance = MagicMock()
Expand Down Expand Up @@ -57,9 +52,6 @@ def test_pds_get_patient_details_success(self):
# Assert
self.assertEqual(result["identifier"][0]["value"], "9912003888")

# Verify Cache was initialized correctly
self.mock_cache_class.assert_called_once()

# Verify get_patient_details was called
self.mock_pds_service_instance.get_patient_details.assert_called_once()

Expand Down Expand Up @@ -110,27 +102,6 @@ def test_pds_get_patient_details_pds_service_exception(self):

self.mock_pds_service_instance.get_patient_details.assert_called_once_with(self.test_patient_id)

def test_pds_get_patient_details_cache_initialization_error(self):
"""Test when Cache initialization fails"""
# Arrange
self.mock_cache_class.side_effect = OSError("Cannot write to /tmp")

# Act
with self.assertRaises(IdSyncException) as context:
pds_get_patient_details(self.test_patient_id)

# Assert
exception = context.exception
self.assertEqual(
exception.message,
"Error retrieving patient details from PDS",
)

# Verify exception was logged
self.mock_logger.exception.assert_called_once_with("Error retrieving patient details from PDS")

self.mock_cache_class.assert_called_once()

def test_pds_get_patient_details_auth_initialization_error(self):
"""Test when AppRestrictedAuth initialization fails"""
# Arrange
Expand Down
10 changes: 1 addition & 9 deletions lambdas/mns_subscription/src/mns_setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,24 +3,16 @@
import boto3
from botocore.config import Config

from common.api_clients.authentication import AppRestrictedAuth, Service
from common.api_clients.authentication import AppRestrictedAuth
from common.api_clients.mns_service import MnsService
from common.cache import Cache

logging.basicConfig(level=logging.INFO)


def get_mns_service(mns_env: str = "int"):
boto_config = Config(region_name="eu-west-2")
cache = Cache(directory="/tmp") # NOSONAR(S5443)
logging.info("Creating authenticator...")
# TODO: MNS and PDS need separate secrets
authenticator = AppRestrictedAuth(
service=Service.PDS,
secret_manager_client=boto3.client("secretsmanager", config=boto_config),
environment=mns_env,
cache=cache,
)

logging.info("Authentication Initiated...")
return MnsService(authenticator)
98 changes: 44 additions & 54 deletions lambdas/shared/src/common/api_clients/authentication.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,34 +2,32 @@
import json
import time
import uuid
from enum import Enum

import jwt
import requests

from common.clients import logger
from common.models.errors import UnhandledResponseError

from ..cache import Cache
GRANT_TYPE_CLIENT_CREDENTIALS = "client_credentials"
CLIENT_ASSERTION_TYPE_JWT_BEARER = "urn:ietf:params:oauth:client-assertion-type:jwt-bearer"
CONTENT_TYPE_X_WWW_FORM_URLENCODED = "application/x-www-form-urlencoded"


class Service(Enum):
PDS = "pds"
IMMUNIZATION = "imms"
JWT_EXPIRY_SECONDS = 5 * 60
ACCESS_TOKEN_EXPIRY_SECONDS = 10 * 60
# Throw away the cached token earlier than the exact expiry time so we have enough
# time left to use it (and to account for network latency, clock skew etc.)
ACCESS_TOKEN_MIN_ACCEPTABLE_LIFETIME_SECONDS = 30


class AppRestrictedAuth:
def __init__(self, service: Service, secret_manager_client, environment, cache: Cache):
def __init__(self, secret_manager_client, environment, secret_name=None):
self.secret_manager_client = secret_manager_client
self.cache = cache
self.cache_key = f"{service.value}_access_token"

self.expiry = 30
self.secret_name = (
f"imms/pds/{environment}/jwt-secrets"
if service == Service.PDS
else f"imms/immunization/{environment}/jwt-secrets"
)

self.cached_access_token = None
self.cached_access_token_expiry_time = None

self.secret_name = f"imms/pds/{environment}/jwt-secrets" if secret_name is None else secret_name

self.token_url = (
f"https://{environment}.api.service.nhs.uk/oauth2/token"
Expand All @@ -38,60 +36,52 @@ def __init__(self, service: Service, secret_manager_client, environment, cache:
)

def get_service_secrets(self):
kwargs = {"SecretId": self.secret_name}
response = self.secret_manager_client.get_secret_value(**kwargs)
response = self.secret_manager_client.get_secret_value(SecretId=self.secret_name)
secret_object = json.loads(response["SecretString"])
secret_object["private_key"] = base64.b64decode(secret_object["private_key_b64"]).decode()

return secret_object

def create_jwt(self, now: int):
logger.info("create_jwt")
secret_object = self.get_service_secrets()
claims = {
"iss": secret_object["api_key"],
"sub": secret_object["api_key"],
"aud": self.token_url,
"iat": now,
"exp": now + self.expiry,
"jti": str(uuid.uuid4()),
}

return jwt.encode(
claims,
{
"iss": secret_object["api_key"],
"sub": secret_object["api_key"],
"aud": self.token_url,
"iat": now,
"exp": now + JWT_EXPIRY_SECONDS,
"jti": str(uuid.uuid4()),
},
secret_object["private_key"],
algorithm="RS512",
headers={"kid": secret_object["kid"]},
)

def get_access_token(self):
logger.info("get_access_token")
now = int(time.time())
logger.info(f"Current time: {now}, Expiry time: {now + self.expiry}")
# Check if token is cached and not expired
logger.info(f"Cache key: {self.cache_key}")
logger.info("Checking cache for access token")
cached = self.cache.get(self.cache_key)

if cached and cached["expires_at"] > now:
logger.info("Returning cached access token")
return cached["token"]

logger.info("No valid cached token found, creating new token")
_jwt = self.create_jwt(now)

headers = {"Content-Type": "application/x-www-form-urlencoded"}
data = {
"grant_type": "client_credentials",
"client_assertion_type": "urn:ietf:params:oauth:client-assertion-type:jwt-bearer",
"client_assertion": _jwt,
}
token_response = requests.post(self.token_url, data=data, headers=headers)

if (
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This looks great Matt :)

@Akol125 it is worth noting how this works e.g. see how it is referenced in locustfile. When this merges, we will need to rebase with these changes and use this approach.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yup, will pick up, once merged, thanks for pointing out.

self.cached_access_token
and self.cached_access_token_expiry_time > now + ACCESS_TOKEN_MIN_ACCEPTABLE_LIFETIME_SECONDS
):
return self.cached_access_token

logger.info("Requesting new access token")
jwt = self.create_jwt(now)

token_response = requests.post(
self.token_url,
data={
"grant_type": GRANT_TYPE_CLIENT_CREDENTIALS,
"client_assertion_type": CLIENT_ASSERTION_TYPE_JWT_BEARER,
"client_assertion": jwt,
},
headers={"Content-Type": CONTENT_TYPE_X_WWW_FORM_URLENCODED},
)
if token_response.status_code != 200:
raise UnhandledResponseError(response=token_response.text, message="Failed to get access token")

token = token_response.json().get("access_token")

self.cache.put(self.cache_key, {"token": token, "expires_at": now + self.expiry})

self.cached_access_token = token
self.cached_access_token_expiry_time = now + ACCESS_TOKEN_EXPIRY_SECONDS
return token
34 changes: 0 additions & 34 deletions lambdas/shared/src/common/cache.py

This file was deleted.

Loading