diff --git a/src/blueapi/service/interface.py b/src/blueapi/service/interface.py index 9bc8bcef8..72f01cac1 100644 --- a/src/blueapi/service/interface.py +++ b/src/blueapi/service/interface.py @@ -147,11 +147,9 @@ def get_device(name: str) -> DeviceModel: return DeviceModel.from_device(device) -def submit_task(task_request: TaskRequest) -> str: +def submit_task(task_request: TaskRequest, metadata: dict[str, Any]) -> str: """Submit a task to be run on begin_task""" - metadata: dict[str, Any] = { - "instrument_session": task_request.instrument_session, - } + metadata["instrument_session"] = task_request.instrument_session if context().tiled_conf is not None: md = config().env.metadata # We raise an InvalidConfigError on setting tiled_conf if this isn't set diff --git a/src/blueapi/service/main.py b/src/blueapi/service/main.py index 5aa44c533..a48929572 100644 --- a/src/blueapi/service/main.py +++ b/src/blueapi/service/main.py @@ -3,7 +3,7 @@ from collections.abc import Awaitable, Callable from contextlib import asynccontextmanager from enum import StrEnum -from typing import Annotated +from typing import Annotated, Any import jwt from fastapi import ( @@ -144,7 +144,7 @@ def get_app(config: ApplicationConfig): ) dependencies = [] if config.oidc: - dependencies.append(Depends(verify_access_token(config.oidc))) + dependencies.append(Depends(decode_access_token(config.oidc))) app.swagger_ui_init_oauth = { "clientId": "NOT_SUPPORTED", } @@ -166,7 +166,7 @@ def get_app(config: ApplicationConfig): return app -def verify_access_token(config: OIDCConfig): +def decode_access_token(config: OIDCConfig): jwkclient = jwt.PyJWKClient(config.jwks_uri) oauth_scheme = OAuth2AuthorizationCodeBearer( authorizationUrl=config.authorization_endpoint, @@ -174,9 +174,9 @@ def verify_access_token(config: OIDCConfig): refreshUrl=config.token_endpoint, ) - def inner(access_token: str = Depends(oauth_scheme)): + def inner(request: Request, access_token: str = Depends(oauth_scheme)): signing_key = jwkclient.get_signing_key_from_jwt(access_token) - jwt.decode( + decoded: dict[str, Any] = jwt.decode( access_token, signing_key.key, algorithms=config.id_token_signing_alg_values_supported, @@ -184,6 +184,7 @@ def inner(access_token: str = Depends(oauth_scheme)): audience=config.client_audience, issuer=config.issuer, ) + request.state.decoded_access_token = decoded return inner @@ -312,7 +313,16 @@ def submit_task( ) -> TaskResponse: """Submit a task to the worker.""" try: - task_id: str = runner.run(interface.submit_task, task_request) + # Extract user from jwt if using OIDC (if jwt exists) + access_token: dict[str, Any] | None = getattr( + request.state, "decoded_access_token", None + ) + if access_token: + user: str = access_token.get("fedid", "Unknown") + else: + user = "Unknown" + + task_id: str = runner.run(interface.submit_task, task_request, {"user": user}) response.headers["Location"] = f"{request.url}/{task_id}" return TaskResponse(task_id=task_id) except ValidationError as e: diff --git a/tests/system_tests/config.yaml b/tests/system_tests/config.yaml index 11dcf5f05..35952c444 100644 --- a/tests/system_tests/config.yaml +++ b/tests/system_tests/config.yaml @@ -4,8 +4,8 @@ env: metadata: instrument: adsim sources: - - kind: deviceManager - module: dodal.beamlines.adsim + # - kind: deviceManager + # module: dodal.beamlines.adsim - kind: planFunctions module: dodal.plans - kind: planFunctions diff --git a/tests/unit_tests/service/test_authentication.py b/tests/unit_tests/service/test_authentication.py index e86dbc490..281c2be03 100644 --- a/tests/unit_tests/service/test_authentication.py +++ b/tests/unit_tests/service/test_authentication.py @@ -1,7 +1,7 @@ import os from pathlib import Path from typing import Any -from unittest.mock import patch +from unittest.mock import Mock, patch import jwt import pytest @@ -117,9 +117,9 @@ def test_poll_for_token_timeout( def test_server_raises_exception_for_invalid_token( oidc_config: OIDCConfig, mock_authn_server: responses.RequestsMock ): - inner = main.verify_access_token(oidc_config) + inner = main.decode_access_token(oidc_config) with pytest.raises(jwt.PyJWTError): - inner(access_token="Invalid Token") + inner(Mock(), access_token="Invalid Token") def test_processes_valid_token( @@ -127,8 +127,8 @@ def test_processes_valid_token( mock_authn_server: responses.RequestsMock, valid_token_with_jwt, ): - inner = main.verify_access_token(oidc_config) - inner(access_token=valid_token_with_jwt["access_token"]) + inner = main.decode_access_token(oidc_config) + inner(Mock(), access_token=valid_token_with_jwt["access_token"]) def test_session_cache_manager_returns_writable_file_path(tmp_path): diff --git a/tests/unit_tests/service/test_interface.py b/tests/unit_tests/service/test_interface.py index da6f1a4d9..405cf33a2 100644 --- a/tests/unit_tests/service/test_interface.py +++ b/tests/unit_tests/service/test_interface.py @@ -195,7 +195,7 @@ def test_submit_task(context_mock: MagicMock): mock_uuid_value = "8dfbb9c2-7a15-47b6-bea8-b6b77c31d3d9" with patch.object(uuid, "uuid4") as uuid_mock: uuid_mock.return_value = uuid.UUID(mock_uuid_value) - task_uuid = interface.submit_task(task) + task_uuid = interface.submit_task(task, {}) assert task_uuid == mock_uuid_value @@ -211,7 +211,7 @@ def test_clear_task(context_mock: MagicMock): mock_uuid_value = "3d858a62-b40a-400f-82af-8d2603a4e59a" with patch.object(uuid, "uuid4") as uuid_mock: uuid_mock.return_value = uuid.UUID(mock_uuid_value) - interface.submit_task(task) + interface.submit_task(task, {}) clear_task_return = interface.clear_task(mock_uuid_value) assert clear_task_return == mock_uuid_value @@ -337,7 +337,8 @@ def test_get_task_by_id( TaskRequest( name="my_plan", instrument_session=FAKE_INSTRUMENT_SESSION, - ) + ), + {}, ) expected_metadata: dict[str, Any] = { @@ -366,6 +367,36 @@ def test_get_task_by_id( ) +@patch("blueapi.service.interface.context") +def test_submit_task_inserts_metadata(context_mock: MagicMock): + context = BlueskyContext() + context.register_plan(my_plan) + context_mock.return_value = context + + metadata = {"foo": "bar"} + + task_id = interface.submit_task( + TaskRequest( + name="my_plan", + instrument_session=FAKE_INSTRUMENT_SESSION, + ), + metadata, + ) + + assert interface.get_task_by_id(task_id) == TrackableTask.model_construct( + task_id=task_id, + request_id=ANY, + task=Task( + name="my_plan", + params={}, + metadata=metadata, + ), + is_complete=False, + is_pending=True, + errors=[], + ) + + @patch("blueapi.service.interface.TiledWriter") @patch("blueapi.service.interface.from_uri") @patch("blueapi.service.interface.context") diff --git a/tests/unit_tests/service/test_rest_api.py b/tests/unit_tests/service/test_rest_api.py index b0bc09dd3..caa7425f8 100644 --- a/tests/unit_tests/service/test_rest_api.py +++ b/tests/unit_tests/service/test_rest_api.py @@ -64,7 +64,10 @@ def client(mock_runner: Mock) -> Iterator[TestClient]: @pytest.fixture def client_with_auth( - mock_runner: Mock, oidc_config: OIDCConfig, valid_token_with_jwt: dict[str, Any] + mock_runner: Mock, + oidc_config: OIDCConfig, + valid_token_with_jwt: dict[str, Any], + mock_authn_server, ) -> Iterator[TestClient]: with patch("blueapi.service.interface.worker"): main.setup_runner(runner=mock_runner) @@ -248,10 +251,30 @@ def test_create_task(mock_runner: Mock, client: TestClient) -> None: response = client.post("/tasks", json=task.model_dump()) - mock_runner.run.assert_called_with(submit_task, task) + mock_runner.run.assert_called_with(submit_task, task, {"user": "Unknown"}) assert response.json() == {"task_id": task_id} +def test_create_task_inserts_auth_metadata( + mock_runner: Mock, + client_with_auth: TestClient, +) -> None: + task = TaskRequest( + name="count", + params={"detectors": ["x"]}, + instrument_session=FAKE_INSTRUMENT_SESSION, + ) + client_with_auth.follow_redirects = False + task_id = str(uuid.uuid4()) + + # mock_runner.run.side_effect = [task_id] + mock_runner.run.return_value = [task_id] + + client_with_auth.post("/tasks", json=task.model_dump()) + + mock_runner.run.assert_called_with(submit_task, task, {"user": "jd1"}) + + def test_create_task_validation_error(mock_runner: Mock, client: TestClient) -> None: mock_runner.run.side_effect = [ ValidationError.from_exception_data(