From e2d111661d33ec81078064a2ef594e5c425d9aa4 Mon Sep 17 00:00:00 2001 From: Daniel Fernandes Date: Fri, 13 Feb 2026 10:13:55 +0000 Subject: [PATCH 1/6] Extract user from jwt and add to metadata at task submit --- src/blueapi/service/interface.py | 6 ++---- src/blueapi/service/main.py | 20 ++++++++++++++------ 2 files changed, 16 insertions(+), 10 deletions(-) diff --git a/src/blueapi/service/interface.py b/src/blueapi/service/interface.py index 9bc8bcef8..72f01cac1 100644 --- a/src/blueapi/service/interface.py +++ b/src/blueapi/service/interface.py @@ -147,11 +147,9 @@ def get_device(name: str) -> DeviceModel: return DeviceModel.from_device(device) -def submit_task(task_request: TaskRequest) -> str: +def submit_task(task_request: TaskRequest, metadata: dict[str, Any]) -> str: """Submit a task to be run on begin_task""" - metadata: dict[str, Any] = { - "instrument_session": task_request.instrument_session, - } + metadata["instrument_session"] = task_request.instrument_session if context().tiled_conf is not None: md = config().env.metadata # We raise an InvalidConfigError on setting tiled_conf if this isn't set diff --git a/src/blueapi/service/main.py b/src/blueapi/service/main.py index 74b6e8193..083e53c28 100644 --- a/src/blueapi/service/main.py +++ b/src/blueapi/service/main.py @@ -3,7 +3,7 @@ from collections.abc import Awaitable, Callable from contextlib import asynccontextmanager from enum import Enum -from typing import Annotated +from typing import Annotated, Any import jwt from fastapi import ( @@ -144,7 +144,7 @@ def get_app(config: ApplicationConfig): ) dependencies = [] if config.oidc: - dependencies.append(Depends(verify_access_token(config.oidc))) + dependencies.append(Depends(decode_access_token(config.oidc))) app.swagger_ui_init_oauth = { "clientId": "NOT_SUPPORTED", } @@ -166,7 +166,7 @@ def get_app(config: ApplicationConfig): return app -def verify_access_token(config: OIDCConfig): +def decode_access_token(config: OIDCConfig): jwkclient = jwt.PyJWKClient(config.jwks_uri) oauth_scheme = OAuth2AuthorizationCodeBearer( authorizationUrl=config.authorization_endpoint, @@ -174,9 +174,9 @@ def verify_access_token(config: OIDCConfig): refreshUrl=config.token_endpoint, ) - def inner(access_token: str = Depends(oauth_scheme)): + def inner(request: Request, access_token: str = Depends(oauth_scheme)): signing_key = jwkclient.get_signing_key_from_jwt(access_token) - jwt.decode( + decoded: dict[str, Any] = jwt.decode( access_token, signing_key.key, algorithms=config.id_token_signing_alg_values_supported, @@ -184,6 +184,7 @@ def inner(access_token: str = Depends(oauth_scheme)): audience=config.client_audience, issuer=config.issuer, ) + request.state.decoded_jwt = decoded return inner @@ -312,7 +313,14 @@ def submit_task( ) -> TaskResponse: """Submit a task to the worker.""" try: - task_id: str = runner.run(interface.submit_task, task_request) + # Extract user from jwt if using OIDC (if jwt exists) + jwt: dict[str, Any] | None = getattr(request.state, "decoded_jwt", None) + if jwt: + user: str = getattr(jwt, "fedid", "Unknown") + else: + user = "Unknown" + + task_id: str = runner.run(interface.submit_task, task_request, {"user": user}) response.headers["Location"] = f"{request.url}/{task_id}" return TaskResponse(task_id=task_id) except ValidationError as e: From 77c292d14c7a7bd8c5544bf7f628a7e66884c8fc Mon Sep 17 00:00:00 2001 From: Daniel Fernandes Date: Fri, 13 Feb 2026 10:15:04 +0000 Subject: [PATCH 2/6] Initial test edits --- tests/system_tests/config.yaml | 4 +-- tests/unit_tests/service/test_interface.py | 37 ++++++++++++++++++++-- tests/unit_tests/service/test_rest_api.py | 23 +++++++++++++- 3 files changed, 58 insertions(+), 6 deletions(-) diff --git a/tests/system_tests/config.yaml b/tests/system_tests/config.yaml index 11dcf5f05..35952c444 100644 --- a/tests/system_tests/config.yaml +++ b/tests/system_tests/config.yaml @@ -4,8 +4,8 @@ env: metadata: instrument: adsim sources: - - kind: deviceManager - module: dodal.beamlines.adsim + # - kind: deviceManager + # module: dodal.beamlines.adsim - kind: planFunctions module: dodal.plans - kind: planFunctions diff --git a/tests/unit_tests/service/test_interface.py b/tests/unit_tests/service/test_interface.py index da6f1a4d9..405cf33a2 100644 --- a/tests/unit_tests/service/test_interface.py +++ b/tests/unit_tests/service/test_interface.py @@ -195,7 +195,7 @@ def test_submit_task(context_mock: MagicMock): mock_uuid_value = "8dfbb9c2-7a15-47b6-bea8-b6b77c31d3d9" with patch.object(uuid, "uuid4") as uuid_mock: uuid_mock.return_value = uuid.UUID(mock_uuid_value) - task_uuid = interface.submit_task(task) + task_uuid = interface.submit_task(task, {}) assert task_uuid == mock_uuid_value @@ -211,7 +211,7 @@ def test_clear_task(context_mock: MagicMock): mock_uuid_value = "3d858a62-b40a-400f-82af-8d2603a4e59a" with patch.object(uuid, "uuid4") as uuid_mock: uuid_mock.return_value = uuid.UUID(mock_uuid_value) - interface.submit_task(task) + interface.submit_task(task, {}) clear_task_return = interface.clear_task(mock_uuid_value) assert clear_task_return == mock_uuid_value @@ -337,7 +337,8 @@ def test_get_task_by_id( TaskRequest( name="my_plan", instrument_session=FAKE_INSTRUMENT_SESSION, - ) + ), + {}, ) expected_metadata: dict[str, Any] = { @@ -366,6 +367,36 @@ def test_get_task_by_id( ) +@patch("blueapi.service.interface.context") +def test_submit_task_inserts_metadata(context_mock: MagicMock): + context = BlueskyContext() + context.register_plan(my_plan) + context_mock.return_value = context + + metadata = {"foo": "bar"} + + task_id = interface.submit_task( + TaskRequest( + name="my_plan", + instrument_session=FAKE_INSTRUMENT_SESSION, + ), + metadata, + ) + + assert interface.get_task_by_id(task_id) == TrackableTask.model_construct( + task_id=task_id, + request_id=ANY, + task=Task( + name="my_plan", + params={}, + metadata=metadata, + ), + is_complete=False, + is_pending=True, + errors=[], + ) + + @patch("blueapi.service.interface.TiledWriter") @patch("blueapi.service.interface.from_uri") @patch("blueapi.service.interface.context") diff --git a/tests/unit_tests/service/test_rest_api.py b/tests/unit_tests/service/test_rest_api.py index b0bc09dd3..848670ead 100644 --- a/tests/unit_tests/service/test_rest_api.py +++ b/tests/unit_tests/service/test_rest_api.py @@ -248,10 +248,31 @@ def test_create_task(mock_runner: Mock, client: TestClient) -> None: response = client.post("/tasks", json=task.model_dump()) - mock_runner.run.assert_called_with(submit_task, task) + mock_runner.run.assert_called_with(submit_task, task, {"user": "Unknown"}) assert response.json() == {"task_id": task_id} +def test_create_task_inserts_auth_metadata( + mock_runner: Mock, + client_with_auth: TestClient, + mock_authn_server, +) -> None: + task = TaskRequest( + name="count", + params={"detectors": ["x"]}, + instrument_session=FAKE_INSTRUMENT_SESSION, + ) + client_with_auth.follow_redirects = False + task_id = str(uuid.uuid4()) + + # mock_runner.run.side_effect = [task_id] + mock_runner.run.return_value = [task_id] + + client_with_auth.post("/tasks", json=task.model_dump()) + + mock_runner.run.assert_called_with(submit_task, task, {"user": "Alice"}) + + def test_create_task_validation_error(mock_runner: Mock, client: TestClient) -> None: mock_runner.run.side_effect = [ ValidationError.from_exception_data( From 40f7d31111a68281e524aa4aa58dcb9ac7605793 Mon Sep 17 00:00:00 2001 From: Daniel Fernandes Date: Fri, 13 Feb 2026 10:45:54 +0000 Subject: [PATCH 3/6] Fix client_with_auth fixture, fix expected user in test --- tests/unit_tests/service/test_rest_api.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/tests/unit_tests/service/test_rest_api.py b/tests/unit_tests/service/test_rest_api.py index 848670ead..caa7425f8 100644 --- a/tests/unit_tests/service/test_rest_api.py +++ b/tests/unit_tests/service/test_rest_api.py @@ -64,7 +64,10 @@ def client(mock_runner: Mock) -> Iterator[TestClient]: @pytest.fixture def client_with_auth( - mock_runner: Mock, oidc_config: OIDCConfig, valid_token_with_jwt: dict[str, Any] + mock_runner: Mock, + oidc_config: OIDCConfig, + valid_token_with_jwt: dict[str, Any], + mock_authn_server, ) -> Iterator[TestClient]: with patch("blueapi.service.interface.worker"): main.setup_runner(runner=mock_runner) @@ -255,7 +258,6 @@ def test_create_task(mock_runner: Mock, client: TestClient) -> None: def test_create_task_inserts_auth_metadata( mock_runner: Mock, client_with_auth: TestClient, - mock_authn_server, ) -> None: task = TaskRequest( name="count", @@ -270,7 +272,7 @@ def test_create_task_inserts_auth_metadata( client_with_auth.post("/tasks", json=task.model_dump()) - mock_runner.run.assert_called_with(submit_task, task, {"user": "Alice"}) + mock_runner.run.assert_called_with(submit_task, task, {"user": "jd1"}) def test_create_task_validation_error(mock_runner: Mock, client: TestClient) -> None: From 4762574f48a663c549996bc7402c137f2c17fd8a Mon Sep 17 00:00:00 2001 From: Daniel Fernandes Date: Fri, 13 Feb 2026 10:46:24 +0000 Subject: [PATCH 4/6] Correct getting jwt from request state --- src/blueapi/service/main.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/src/blueapi/service/main.py b/src/blueapi/service/main.py index 083e53c28..19156b853 100644 --- a/src/blueapi/service/main.py +++ b/src/blueapi/service/main.py @@ -314,9 +314,11 @@ def submit_task( """Submit a task to the worker.""" try: # Extract user from jwt if using OIDC (if jwt exists) - jwt: dict[str, Any] | None = getattr(request.state, "decoded_jwt", None) - if jwt: - user: str = getattr(jwt, "fedid", "Unknown") + access_token: dict[str, Any] | None = getattr( + request.state, "decoded_jwt", None + ) + if access_token: + user: str = access_token.get("fedid", "Unknown") else: user = "Unknown" From cd23160ccda88c65d0ad993e2ca1d1e0bd655850 Mon Sep 17 00:00:00 2001 From: Daniel Fernandes Date: Fri, 13 Feb 2026 10:47:15 +0000 Subject: [PATCH 5/6] Rename decoded_jwt to decoded_access_token --- src/blueapi/service/main.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/blueapi/service/main.py b/src/blueapi/service/main.py index 19156b853..1a10af5d2 100644 --- a/src/blueapi/service/main.py +++ b/src/blueapi/service/main.py @@ -184,7 +184,7 @@ def inner(request: Request, access_token: str = Depends(oauth_scheme)): audience=config.client_audience, issuer=config.issuer, ) - request.state.decoded_jwt = decoded + request.state.decoded_access_token = decoded return inner @@ -315,7 +315,7 @@ def submit_task( try: # Extract user from jwt if using OIDC (if jwt exists) access_token: dict[str, Any] | None = getattr( - request.state, "decoded_jwt", None + request.state, "decoded_access_token", None ) if access_token: user: str = access_token.get("fedid", "Unknown") From 166a7ca393f8605e847451ee29b66c6bb584601e Mon Sep 17 00:00:00 2001 From: Daniel Fernandes Date: Fri, 13 Feb 2026 10:53:21 +0000 Subject: [PATCH 6/6] Rename calls --- tests/unit_tests/service/test_authentication.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/tests/unit_tests/service/test_authentication.py b/tests/unit_tests/service/test_authentication.py index e86dbc490..281c2be03 100644 --- a/tests/unit_tests/service/test_authentication.py +++ b/tests/unit_tests/service/test_authentication.py @@ -1,7 +1,7 @@ import os from pathlib import Path from typing import Any -from unittest.mock import patch +from unittest.mock import Mock, patch import jwt import pytest @@ -117,9 +117,9 @@ def test_poll_for_token_timeout( def test_server_raises_exception_for_invalid_token( oidc_config: OIDCConfig, mock_authn_server: responses.RequestsMock ): - inner = main.verify_access_token(oidc_config) + inner = main.decode_access_token(oidc_config) with pytest.raises(jwt.PyJWTError): - inner(access_token="Invalid Token") + inner(Mock(), access_token="Invalid Token") def test_processes_valid_token( @@ -127,8 +127,8 @@ def test_processes_valid_token( mock_authn_server: responses.RequestsMock, valid_token_with_jwt, ): - inner = main.verify_access_token(oidc_config) - inner(access_token=valid_token_with_jwt["access_token"]) + inner = main.decode_access_token(oidc_config) + inner(Mock(), access_token=valid_token_with_jwt["access_token"]) def test_session_cache_manager_returns_writable_file_path(tmp_path):