Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 2 additions & 4 deletions src/blueapi/service/interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,11 +147,9 @@ def get_device(name: str) -> DeviceModel:
return DeviceModel.from_device(device)


def submit_task(task_request: TaskRequest) -> str:
def submit_task(task_request: TaskRequest, metadata: dict[str, Any]) -> str:
"""Submit a task to be run on begin_task"""
metadata: dict[str, Any] = {
"instrument_session": task_request.instrument_session,
}
metadata["instrument_session"] = task_request.instrument_session
if context().tiled_conf is not None:
md = config().env.metadata
# We raise an InvalidConfigError on setting tiled_conf if this isn't set
Expand Down
22 changes: 16 additions & 6 deletions src/blueapi/service/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from collections.abc import Awaitable, Callable
from contextlib import asynccontextmanager
from enum import StrEnum
from typing import Annotated
from typing import Annotated, Any

import jwt
from fastapi import (
Expand Down Expand Up @@ -144,7 +144,7 @@ def get_app(config: ApplicationConfig):
)
dependencies = []
if config.oidc:
dependencies.append(Depends(verify_access_token(config.oidc)))
dependencies.append(Depends(decode_access_token(config.oidc)))
app.swagger_ui_init_oauth = {
"clientId": "NOT_SUPPORTED",
}
Expand All @@ -166,24 +166,25 @@ def get_app(config: ApplicationConfig):
return app


def verify_access_token(config: OIDCConfig):
def decode_access_token(config: OIDCConfig):
jwkclient = jwt.PyJWKClient(config.jwks_uri)
oauth_scheme = OAuth2AuthorizationCodeBearer(
authorizationUrl=config.authorization_endpoint,
tokenUrl=config.token_endpoint,
refreshUrl=config.token_endpoint,
)

def inner(access_token: str = Depends(oauth_scheme)):
def inner(request: Request, access_token: str = Depends(oauth_scheme)):
signing_key = jwkclient.get_signing_key_from_jwt(access_token)
jwt.decode(
decoded: dict[str, Any] = jwt.decode(
access_token,
signing_key.key,
algorithms=config.id_token_signing_alg_values_supported,
verify=True,
audience=config.client_audience,
issuer=config.issuer,
)
request.state.decoded_access_token = decoded

return inner

Expand Down Expand Up @@ -312,7 +313,16 @@ def submit_task(
) -> TaskResponse:
"""Submit a task to the worker."""
try:
task_id: str = runner.run(interface.submit_task, task_request)
# Extract user from jwt if using OIDC (if jwt exists)
access_token: dict[str, Any] | None = getattr(
request.state, "decoded_access_token", None
)
if access_token:
user: str = access_token.get("fedid", "Unknown")
else:
user = "Unknown"

task_id: str = runner.run(interface.submit_task, task_request, {"user": user})
response.headers["Location"] = f"{request.url}/{task_id}"
return TaskResponse(task_id=task_id)
except ValidationError as e:
Expand Down
4 changes: 2 additions & 2 deletions tests/system_tests/config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,8 @@ env:
metadata:
instrument: adsim
sources:
- kind: deviceManager
module: dodal.beamlines.adsim
# - kind: deviceManager
# module: dodal.beamlines.adsim
- kind: planFunctions
module: dodal.plans
- kind: planFunctions
Expand Down
10 changes: 5 additions & 5 deletions tests/unit_tests/service/test_authentication.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import os
from pathlib import Path
from typing import Any
from unittest.mock import patch
from unittest.mock import Mock, patch

import jwt
import pytest
Expand Down Expand Up @@ -117,18 +117,18 @@ def test_poll_for_token_timeout(
def test_server_raises_exception_for_invalid_token(
oidc_config: OIDCConfig, mock_authn_server: responses.RequestsMock
):
inner = main.verify_access_token(oidc_config)
inner = main.decode_access_token(oidc_config)
with pytest.raises(jwt.PyJWTError):
inner(access_token="Invalid Token")
inner(Mock(), access_token="Invalid Token")


def test_processes_valid_token(
oidc_config: OIDCConfig,
mock_authn_server: responses.RequestsMock,
valid_token_with_jwt,
):
inner = main.verify_access_token(oidc_config)
inner(access_token=valid_token_with_jwt["access_token"])
inner = main.decode_access_token(oidc_config)
inner(Mock(), access_token=valid_token_with_jwt["access_token"])


def test_session_cache_manager_returns_writable_file_path(tmp_path):
Expand Down
37 changes: 34 additions & 3 deletions tests/unit_tests/service/test_interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -195,7 +195,7 @@ def test_submit_task(context_mock: MagicMock):
mock_uuid_value = "8dfbb9c2-7a15-47b6-bea8-b6b77c31d3d9"
with patch.object(uuid, "uuid4") as uuid_mock:
uuid_mock.return_value = uuid.UUID(mock_uuid_value)
task_uuid = interface.submit_task(task)
task_uuid = interface.submit_task(task, {})
assert task_uuid == mock_uuid_value


Expand All @@ -211,7 +211,7 @@ def test_clear_task(context_mock: MagicMock):
mock_uuid_value = "3d858a62-b40a-400f-82af-8d2603a4e59a"
with patch.object(uuid, "uuid4") as uuid_mock:
uuid_mock.return_value = uuid.UUID(mock_uuid_value)
interface.submit_task(task)
interface.submit_task(task, {})

clear_task_return = interface.clear_task(mock_uuid_value)
assert clear_task_return == mock_uuid_value
Expand Down Expand Up @@ -337,7 +337,8 @@ def test_get_task_by_id(
TaskRequest(
name="my_plan",
instrument_session=FAKE_INSTRUMENT_SESSION,
)
),
{},
)

expected_metadata: dict[str, Any] = {
Expand Down Expand Up @@ -366,6 +367,36 @@ def test_get_task_by_id(
)


@patch("blueapi.service.interface.context")
def test_submit_task_inserts_metadata(context_mock: MagicMock):
context = BlueskyContext()
context.register_plan(my_plan)
context_mock.return_value = context

metadata = {"foo": "bar"}

task_id = interface.submit_task(
TaskRequest(
name="my_plan",
instrument_session=FAKE_INSTRUMENT_SESSION,
),
metadata,
)

assert interface.get_task_by_id(task_id) == TrackableTask.model_construct(
task_id=task_id,
request_id=ANY,
task=Task(
name="my_plan",
params={},
metadata=metadata,
),
is_complete=False,
is_pending=True,
errors=[],
)


@patch("blueapi.service.interface.TiledWriter")
@patch("blueapi.service.interface.from_uri")
@patch("blueapi.service.interface.context")
Expand Down
27 changes: 25 additions & 2 deletions tests/unit_tests/service/test_rest_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,10 @@ def client(mock_runner: Mock) -> Iterator[TestClient]:

@pytest.fixture
def client_with_auth(
mock_runner: Mock, oidc_config: OIDCConfig, valid_token_with_jwt: dict[str, Any]
mock_runner: Mock,
oidc_config: OIDCConfig,
valid_token_with_jwt: dict[str, Any],
mock_authn_server,
) -> Iterator[TestClient]:
with patch("blueapi.service.interface.worker"):
main.setup_runner(runner=mock_runner)
Expand Down Expand Up @@ -248,10 +251,30 @@ def test_create_task(mock_runner: Mock, client: TestClient) -> None:

response = client.post("/tasks", json=task.model_dump())

mock_runner.run.assert_called_with(submit_task, task)
mock_runner.run.assert_called_with(submit_task, task, {"user": "Unknown"})
assert response.json() == {"task_id": task_id}


def test_create_task_inserts_auth_metadata(
mock_runner: Mock,
client_with_auth: TestClient,
) -> None:
task = TaskRequest(
name="count",
params={"detectors": ["x"]},
instrument_session=FAKE_INSTRUMENT_SESSION,
)
client_with_auth.follow_redirects = False
task_id = str(uuid.uuid4())

# mock_runner.run.side_effect = [task_id]
mock_runner.run.return_value = [task_id]

client_with_auth.post("/tasks", json=task.model_dump())

mock_runner.run.assert_called_with(submit_task, task, {"user": "jd1"})


def test_create_task_validation_error(mock_runner: Mock, client: TestClient) -> None:
mock_runner.run.side_effect = [
ValidationError.from_exception_data(
Expand Down
Loading