Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 7 additions & 36 deletions src/blueapi/service/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,18 +25,19 @@
get_tracer,
start_as_current_span,
)
from opentelemetry.context import attach
from opentelemetry.instrumentation.fastapi import FastAPIInstrumentor
from opentelemetry.propagate import get_global_textmap
from opentelemetry.trace import get_tracer_provider
from pydantic import ValidationError
from pydantic.json_schema import SkipJsonSchema
from starlette.responses import JSONResponse
from super_state_machine.errors import TransitionError

from blueapi import __version__
from blueapi.config import ApplicationConfig, OIDCConfig, Tag
from blueapi.service import interface
from blueapi.service.middleware import (
ObservabilityContextPropagator,
VersionHeaders,
)
from blueapi.worker import TrackableTask, WorkerState
from blueapi.worker.event import TaskStatusEnum

Expand Down Expand Up @@ -126,8 +127,9 @@ def get_app(config: ApplicationConfig):
app.include_router(secure_router_v1, dependencies=dependencies)
app.add_exception_handler(KeyError, on_key_error_404)
app.add_exception_handler(jwt.PyJWTError, on_token_error_401)
app.middleware("http")(add_version_headers)
app.middleware("http")(inject_propagated_observability_context)

app.add_middleware(ObservabilityContextPropagator)
app.add_middleware(VersionHeaders)
app.middleware("http")(log_request_details)
if config.api.cors:
app.add_middleware(
Expand Down Expand Up @@ -607,15 +609,6 @@ def start(config: ApplicationConfig):
)


async def add_version_headers(
request: Request, call_next: Callable[[Request], Awaitable[Response]]
):
response = await call_next(request)
response.headers["X-API-Version"] = ApplicationConfig.REST_API_VERSION
response.headers["X-BlueAPI-Version"] = __version__
return response


async def log_request_details(
request: Request, call_next: Callable[[Request], Awaitable[StreamingResponse]]
) -> Response:
Expand All @@ -637,25 +630,3 @@ async def log_request_details(
LOGGER.info(log_message, extra=extra)

return response


async def inject_propagated_observability_context(
request: Request, call_next: Callable[[Request], Awaitable[Response]]
) -> Response:
"""Middleware to extract any propagated observability context from the
HTTP headers and attach it to the local one.
"""
headers = request.headers
if ApplicationConfig.CONTEXT_HEADER in headers:
carrier = {
ApplicationConfig.CONTEXT_HEADER: headers[ApplicationConfig.CONTEXT_HEADER]
}
if ApplicationConfig.VENDOR_CONTEXT_HEADER in headers:
carrier[ApplicationConfig.VENDOR_CONTEXT_HEADER] = headers[
ApplicationConfig.VENDOR_CONTEXT_HEADER
]
ctx = get_global_textmap().extract(carrier)

attach(ctx)
response = await call_next(request)
return response
58 changes: 58 additions & 0 deletions src/blueapi/service/middleware.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
import logging

from opentelemetry.context import attach
from opentelemetry.propagate import get_global_textmap
from starlette.types import ASGIApp, Message, Receive, Scope, Send

from blueapi import __version__
from blueapi.config import ApplicationConfig

OBS_LOGGER = logging.getLogger("blueapi.service.middleware.observability")

CONTEXT_HEADER = ApplicationConfig.CONTEXT_HEADER.encode()
VENDOR_CONTEXT_HEADER = ApplicationConfig.VENDOR_CONTEXT_HEADER.encode()

API_VERSION = (b"x-api-version", ApplicationConfig.REST_API_VERSION.encode("utf-8"))
VERSION = (b"x-blueapi-version", __version__.encode("utf-8"))


class VersionHeaders:
def __init__(self, app: ASGIApp):
self.app = app

async def __call__(self, scope: Scope, receive: Receive, send: Send):
if scope.get("type") not in ("websocket", "http"):
return await self.app(scope, receive, send)

async def local_send(message: Message):
if message["type"] in ("websocket.accept", "http.response.start"):
message["headers"].append(VERSION)
message["headers"].append(API_VERSION)
await send(message)

return await self.app(scope, receive, local_send)


class ObservabilityContextPropagator:
Comment thread
tpoliaw marked this conversation as resolved.
def __init__(self, app: ASGIApp):
self.app = app

async def __call__(self, scope: Scope, receive: Receive, send: Send):
if scope.get("type") not in ("http", "websocket"):
return await self.app(scope, receive, send)

ctx = None
v_ctx = None
for key, val in scope.get("headers", ()):
if key == CONTEXT_HEADER:
ctx = val.decode()
elif key == VENDOR_CONTEXT_HEADER:
v_ctx = val.decode()
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I assume that CONTEXT_HEADER and VENDOR_CONTEXT_HEADER will only appear on headers once?

Copy link
Copy Markdown
Contributor Author

@tpoliaw tpoliaw Apr 15, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I hope so. And if not I think the 'last entry wins' approach is probably fine

if ctx:
OBS_LOGGER.debug("Propagating observability context: %s, %s", ctx, v_ctx)
carrier = {ApplicationConfig.CONTEXT_HEADER: ctx}
if v_ctx:
carrier[ApplicationConfig.VENDOR_CONTEXT_HEADER] = v_ctx
attach(get_global_textmap().extract(carrier))

return await self.app(scope, receive, send)
4 changes: 2 additions & 2 deletions tests/unit_tests/service/test_main.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,15 +8,15 @@
from blueapi import __version__
from blueapi.config import ApplicationConfig
from blueapi.service.main import (
add_version_headers,
get_passthrough_headers,
log_request_details,
)
from blueapi.service.middleware import VersionHeaders


async def test_add_version_header():
app = FastAPI()
app.middleware("http")(add_version_headers)
app.add_middleware(VersionHeaders)

@app.get("/")
async def root():
Expand Down
108 changes: 108 additions & 0 deletions tests/unit_tests/service/test_middleware.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,108 @@
from unittest.mock import AsyncMock, MagicMock, Mock, patch

import pytest
from starlette.types import ASGIApp

from blueapi.config import ApplicationConfig
from blueapi.service.middleware import (
API_VERSION,
CONTEXT_HEADER,
VENDOR_CONTEXT_HEADER,
VERSION,
ObservabilityContextPropagator,
VersionHeaders,
)


@pytest.fixture
def app():
return AsyncMock(spec=ASGIApp)


@pytest.mark.parametrize(
"protocol,message_type",
[("http", "http.response.start"), ("websocket", "websocket.accept")],
)
async def test_version_headers_added(app: Mock, protocol: str, message_type: str):
vh = VersionHeaders(app)

send = AsyncMock()
scope = {"type": protocol}
await vh(scope, Mock(), send)

# the middleware wraps the send function so we need to extract the function
# the app was actually called with
local_send = app.call_args[0][2]

# Calling the wrapped send method here is equivalent to what the downstream
# framework would do after the middleware has done its thing
message = {"type": message_type, "headers": []}
await local_send(message)

# Check the headers were sent to the original send method
send.assert_called_once_with(
{"type": message_type, "headers": [VERSION, API_VERSION]}
)


async def test_version_headers_ignore_non_http_or_websockets(app: Mock):
vh = VersionHeaders(app)

scope = {"type": "other"}
send = Mock()
recv = Mock()

await vh(scope, recv, send)

# for non-http/ws requests, the original args are passed directly
app.assert_called_once_with(scope, recv, send)


async def test_obs_context_ignores_non_http_or_websockets(app: Mock):
ocp = ObservabilityContextPropagator(app)

scope = MagicMock()
scope.__getitem__.side_effect = {"type": "other"}.__getitem__

with patch("blueapi.service.middleware.attach") as att:
await ocp(scope, Mock(), Mock())

att.assert_not_called()
scope.get.assert_called_once_with("type")


@pytest.mark.parametrize("protocol", ["http", "websocket"])
async def test_obs_context_passes_context(app: Mock, protocol: str):
ocp = ObservabilityContextPropagator(app)
scope = {"type": protocol, "headers": ((CONTEXT_HEADER, b"req_context"),)}

with patch("blueapi.service.middleware.attach") as att:
with patch("blueapi.service.middleware.get_global_textmap") as get_global:
get_global.return_value.extract.side_effect = lambda x: x
await ocp(scope, Mock(), Mock())

att.assert_called_once_with({ApplicationConfig.CONTEXT_HEADER: "req_context"})


@pytest.mark.parametrize("protocol", ["http", "websocket"])
async def test_obs_context_passes_vendor_context(app: Mock, protocol: str):
ocp = ObservabilityContextPropagator(app)
scope = {
"type": protocol,
"headers": (
(CONTEXT_HEADER, b"req_context"),
(VENDOR_CONTEXT_HEADER, b"vendor_context"),
),
}

with patch("blueapi.service.middleware.attach") as att:
with patch("blueapi.service.middleware.get_global_textmap") as get_global:
get_global.return_value.extract.side_effect = lambda x: x
await ocp(scope, Mock(), Mock())

att.assert_called_once_with(
{
ApplicationConfig.CONTEXT_HEADER: "req_context",
ApplicationConfig.VENDOR_CONTEXT_HEADER: "vendor_context",
}
)
Loading