From 2c2675b96199c66306336f153eba03f851002147 Mon Sep 17 00:00:00 2001 From: Peter Holloway Date: Wed, 1 Apr 2026 18:37:02 +0100 Subject: [PATCH 1/4] refactor: Rewrite server middleware with websocket support The builtin fastapi support for middleware only supports http/rest requests. To enable the same middleware for websockets, a new implementation of the starlette middleware is required and it makes sense to use the same implementation for both protocols. For rest requests, there should not be any change in behaviour from this change. --- src/blueapi/service/main.py | 43 ++++---------------- src/blueapi/service/middleware.py | 58 +++++++++++++++++++++++++++ tests/unit_tests/service/test_main.py | 4 +- 3 files changed, 67 insertions(+), 38 deletions(-) create mode 100644 src/blueapi/service/middleware.py diff --git a/src/blueapi/service/main.py b/src/blueapi/service/main.py index 679119238..e755044ba 100644 --- a/src/blueapi/service/main.py +++ b/src/blueapi/service/main.py @@ -25,18 +25,19 @@ get_tracer, start_as_current_span, ) -from opentelemetry.context import attach from opentelemetry.instrumentation.fastapi import FastAPIInstrumentor -from opentelemetry.propagate import get_global_textmap from opentelemetry.trace import get_tracer_provider from pydantic import ValidationError from pydantic.json_schema import SkipJsonSchema from starlette.responses import JSONResponse from super_state_machine.errors import TransitionError -from blueapi import __version__ from blueapi.config import ApplicationConfig, OIDCConfig, Tag from blueapi.service import interface +from blueapi.service.middleware import ( + ObservabilityContextPropagator, + VersionHeaders, +) from blueapi.worker import TrackableTask, WorkerState from blueapi.worker.event import TaskStatusEnum @@ -126,8 +127,9 @@ def get_app(config: ApplicationConfig): app.include_router(secure_router_v1, dependencies=dependencies) app.add_exception_handler(KeyError, on_key_error_404) app.add_exception_handler(jwt.PyJWTError, on_token_error_401) - app.middleware("http")(add_version_headers) - app.middleware("http")(inject_propagated_observability_context) + + app.add_middleware(ObservabilityContextPropagator) + app.add_middleware(VersionHeaders) app.middleware("http")(log_request_details) if config.api.cors: app.add_middleware( @@ -607,15 +609,6 @@ def start(config: ApplicationConfig): ) -async def add_version_headers( - request: Request, call_next: Callable[[Request], Awaitable[Response]] -): - response = await call_next(request) - response.headers["X-API-Version"] = ApplicationConfig.REST_API_VERSION - response.headers["X-BlueAPI-Version"] = __version__ - return response - - async def log_request_details( request: Request, call_next: Callable[[Request], Awaitable[StreamingResponse]] ) -> Response: @@ -637,25 +630,3 @@ async def log_request_details( LOGGER.info(log_message, extra=extra) return response - - -async def inject_propagated_observability_context( - request: Request, call_next: Callable[[Request], Awaitable[Response]] -) -> Response: - """Middleware to extract any propagated observability context from the - HTTP headers and attach it to the local one. - """ - headers = request.headers - if ApplicationConfig.CONTEXT_HEADER in headers: - carrier = { - ApplicationConfig.CONTEXT_HEADER: headers[ApplicationConfig.CONTEXT_HEADER] - } - if ApplicationConfig.VENDOR_CONTEXT_HEADER in headers: - carrier[ApplicationConfig.VENDOR_CONTEXT_HEADER] = headers[ - ApplicationConfig.VENDOR_CONTEXT_HEADER - ] - ctx = get_global_textmap().extract(carrier) - - attach(ctx) - response = await call_next(request) - return response diff --git a/src/blueapi/service/middleware.py b/src/blueapi/service/middleware.py new file mode 100644 index 000000000..489fdbf7a --- /dev/null +++ b/src/blueapi/service/middleware.py @@ -0,0 +1,58 @@ +import logging + +from opentelemetry.context import attach +from opentelemetry.propagate import get_global_textmap +from starlette.types import ASGIApp, Message, Receive, Scope, Send + +from blueapi import __version__ +from blueapi.config import ApplicationConfig + +OBS_LOGGER = logging.getLogger("blueapi.service.middleware.observability") + +CONTEXT_HEADER = ApplicationConfig.CONTEXT_HEADER.encode() +VENDOR_CONTEXT_HEADER = ApplicationConfig.VENDOR_CONTEXT_HEADER.encode() + +API_VERSION = (b"x-api-version", ApplicationConfig.REST_API_VERSION.encode("utf-8")) +VERSION = (b"x-blueapi-version", __version__.encode("utf-8")) + + +class VersionHeaders: + def __init__(self, app: ASGIApp): + self.app = app + + async def __call__(self, scope: Scope, receive: Receive, send: Send): + if scope.get("type") not in ("websocket", "http"): + return await self.app(scope, receive, send) + + async def local_send(message: Message): + if message["type"] in ("websocket.accept", "http.response.start"): + message["headers"].append(VERSION) + message["headers"].append(API_VERSION) + await send(message) + + await self.app(scope, receive, local_send) + + +class ObservabilityContextPropagator: + def __init__(self, app: ASGIApp): + self.app = app + + async def __call__(self, scope: Scope, receive: Receive, send: Send): + if scope["type"] not in ("http", "websocket"): + return await self.app(scope, receive, send) + + ctx = None + v_ctx = None + for key, val in scope.get("headers", ()): + if key == CONTEXT_HEADER: + ctx = val.decode() + elif key == VENDOR_CONTEXT_HEADER: + v_ctx = val.decode() + if ctx: + OBS_LOGGER.debug("Propagating observability context: %s, %s", ctx, v_ctx) + carrier = {ApplicationConfig.CONTEXT_HEADER: ctx} + if v_ctx: + carrier[ApplicationConfig.VENDOR_CONTEXT_HEADER] = v_ctx + attach(get_global_textmap().extract(carrier)) + + await self.app(scope, receive, send) diff --git a/tests/unit_tests/service/test_main.py b/tests/unit_tests/service/test_main.py index 4a4bcca63..fb689985e 100644 --- a/tests/unit_tests/service/test_main.py +++ b/tests/unit_tests/service/test_main.py @@ -8,15 +8,15 @@ from blueapi import __version__ from blueapi.config import ApplicationConfig from blueapi.service.main import ( - add_version_headers, get_passthrough_headers, log_request_details, ) +from blueapi.service.middleware import VersionHeaders async def test_add_version_header(): app = FastAPI() - app.middleware("http")(add_version_headers) + app.add_middleware(VersionHeaders) @app.get("/") async def root(): From 187e4fb5b2fce8fa0814ce03cef86ab546e32efa Mon Sep 17 00:00:00 2001 From: Peter Holloway Date: Tue, 7 Apr 2026 17:32:44 +0100 Subject: [PATCH 2/4] Add middleware tests --- tests/unit_tests/service/test_middleware.py | 108 ++++++++++++++++++++ 1 file changed, 108 insertions(+) create mode 100644 tests/unit_tests/service/test_middleware.py diff --git a/tests/unit_tests/service/test_middleware.py b/tests/unit_tests/service/test_middleware.py new file mode 100644 index 000000000..1ee183bc4 --- /dev/null +++ b/tests/unit_tests/service/test_middleware.py @@ -0,0 +1,108 @@ +from unittest.mock import AsyncMock, MagicMock, Mock, patch + +import pytest +from starlette.types import ASGIApp + +from blueapi.config import ApplicationConfig +from blueapi.service.middleware import ( + API_VERSION, + CONTEXT_HEADER, + VENDOR_CONTEXT_HEADER, + VERSION, + ObservabilityContextPropagator, + VersionHeaders, +) + + +@pytest.fixture +def app(): + return AsyncMock(spec=ASGIApp) + + +@pytest.mark.parametrize( + "protocol,message_type", + [("http", "http.response.start"), ("websocket", "websocket.accept")], +) +async def test_version_headers_added(app: Mock, protocol: str, message_type: str): + vh = VersionHeaders(app) + + send = AsyncMock() + scope = {"type": protocol} + await vh(scope, Mock(), send) + + # the middleware wraps the send function so we need to extract the function + # the app was actually called with + local_send = app.call_args[0][2] + + # Calling the wrapped send method here is equivalent to what the downstream + # framework would do after the middleware has done its thing + message = {"type": message_type, "headers": []} + await local_send(message) + + # Check the headers were sent to the original send method + send.assert_called_once_with( + {"type": message_type, "headers": [VERSION, API_VERSION]} + ) + + +async def test_version_headers_ignore_non_http_or_websockets(app: Mock): + vh = VersionHeaders(app) + + scope = {"type": "other"} + send = Mock() + recv = Mock() + + await vh(scope, recv, send) + + # for non-http/ws requests, the original args are passed directly + app.assert_called_once_with(scope, recv, send) + + +async def test_obs_context_ignores_non_http_or_websockets(app: Mock): + ocp = ObservabilityContextPropagator(app) + + scope = MagicMock() + scope.__getitem__.side_effect = {"type": "other"}.__getitem__ + + with patch("blueapi.service.middleware.attach") as att: + await ocp(scope, Mock(), Mock()) + + att.assert_not_called() + scope.get.assert_not_called() + + +@pytest.mark.parametrize("protocol", ["http", "websocket"]) +async def test_obs_context_passes_context(app: Mock, protocol: str): + ocp = ObservabilityContextPropagator(app) + scope = {"type": protocol, "headers": ((CONTEXT_HEADER, b"req_context"),)} + + with patch("blueapi.service.middleware.attach") as att: + with patch("blueapi.service.middleware.get_global_textmap") as get_global: + get_global.return_value.extract.side_effect = lambda x: x + await ocp(scope, Mock(), Mock()) + + att.assert_called_once_with({ApplicationConfig.CONTEXT_HEADER: "req_context"}) + + +@pytest.mark.parametrize("protocol", ["http", "websocket"]) +async def test_obs_context_passes_vendor_context(app: Mock, protocol: str): + ocp = ObservabilityContextPropagator(app) + scope = { + "type": protocol, + "headers": ( + (CONTEXT_HEADER, b"req_context"), + (VENDOR_CONTEXT_HEADER, b"vendor_context"), + ), + } + + with patch("blueapi.service.middleware.attach") as att: + with patch("blueapi.service.middleware.get_global_textmap") as get_global: + get_global.return_value.extract.side_effect = lambda x: x + await ocp(scope, Mock(), Mock()) + + att.assert_called_once_with( + { + ApplicationConfig.CONTEXT_HEADER: "req_context", + ApplicationConfig.VENDOR_CONTEXT_HEADER: "vendor_context", + } + ) From 1ea04f6f25cc0e836dcc3f615a528aecd28d1eba Mon Sep 17 00:00:00 2001 From: Peter Holloway Date: Wed, 15 Apr 2026 14:53:08 +0100 Subject: [PATCH 3/4] Make scope attribute access consistent --- src/blueapi/service/middleware.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/blueapi/service/middleware.py b/src/blueapi/service/middleware.py index 489fdbf7a..b31fe0fb9 100644 --- a/src/blueapi/service/middleware.py +++ b/src/blueapi/service/middleware.py @@ -30,7 +30,7 @@ async def local_send(message: Message): message["headers"].append(API_VERSION) await send(message) - await self.app(scope, receive, local_send) + return await self.app(scope, receive, local_send) class ObservabilityContextPropagator: @@ -38,7 +38,7 @@ def __init__(self, app: ASGIApp): self.app = app async def __call__(self, scope: Scope, receive: Receive, send: Send): - if scope["type"] not in ("http", "websocket"): + if scope.get("type") not in ("http", "websocket"): return await self.app(scope, receive, send) ctx = None @@ -55,4 +55,4 @@ async def __call__(self, scope: Scope, receive: Receive, send: Send): carrier[ApplicationConfig.VENDOR_CONTEXT_HEADER] = v_ctx attach(get_global_textmap().extract(carrier)) - await self.app(scope, receive, send) + return await self.app(scope, receive, send) From b2308cfa666119ac98d263f4be44ab1b0151a76b Mon Sep 17 00:00:00 2001 From: Peter Holloway Date: Wed, 22 Apr 2026 15:01:51 +0100 Subject: [PATCH 4/4] Update tests for scope['type'] access --- tests/unit_tests/service/test_middleware.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/unit_tests/service/test_middleware.py b/tests/unit_tests/service/test_middleware.py index 1ee183bc4..5f6dbeac6 100644 --- a/tests/unit_tests/service/test_middleware.py +++ b/tests/unit_tests/service/test_middleware.py @@ -68,7 +68,7 @@ async def test_obs_context_ignores_non_http_or_websockets(app: Mock): await ocp(scope, Mock(), Mock()) att.assert_not_called() - scope.get.assert_not_called() + scope.get.assert_called_once_with("type") @pytest.mark.parametrize("protocol", ["http", "websocket"])