diff --git a/pathology-api/lambda_handler.py b/pathology-api/lambda_handler.py index a769da89..6b82a9cb 100644 --- a/pathology-api/lambda_handler.py +++ b/pathology-api/lambda_handler.py @@ -13,6 +13,7 @@ from pathology_api.fhir.r4.resources import Bundle, OperationOutcome from pathology_api.handler import handle_request from pathology_api.logging import get_logger +from pathology_api.request_context import set_correlation_id _logger = get_logger(__name__) @@ -102,8 +103,16 @@ def status() -> Response[str]: return Response(status_code=200, body="OK", headers={"Content-Type": "text/plain"}) +_CORRELATION_ID_HEADER = "nhsd-correlation-id" + + @app.post("/FHIR/R4/Bundle") def post_result() -> Response[str]: + correlation_id = app.current_event.headers.get(_CORRELATION_ID_HEADER) + if not correlation_id: + raise ValidationError(f"Missing required header: {_CORRELATION_ID_HEADER}") + set_correlation_id(correlation_id) + _logger.debug("Post result endpoint called.") try: diff --git a/pathology-api/src/pathology_api/logging.py b/pathology-api/src/pathology_api/logging.py index d094698c..fc59087e 100644 --- a/pathology-api/src/pathology_api/logging.py +++ b/pathology-api/src/pathology_api/logging.py @@ -1,7 +1,18 @@ +import logging from typing import Any, Protocol from aws_lambda_powertools import Logger +from pathology_api.request_context import get_correlation_id + + +class _CorrelationIdFilter(logging.Filter): + """Injects the current correlation ID into every log record.""" + + def filter(self, record: logging.LogRecord) -> bool: + record.correlation_id = get_correlation_id() + return True + class LogProvider(Protocol): """Protocol defining required contract for a logger.""" @@ -19,4 +30,6 @@ def exception(self, msg: str, *args: Any, **kwargs: Any) -> None: ... def get_logger(service: str) -> LogProvider: """Get a configured logger instance.""" - return Logger(service=service, level="DEBUG", serialize_stacktrace=True) + logger = Logger(service=service, level="DEBUG", serialize_stacktrace=True) + logger.addFilter(_CorrelationIdFilter()) + return logger diff --git a/pathology-api/src/pathology_api/request_context.py b/pathology-api/src/pathology_api/request_context.py new file mode 100644 index 00000000..10422a3b --- /dev/null +++ b/pathology-api/src/pathology_api/request_context.py @@ -0,0 +1,13 @@ +from contextvars import ContextVar + +_correlation_id: ContextVar[str] = ContextVar("correlation_id", default="") + + +def set_correlation_id(value: str) -> None: + """Set the correlation ID for the current request context.""" + _correlation_id.set(value) + + +def get_correlation_id() -> str: + """Get the correlation ID for the current request context.""" + return _correlation_id.get() diff --git a/pathology-api/src/pathology_api/test_logging.py b/pathology-api/src/pathology_api/test_logging.py new file mode 100644 index 00000000..98b10714 --- /dev/null +++ b/pathology-api/src/pathology_api/test_logging.py @@ -0,0 +1,54 @@ +import logging + +from pathology_api.logging import ( + _CorrelationIdFilter, + get_logger, +) +from pathology_api.request_context import set_correlation_id + + +class TestCorrelationIdFilter: + def test_filter_injects_correlation_id_into_log_record(self) -> None: + set_correlation_id("test-abc-123") + + record = logging.LogRecord( + name="test", + level=logging.INFO, + pathname="test.py", + lineno=1, + msg="test message", + args=None, + exc_info=None, + ) + + f = _CorrelationIdFilter() + result = f.filter(record) + + assert result is True + assert record.correlation_id == "test-abc-123" # type: ignore[attr-defined] + + def test_filter_uses_empty_default_when_no_correlation_id_set(self) -> None: + set_correlation_id("") + + record = logging.LogRecord( + name="test", + level=logging.INFO, + pathname="test.py", + lineno=1, + msg="test message", + args=None, + exc_info=None, + ) + + f = _CorrelationIdFilter() + f.filter(record) + + assert record.correlation_id == "" # type: ignore[attr-defined] + + +class TestGetLogger: + def test_get_logger_attaches_correlation_id_filter(self) -> None: + logger = get_logger("test-service") + + filters = getattr(logger, "filters", []) + assert any(isinstance(f, _CorrelationIdFilter) for f in filters) diff --git a/pathology-api/src/pathology_api/test_request_context.py b/pathology-api/src/pathology_api/test_request_context.py new file mode 100644 index 00000000..9ccaf380 --- /dev/null +++ b/pathology-api/src/pathology_api/test_request_context.py @@ -0,0 +1,11 @@ +from pathology_api.request_context import get_correlation_id, set_correlation_id + + +class TestSetAndGetCorrelationId: + def test_set_and_get_correlation_id(self) -> None: + set_correlation_id("round-trip-test-123") + assert get_correlation_id() == "round-trip-test-123" + + def test_default_correlation_id_is_empty(self) -> None: + set_correlation_id("") + assert get_correlation_id() == "" diff --git a/pathology-api/test_lambda_handler.py b/pathology-api/test_lambda_handler.py index 7f867aea..ed8cd020 100644 --- a/pathology-api/test_lambda_handler.py +++ b/pathology-api/test_lambda_handler.py @@ -8,6 +8,7 @@ from pathology_api.exception import ValidationError from pathology_api.fhir.r4.elements import LogicalReference, PatientIdentifier from pathology_api.fhir.r4.resources import Bundle, Composition, OperationOutcome +from pathology_api.request_context import get_correlation_id class TestHandler: @@ -16,9 +17,11 @@ def _create_test_event( body: str | None = None, path_params: str | None = None, request_method: str | None = None, + headers: dict[str, str] | None = None, ) -> dict[str, Any]: return { "body": body, + "headers": headers or {}, "requestContext": { "http": { "path": f"/{path_params}", @@ -58,6 +61,7 @@ def test_create_test_result_success(self) -> None: body=bundle.model_dump_json(by_alias=True), path_params="FHIR/R4/Bundle", request_method="POST", + headers={"nhsd-correlation-id": "test-correlation-id"}, ) context = LambdaContext() @@ -76,9 +80,72 @@ def test_create_test_result_success(self) -> None: # A UUID value so can only check its presence. assert response_bundle.id is not None + def test_correlation_id_is_set_from_request_header(self) -> None: + correlation_id = "test-correlation-id-abc-123" + bundle = Bundle.create( + type="document", + entry=[ + Bundle.Entry( + fullUrl="composition", + resource=Composition.create( + subject=LogicalReference( + PatientIdentifier.from_nhs_number("nhs_number") + ) + ), + ) + ], + ) + event = self._create_test_event( + body=bundle.model_dump_json(by_alias=True), + path_params="FHIR/R4/Bundle", + request_method="POST", + headers={"nhsd-correlation-id": correlation_id}, + ) + context = LambdaContext() + + handler(event, context) + + assert get_correlation_id() == correlation_id + + def test_missing_correlation_id_header_returns_400(self) -> None: + bundle = Bundle.create( + type="document", + entry=[ + Bundle.Entry( + fullUrl="composition", + resource=Composition.create( + subject=LogicalReference( + PatientIdentifier.from_nhs_number("nhs_number") + ) + ), + ) + ], + ) + event = self._create_test_event( + body=bundle.model_dump_json(by_alias=True), + path_params="FHIR/R4/Bundle", + request_method="POST", + ) + context = LambdaContext() + + response = handler(event, context) + + assert response["statusCode"] == 400 + assert response["headers"] == {"Content-Type": "application/fhir+json"} + + returned_issue = self._parse_returned_issue(response["body"]) + assert returned_issue["severity"] == "error" + assert returned_issue["code"] == "invalid" + assert ( + returned_issue["diagnostics"] + == "Missing required header: nhsd-correlation-id" + ) + def test_create_test_result_no_payload(self) -> None: event = self._create_test_event( - path_params="FHIR/R4/Bundle", request_method="POST" + path_params="FHIR/R4/Bundle", + request_method="POST", + headers={"nhsd-correlation-id": "test-correlation-id"}, ) context = LambdaContext() @@ -98,7 +165,10 @@ def test_create_test_result_no_payload(self) -> None: def test_create_test_result_empty_payload(self) -> None: event = self._create_test_event( - body="{}", path_params="FHIR/R4/Bundle", request_method="POST" + body="{}", + path_params="FHIR/R4/Bundle", + request_method="POST", + headers={"nhsd-correlation-id": "test-correlation-id"}, ) context = LambdaContext() @@ -118,7 +188,10 @@ def test_create_test_result_empty_payload(self) -> None: def test_create_test_result_invalid_json(self) -> None: event = self._create_test_event( - body="invalid json", path_params="FHIR/R4/Bundle", request_method="POST" + body="invalid json", + path_params="FHIR/R4/Bundle", + request_method="POST", + headers={"nhsd-correlation-id": "test-correlation-id"}, ) context = LambdaContext() @@ -169,6 +242,7 @@ def test_create_test_result_processing_error( body=bundle.model_dump_json(by_alias=True), path_params="FHIR/R4/Bundle", request_method="POST", + headers={"nhsd-correlation-id": "test-correlation-id"}, ) context = LambdaContext() @@ -207,6 +281,7 @@ def test_create_test_result_model_validate_error( body=bundle.model_dump_json(by_alias=True), path_params="FHIR/R4/Bundle", request_method="POST", + headers={"nhsd-correlation-id": "test-correlation-id"}, ) context = LambdaContext()