diff --git a/src/blueapi/service/main.py b/src/blueapi/service/main.py index 679119238..e755044ba 100644 --- a/src/blueapi/service/main.py +++ b/src/blueapi/service/main.py @@ -25,18 +25,19 @@ get_tracer, start_as_current_span, ) -from opentelemetry.context import attach from opentelemetry.instrumentation.fastapi import FastAPIInstrumentor -from opentelemetry.propagate import get_global_textmap from opentelemetry.trace import get_tracer_provider from pydantic import ValidationError from pydantic.json_schema import SkipJsonSchema from starlette.responses import JSONResponse from super_state_machine.errors import TransitionError -from blueapi import __version__ from blueapi.config import ApplicationConfig, OIDCConfig, Tag from blueapi.service import interface +from blueapi.service.middleware import ( + ObservabilityContextPropagator, + VersionHeaders, +) from blueapi.worker import TrackableTask, WorkerState from blueapi.worker.event import TaskStatusEnum @@ -126,8 +127,9 @@ def get_app(config: ApplicationConfig): app.include_router(secure_router_v1, dependencies=dependencies) app.add_exception_handler(KeyError, on_key_error_404) app.add_exception_handler(jwt.PyJWTError, on_token_error_401) - app.middleware("http")(add_version_headers) - app.middleware("http")(inject_propagated_observability_context) + + app.add_middleware(ObservabilityContextPropagator) + app.add_middleware(VersionHeaders) app.middleware("http")(log_request_details) if config.api.cors: app.add_middleware( @@ -607,15 +609,6 @@ def start(config: ApplicationConfig): ) -async def add_version_headers( - request: Request, call_next: Callable[[Request], Awaitable[Response]] -): - response = await call_next(request) - response.headers["X-API-Version"] = ApplicationConfig.REST_API_VERSION - response.headers["X-BlueAPI-Version"] = __version__ - return response - - async def log_request_details( request: Request, call_next: Callable[[Request], Awaitable[StreamingResponse]] ) -> Response: @@ -637,25 +630,3 @@ async def log_request_details( LOGGER.info(log_message, extra=extra) return response - - -async def inject_propagated_observability_context( - request: Request, call_next: Callable[[Request], Awaitable[Response]] -) -> Response: - """Middleware to extract any propagated observability context from the - HTTP headers and attach it to the local one. - """ - headers = request.headers - if ApplicationConfig.CONTEXT_HEADER in headers: - carrier = { - ApplicationConfig.CONTEXT_HEADER: headers[ApplicationConfig.CONTEXT_HEADER] - } - if ApplicationConfig.VENDOR_CONTEXT_HEADER in headers: - carrier[ApplicationConfig.VENDOR_CONTEXT_HEADER] = headers[ - ApplicationConfig.VENDOR_CONTEXT_HEADER - ] - ctx = get_global_textmap().extract(carrier) - - attach(ctx) - response = await call_next(request) - return response diff --git a/src/blueapi/service/middleware.py b/src/blueapi/service/middleware.py new file mode 100644 index 000000000..b31fe0fb9 --- /dev/null +++ b/src/blueapi/service/middleware.py @@ -0,0 +1,58 @@ +import logging + +from opentelemetry.context import attach +from opentelemetry.propagate import get_global_textmap +from starlette.types import ASGIApp, Message, Receive, Scope, Send + +from blueapi import __version__ +from blueapi.config import ApplicationConfig + +OBS_LOGGER = logging.getLogger("blueapi.service.middleware.observability") + +CONTEXT_HEADER = ApplicationConfig.CONTEXT_HEADER.encode() +VENDOR_CONTEXT_HEADER = ApplicationConfig.VENDOR_CONTEXT_HEADER.encode() + +API_VERSION = (b"x-api-version", ApplicationConfig.REST_API_VERSION.encode("utf-8")) +VERSION = (b"x-blueapi-version", __version__.encode("utf-8")) + + +class VersionHeaders: + def __init__(self, app: ASGIApp): + self.app = app + + async def __call__(self, scope: Scope, receive: Receive, send: Send): + if scope.get("type") not in ("websocket", "http"): + return await self.app(scope, receive, send) + + async def local_send(message: Message): + if message["type"] in ("websocket.accept", "http.response.start"): + message["headers"].append(VERSION) + message["headers"].append(API_VERSION) + await send(message) + + return await self.app(scope, receive, local_send) + + +class ObservabilityContextPropagator: + def __init__(self, app: ASGIApp): + self.app = app + + async def __call__(self, scope: Scope, receive: Receive, send: Send): + if scope.get("type") not in ("http", "websocket"): + return await self.app(scope, receive, send) + + ctx = None + v_ctx = None + for key, val in scope.get("headers", ()): + if key == CONTEXT_HEADER: + ctx = val.decode() + elif key == VENDOR_CONTEXT_HEADER: + v_ctx = val.decode() + if ctx: + OBS_LOGGER.debug("Propagating observability context: %s, %s", ctx, v_ctx) + carrier = {ApplicationConfig.CONTEXT_HEADER: ctx} + if v_ctx: + carrier[ApplicationConfig.VENDOR_CONTEXT_HEADER] = v_ctx + attach(get_global_textmap().extract(carrier)) + + return await self.app(scope, receive, send) diff --git a/tests/unit_tests/service/test_main.py b/tests/unit_tests/service/test_main.py index 4a4bcca63..fb689985e 100644 --- a/tests/unit_tests/service/test_main.py +++ b/tests/unit_tests/service/test_main.py @@ -8,15 +8,15 @@ from blueapi import __version__ from blueapi.config import ApplicationConfig from blueapi.service.main import ( - add_version_headers, get_passthrough_headers, log_request_details, ) +from blueapi.service.middleware import VersionHeaders async def test_add_version_header(): app = FastAPI() - app.middleware("http")(add_version_headers) + app.add_middleware(VersionHeaders) @app.get("/") async def root(): diff --git a/tests/unit_tests/service/test_middleware.py b/tests/unit_tests/service/test_middleware.py new file mode 100644 index 000000000..5f6dbeac6 --- /dev/null +++ b/tests/unit_tests/service/test_middleware.py @@ -0,0 +1,108 @@ +from unittest.mock import AsyncMock, MagicMock, Mock, patch + +import pytest +from starlette.types import ASGIApp + +from blueapi.config import ApplicationConfig +from blueapi.service.middleware import ( + API_VERSION, + CONTEXT_HEADER, + VENDOR_CONTEXT_HEADER, + VERSION, + ObservabilityContextPropagator, + VersionHeaders, +) + + +@pytest.fixture +def app(): + return AsyncMock(spec=ASGIApp) + + +@pytest.mark.parametrize( + "protocol,message_type", + [("http", "http.response.start"), ("websocket", "websocket.accept")], +) +async def test_version_headers_added(app: Mock, protocol: str, message_type: str): + vh = VersionHeaders(app) + + send = AsyncMock() + scope = {"type": protocol} + await vh(scope, Mock(), send) + + # the middleware wraps the send function so we need to extract the function + # the app was actually called with + local_send = app.call_args[0][2] + + # Calling the wrapped send method here is equivalent to what the downstream + # framework would do after the middleware has done its thing + message = {"type": message_type, "headers": []} + await local_send(message) + + # Check the headers were sent to the original send method + send.assert_called_once_with( + {"type": message_type, "headers": [VERSION, API_VERSION]} + ) + + +async def test_version_headers_ignore_non_http_or_websockets(app: Mock): + vh = VersionHeaders(app) + + scope = {"type": "other"} + send = Mock() + recv = Mock() + + await vh(scope, recv, send) + + # for non-http/ws requests, the original args are passed directly + app.assert_called_once_with(scope, recv, send) + + +async def test_obs_context_ignores_non_http_or_websockets(app: Mock): + ocp = ObservabilityContextPropagator(app) + + scope = MagicMock() + scope.__getitem__.side_effect = {"type": "other"}.__getitem__ + + with patch("blueapi.service.middleware.attach") as att: + await ocp(scope, Mock(), Mock()) + + att.assert_not_called() + scope.get.assert_called_once_with("type") + + +@pytest.mark.parametrize("protocol", ["http", "websocket"]) +async def test_obs_context_passes_context(app: Mock, protocol: str): + ocp = ObservabilityContextPropagator(app) + scope = {"type": protocol, "headers": ((CONTEXT_HEADER, b"req_context"),)} + + with patch("blueapi.service.middleware.attach") as att: + with patch("blueapi.service.middleware.get_global_textmap") as get_global: + get_global.return_value.extract.side_effect = lambda x: x + await ocp(scope, Mock(), Mock()) + + att.assert_called_once_with({ApplicationConfig.CONTEXT_HEADER: "req_context"}) + + +@pytest.mark.parametrize("protocol", ["http", "websocket"]) +async def test_obs_context_passes_vendor_context(app: Mock, protocol: str): + ocp = ObservabilityContextPropagator(app) + scope = { + "type": protocol, + "headers": ( + (CONTEXT_HEADER, b"req_context"), + (VENDOR_CONTEXT_HEADER, b"vendor_context"), + ), + } + + with patch("blueapi.service.middleware.attach") as att: + with patch("blueapi.service.middleware.get_global_textmap") as get_global: + get_global.return_value.extract.side_effect = lambda x: x + await ocp(scope, Mock(), Mock()) + + att.assert_called_once_with( + { + ApplicationConfig.CONTEXT_HEADER: "req_context", + ApplicationConfig.VENDOR_CONTEXT_HEADER: "vendor_context", + } + )