From f4c12b4e6df78342a0329dd2051865ba70e8fffe Mon Sep 17 00:00:00 2001 From: Peter Holloway Date: Tue, 27 Jan 2026 16:47:12 +0000 Subject: [PATCH 01/14] WIP returning results from plans --- src/blueapi/core/context.py | 2 +- src/blueapi/worker/event.py | 2 ++ src/blueapi/worker/task.py | 5 ++++- src/blueapi/worker/task_worker.py | 7 +++++-- 4 files changed, 12 insertions(+), 4 deletions(-) diff --git a/src/blueapi/core/context.py b/src/blueapi/core/context.py index 9bf29aec3d..56fd163a2b 100644 --- a/src/blueapi/core/context.py +++ b/src/blueapi/core/context.py @@ -120,7 +120,7 @@ class BlueskyContext: configuration: InitVar[ApplicationConfig | None] = None run_engine: RunEngine = field( - default_factory=lambda: RunEngine(context_managers=[]) + default_factory=lambda: RunEngine(context_managers=[], call_returns_result=True) ) tiled_conf: TiledConfig | None = field(default=None, init=False, repr=False) numtracker: NumtrackerClient | None = field(default=None, init=False, repr=False) diff --git a/src/blueapi/worker/event.py b/src/blueapi/worker/event.py index 2004a1a633..ed27506fac 100644 --- a/src/blueapi/worker/event.py +++ b/src/blueapi/worker/event.py @@ -1,5 +1,6 @@ from collections.abc import Mapping from enum import StrEnum +from typing import Any from bluesky.run_engine import RunEngineStateMachine from pydantic import Field @@ -109,6 +110,7 @@ class TaskStatus(BlueapiBaseModel): task_id: str task_complete: bool task_failed: bool + result: Any = None class WorkerEvent(BlueapiBaseModel): diff --git a/src/blueapi/worker/task.py b/src/blueapi/worker/task.py index 32060bc647..669b7fb05f 100644 --- a/src/blueapi/worker/task.py +++ b/src/blueapi/worker/task.py @@ -38,7 +38,10 @@ def do_task(self, ctx: BlueskyContext) -> None: func = ctx.plan_functions[self.name] prepared_params = self.prepare_params(ctx) ctx.run_engine.md.update(self.metadata) - ctx.run_engine(func(**prepared_params)) + result = ctx.run_engine(func(**prepared_params)) + if isinstance(result, tuple): + return None + return result.plan_result def _lookup_params(ctx: BlueskyContext, task: Task) -> BaseModel: diff --git a/src/blueapi/worker/task_worker.py b/src/blueapi/worker/task_worker.py index e1ed4970a9..1ced0b2dca 100644 --- a/src/blueapi/worker/task_worker.py +++ b/src/blueapi/worker/task_worker.py @@ -69,6 +69,7 @@ class TrackableTask(BlueapiBaseModel): is_complete: bool = False is_pending: bool = True errors: list[str] = Field(default_factory=list) + result: Any | None = None class TaskWorker: @@ -423,11 +424,12 @@ def _cycle(self) -> None: next_task: TrackableTask | KillSignal = self._task_channel.get() if isinstance(next_task, TrackableTask): - def process_task(): + def process_task() -> Any: LOGGER.info(f"Got new task: {next_task}") self._current = next_task self._current.is_pending = False - self._current.task.do_task(self._ctx) + result = self._current.task.do_task(self._ctx) + self._current.result = result with plan_tag_filter_context(next_task.task.name, LOGGER): if self._current_task_otel_context is not None: @@ -528,6 +530,7 @@ def _report_status( task_id=self._current.task_id, task_complete=self._current.is_complete, task_failed=bool(self._current.errors), + result=self._current.result, ) correlation_id = self._current.task_id add_span_attributes( From 3d6ac9d26138e221c83dc54245099d63a4608658 Mon Sep 17 00:00:00 2001 From: Peter Holloway Date: Wed, 28 Jan 2026 17:25:27 +0000 Subject: [PATCH 02/14] Store plan result as python dict/list/str etc Use pydantic to convert result to something that can be JSON serialized later. --- src/blueapi/worker/event.py | 16 +++++++++++++++- 1 file changed, 15 insertions(+), 1 deletion(-) diff --git a/src/blueapi/worker/event.py b/src/blueapi/worker/event.py index ed27506fac..3f30a6b860 100644 --- a/src/blueapi/worker/event.py +++ b/src/blueapi/worker/event.py @@ -1,9 +1,10 @@ +import logging from collections.abc import Mapping from enum import StrEnum from typing import Any from bluesky.run_engine import RunEngineStateMachine -from pydantic import Field +from pydantic import Field, TypeAdapter, field_validator from super_state_machine.extras import PropertyMachine, ProxyString from blueapi.utils import BlueapiBaseModel @@ -12,6 +13,8 @@ # RawRunEngineState = type[PropertyMachine | ProxyString | str] RawRunEngineState = PropertyMachine | ProxyString | str +log = logging.getLogger(__name__) + # NOTE this is interim until refactor class TaskStatusEnum(StrEnum): @@ -112,6 +115,17 @@ class TaskStatus(BlueapiBaseModel): task_failed: bool result: Any = None + @field_validator("result") + @classmethod + def _serialize_result(cls, result): + try: + return TypeAdapter(type(result)).dump_python(result) + except Exception: + log.warning( + "Plan result type (%s) not serializable: %s", type(result), result + ) + pass + class WorkerEvent(BlueapiBaseModel): """ From 8bc2a01cf7cc22716b3b1ab8a760d06f380cadbb Mon Sep 17 00:00:00 2001 From: Peter Holloway Date: Thu, 29 Jan 2026 16:14:16 +0000 Subject: [PATCH 03/14] Remove return type from process task --- src/blueapi/worker/task_worker.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/blueapi/worker/task_worker.py b/src/blueapi/worker/task_worker.py index 1ced0b2dca..2999b68fa9 100644 --- a/src/blueapi/worker/task_worker.py +++ b/src/blueapi/worker/task_worker.py @@ -424,7 +424,7 @@ def _cycle(self) -> None: next_task: TrackableTask | KillSignal = self._task_channel.get() if isinstance(next_task, TrackableTask): - def process_task() -> Any: + def process_task(): LOGGER.info(f"Got new task: {next_task}") self._current = next_task self._current.is_pending = False From cf51dbd9809998705a60c3c9aacc5a14b0902b75 Mon Sep 17 00:00:00 2001 From: Peter Holloway Date: Thu, 29 Jan 2026 17:09:07 +0000 Subject: [PATCH 04/14] Handle serialization of TrackableTask as well as TaskStatus --- src/blueapi/utils/base_model.py | 22 +++++++++++++++++++++- src/blueapi/worker/event.py | 17 +++-------------- src/blueapi/worker/task_worker.py | 4 ++-- 3 files changed, 26 insertions(+), 17 deletions(-) diff --git a/src/blueapi/utils/base_model.py b/src/blueapi/utils/base_model.py index 142052a8c3..daad9a0bd4 100644 --- a/src/blueapi/utils/base_model.py +++ b/src/blueapi/utils/base_model.py @@ -1,5 +1,14 @@ -from pydantic import BaseModel, ConfigDict +import logging +from typing import Annotated, Any +from pydantic import ( + BaseModel, + ConfigDict, + PlainSerializer, + TypeAdapter, +) + +logger = logging.getLogger(__name__) # Pydantic config for blueapi API models with common config. BlueapiModelConfig = ConfigDict( extra="forbid", @@ -16,6 +25,17 @@ ) +def _safe_serialize(value: Any) -> Any: + """Try serializing but skip any type that pydantic can't handle""" + try: + return TypeAdapter(type(value)).dump_python(value, mode="json") + except Exception: + logger.warning("Type '%s' not serializable: %s", type(value), value) + + +NoneFallback = Annotated[Any, PlainSerializer(_safe_serialize)] + + class BlueapiBaseModel(BaseModel): """ Base class for blueapi API models. diff --git a/src/blueapi/worker/event.py b/src/blueapi/worker/event.py index 3f30a6b860..0bba7bd800 100644 --- a/src/blueapi/worker/event.py +++ b/src/blueapi/worker/event.py @@ -1,13 +1,13 @@ import logging from collections.abc import Mapping from enum import StrEnum -from typing import Any from bluesky.run_engine import RunEngineStateMachine -from pydantic import Field, TypeAdapter, field_validator +from pydantic import Field from super_state_machine.extras import PropertyMachine, ProxyString from blueapi.utils import BlueapiBaseModel +from blueapi.utils.base_model import NoneFallback # The RunEngine can return any of these three types as its state # RawRunEngineState = type[PropertyMachine | ProxyString | str] @@ -113,18 +113,7 @@ class TaskStatus(BlueapiBaseModel): task_id: str task_complete: bool task_failed: bool - result: Any = None - - @field_validator("result") - @classmethod - def _serialize_result(cls, result): - try: - return TypeAdapter(type(result)).dump_python(result) - except Exception: - log.warning( - "Plan result type (%s) not serializable: %s", type(result), result - ) - pass + result: NoneFallback = None class WorkerEvent(BlueapiBaseModel): diff --git a/src/blueapi/worker/task_worker.py b/src/blueapi/worker/task_worker.py index 2999b68fa9..5bed1bffea 100644 --- a/src/blueapi/worker/task_worker.py +++ b/src/blueapi/worker/task_worker.py @@ -32,7 +32,7 @@ ) from blueapi.core.bluesky_event_loop import configure_bluesky_event_loop from blueapi.log import plan_tag_filter_context -from blueapi.utils.base_model import BlueapiBaseModel +from blueapi.utils.base_model import BlueapiBaseModel, NoneFallback from blueapi.utils.thread_exception import handle_all_exceptions from .event import ( @@ -69,7 +69,7 @@ class TrackableTask(BlueapiBaseModel): is_complete: bool = False is_pending: bool = True errors: list[str] = Field(default_factory=list) - result: Any | None = None + result: NoneFallback = None class TaskWorker: From cd1744fdc4a0096dda58a07df07d6d1a964f32ba Mon Sep 17 00:00:00 2001 From: Peter Holloway Date: Thu, 29 Jan 2026 17:23:37 +0000 Subject: [PATCH 05/14] Update tests --- docs/reference/openapi.yaml | 2 ++ tests/unit_tests/service/test_rest_api.py | 5 +++++ tests/unit_tests/test_cli.py | 9 ++++++--- 3 files changed, 13 insertions(+), 3 deletions(-) diff --git a/docs/reference/openapi.yaml b/docs/reference/openapi.yaml index 499e8cc52d..1380e642b7 100644 --- a/docs/reference/openapi.yaml +++ b/docs/reference/openapi.yaml @@ -314,6 +314,8 @@ components: request_id: title: Request Id type: string + result: + title: Result task: $ref: '#/components/schemas/Task' task_id: diff --git a/tests/unit_tests/service/test_rest_api.py b/tests/unit_tests/service/test_rest_api.py index b0bc09dd33..12555cc784 100644 --- a/tests/unit_tests/service/test_rest_api.py +++ b/tests/unit_tests/service/test_rest_api.py @@ -336,6 +336,7 @@ def test_get_tasks(mock_runner: Mock, client: TestClient) -> None: "params": {"time": 0.0}, "metadata": {}, }, + "result": None, "task_id": "0", }, { @@ -348,6 +349,7 @@ def test_get_tasks(mock_runner: Mock, client: TestClient) -> None: "params": {}, "metadata": {}, }, + "result": None, "task_id": "1", }, ] @@ -379,6 +381,7 @@ def test_get_tasks_by_status(mock_runner: Mock, client: TestClient) -> None: "params": {}, "metadata": {}, }, + "result": None, "task_id": "3", } ] @@ -472,6 +475,7 @@ def test_get_task(mock_runner: Mock, client: TestClient): "foo": "bar", }, }, + "result": None, "task_id": f"{task_id}", } @@ -500,6 +504,7 @@ def test_get_all_tasks(mock_runner: Mock, client: TestClient): "is_complete": False, "is_pending": True, "request_id": None, + "result": None, "errors": [], } ] diff --git a/tests/unit_tests/test_cli.py b/tests/unit_tests/test_cli.py index 2a49a1fe80..1aaf585581 100644 --- a/tests/unit_tests/test_cli.py +++ b/tests/unit_tests/test_cli.py @@ -929,9 +929,12 @@ def test_event_formatting(): OutputFormat.JSON, worker, ( - """{"state": "RUNNING", "task_status": """ - """{"task_id": "count", "task_complete": false, "task_failed": false}, """ - """"errors": [], "warnings": []}\n""" + '{"state": "RUNNING", "task_status": {' + '"task_id": "count", ' + '"task_complete": false, ' + '"task_failed": false, ' + '"result": null' + '}, "errors": [], "warnings": []}\n' ), ) _assert_matching_formatting(OutputFormat.COMPACT, worker, "Worker Event: RUNNING\n") From 4360291c55ed59a63bed64e2719bfa2e2a474a41 Mon Sep 17 00:00:00 2001 From: Peter Holloway Date: Mon, 9 Feb 2026 11:36:31 +0000 Subject: [PATCH 06/14] Add test --- tests/unit_tests/worker/test_task_worker.py | 18 ++++++++++++++++++ 1 file changed, 18 insertions(+) diff --git a/tests/unit_tests/worker/test_task_worker.py b/tests/unit_tests/worker/test_task_worker.py index 2793e88077..a63501154c 100644 --- a/tests/unit_tests/worker/test_task_worker.py +++ b/tests/unit_tests/worker/test_task_worker.py @@ -79,6 +79,11 @@ def failing_plan() -> MsgGenerator: raise KeyError("I failed") +def returning_plan() -> MsgGenerator[int]: + yield from [] + return 42 + + @pytest.fixture def fake_device() -> FakeDevice: return FakeDevice() @@ -95,6 +100,7 @@ def context(fake_device: FakeDevice, second_fake_device: FakeDevice) -> BlueskyC ctx_config = EnvironmentConfig() ctx_config.sources.append(DeviceSource(module="devices")) ctx.register_plan(failing_plan) + ctx.register_plan(returning_plan) ctx.register_device(fake_device) ctx.register_device(second_fake_device) ctx.with_config(ctx_config) @@ -269,6 +275,18 @@ def test_begin_task_uses_plan_name_filter( filter_mock.assert_called_once() +def test_return_value_recorded(worker: TaskWorker): + task_id = worker.submit_task(Task(name="returning_plan", params={})) + events_future = take_events( + worker.worker_events, + lambda evt: evt.task_status is not None and evt.task_status.task_complete, + ) + worker.begin_task(task_id) + events = events_future.result(timeout=2.0) + assert events[-1].task_status is not None + assert events[-1].task_status.result == 42 + + def test_plan_failure_recorded_in_active_task(worker: TaskWorker) -> None: task_id = worker.submit_task(_FAILING_TASK) events_future: Future[list[WorkerEvent]] = take_events( From 17548448c606cd9be0a40064cb4a4f71821d2edc Mon Sep 17 00:00:00 2001 From: Peter Holloway Date: Mon, 9 Feb 2026 11:42:09 +0000 Subject: [PATCH 07/14] Update schema version --- docs/reference/openapi.yaml | 2 +- src/blueapi/service/main.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/docs/reference/openapi.yaml b/docs/reference/openapi.yaml index 1380e642b7..45e33ed31d 100644 --- a/docs/reference/openapi.yaml +++ b/docs/reference/openapi.yaml @@ -379,7 +379,7 @@ info: name: Apache 2.0 url: https://www.apache.org/licenses/LICENSE-2.0.html title: BlueAPI Control - version: 1.1.2 + version: 1.1.3 openapi: 3.1.0 paths: /config/oidc: diff --git a/src/blueapi/service/main.py b/src/blueapi/service/main.py index 5aa44c533e..02c33e514d 100644 --- a/src/blueapi/service/main.py +++ b/src/blueapi/service/main.py @@ -58,7 +58,7 @@ from .runner import WorkerDispatcher #: API version to publish in OpenAPI schema -REST_API_VERSION = "1.1.2" +REST_API_VERSION = "1.1.3" LICENSE_INFO: dict[str, str] = { "name": "Apache 2.0", From 1e7faf6417c6127149e4267ee5171785695fa64d Mon Sep 17 00:00:00 2001 From: Peter Holloway Date: Mon, 9 Feb 2026 11:49:34 +0000 Subject: [PATCH 08/14] minor bump instead of patch --- docs/reference/openapi.yaml | 2 +- src/blueapi/service/main.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/docs/reference/openapi.yaml b/docs/reference/openapi.yaml index 45e33ed31d..88a50972db 100644 --- a/docs/reference/openapi.yaml +++ b/docs/reference/openapi.yaml @@ -379,7 +379,7 @@ info: name: Apache 2.0 url: https://www.apache.org/licenses/LICENSE-2.0.html title: BlueAPI Control - version: 1.1.3 + version: 1.2.0 openapi: 3.1.0 paths: /config/oidc: diff --git a/src/blueapi/service/main.py b/src/blueapi/service/main.py index 02c33e514d..6f2acbd156 100644 --- a/src/blueapi/service/main.py +++ b/src/blueapi/service/main.py @@ -58,7 +58,7 @@ from .runner import WorkerDispatcher #: API version to publish in OpenAPI schema -REST_API_VERSION = "1.1.3" +REST_API_VERSION = "1.2.0" LICENSE_INFO: dict[str, str] = { "name": "Apache 2.0", From 72922f3c0a85b2e1a252d55966c39ae8e1dcfaf9 Mon Sep 17 00:00:00 2001 From: Peter Holloway Date: Mon, 9 Feb 2026 12:13:16 +0000 Subject: [PATCH 09/14] Add tests for result serialization --- tests/unit_tests/worker/test_task_worker.py | 51 ++++++++++++++++++--- 1 file changed, 44 insertions(+), 7 deletions(-) diff --git a/tests/unit_tests/worker/test_task_worker.py b/tests/unit_tests/worker/test_task_worker.py index a63501154c..45fd42c734 100644 --- a/tests/unit_tests/worker/test_task_worker.py +++ b/tests/unit_tests/worker/test_task_worker.py @@ -84,6 +84,23 @@ def returning_plan() -> MsgGenerator[int]: return 42 +@dataclasses.dataclass +class ComplexReturn: + foo: int + bar: list[str] + + +class ModelReturn(pydantic.BaseModel): + foo: int + bar: list[str] + + +class Unreturnable: + def __init__(self, foo: int, bar: list[str]): + self.foo = foo + self.bar = bar + + @pytest.fixture def fake_device() -> FakeDevice: return FakeDevice() @@ -276,17 +293,37 @@ def test_begin_task_uses_plan_name_filter( def test_return_value_recorded(worker: TaskWorker): - task_id = worker.submit_task(Task(name="returning_plan", params={})) - events_future = take_events( - worker.worker_events, - lambda evt: evt.task_status is not None and evt.task_status.task_complete, - ) - worker.begin_task(task_id) - events = events_future.result(timeout=2.0) + task_id = worker.submit_task(Task(name="returning_plan")) + events = begin_task_and_wait_until_complete(worker, task_id) assert events[-1].task_status is not None assert events[-1].task_status.result == 42 +@pytest.mark.parametrize( + "result,serial", + [ + (42, 42), + ("helloWorld", "helloWorld"), + (ComplexReturn(34, ["foo"]), {"foo": 34, "bar": ["foo"]}), + (ModelReturn(foo=42, bar=["helloWorld"]), {"foo": 42, "bar": ["helloWorld"]}), + (Unreturnable(17, ["fizzbuzz"]), None), + ], +) +def test_plan_result_serialized(worker: TaskWorker, result: Any, serial: Any): + def result_plan() -> MsgGenerator: + yield from [] + return result + + worker._ctx.register_plan(result_plan) + task_id = worker.submit_task(Task(name="result_plan")) + events = begin_task_and_wait_until_complete(worker, task_id) + ts = events[-1].task_status + assert ts is not None + assert ts.result == result + + assert ts.model_dump()["result"] == serial + + def test_plan_failure_recorded_in_active_task(worker: TaskWorker) -> None: task_id = worker.submit_task(_FAILING_TASK) events_future: Future[list[WorkerEvent]] = take_events( From 6e38429d0a2c0d1f8ea941ba41f2f8fd422e1c2a Mon Sep 17 00:00:00 2001 From: Peter Holloway Date: Mon, 9 Feb 2026 20:43:37 +0000 Subject: [PATCH 10/14] Fix TrackableTask pickling --- src/blueapi/utils/base_model.py | 14 -------------- src/blueapi/worker/event.py | 4 ++-- src/blueapi/worker/task_worker.py | 17 +++++++++++++---- tests/unit_tests/worker/test_task_worker.py | 2 +- 4 files changed, 16 insertions(+), 21 deletions(-) diff --git a/src/blueapi/utils/base_model.py b/src/blueapi/utils/base_model.py index daad9a0bd4..1842ee02dc 100644 --- a/src/blueapi/utils/base_model.py +++ b/src/blueapi/utils/base_model.py @@ -1,11 +1,8 @@ import logging -from typing import Annotated, Any from pydantic import ( BaseModel, ConfigDict, - PlainSerializer, - TypeAdapter, ) logger = logging.getLogger(__name__) @@ -25,17 +22,6 @@ ) -def _safe_serialize(value: Any) -> Any: - """Try serializing but skip any type that pydantic can't handle""" - try: - return TypeAdapter(type(value)).dump_python(value, mode="json") - except Exception: - logger.warning("Type '%s' not serializable: %s", type(value), value) - - -NoneFallback = Annotated[Any, PlainSerializer(_safe_serialize)] - - class BlueapiBaseModel(BaseModel): """ Base class for blueapi API models. diff --git a/src/blueapi/worker/event.py b/src/blueapi/worker/event.py index 0bba7bd800..ca4496d273 100644 --- a/src/blueapi/worker/event.py +++ b/src/blueapi/worker/event.py @@ -1,13 +1,13 @@ import logging from collections.abc import Mapping from enum import StrEnum +from typing import Any from bluesky.run_engine import RunEngineStateMachine from pydantic import Field from super_state_machine.extras import PropertyMachine, ProxyString from blueapi.utils import BlueapiBaseModel -from blueapi.utils.base_model import NoneFallback # The RunEngine can return any of these three types as its state # RawRunEngineState = type[PropertyMachine | ProxyString | str] @@ -113,7 +113,7 @@ class TaskStatus(BlueapiBaseModel): task_id: str task_complete: bool task_failed: bool - result: NoneFallback = None + result: Any = None class WorkerEvent(BlueapiBaseModel): diff --git a/src/blueapi/worker/task_worker.py b/src/blueapi/worker/task_worker.py index 5bed1bffea..0d092e266a 100644 --- a/src/blueapi/worker/task_worker.py +++ b/src/blueapi/worker/task_worker.py @@ -18,7 +18,7 @@ from opentelemetry.baggage import get_baggage from opentelemetry.context import Context, get_current from opentelemetry.trace import SpanKind -from pydantic import Field +from pydantic import Field, TypeAdapter from pydantic.json_schema import SkipJsonSchema from super_state_machine.errors import TransitionError @@ -32,7 +32,7 @@ ) from blueapi.core.bluesky_event_loop import configure_bluesky_event_loop from blueapi.log import plan_tag_filter_context -from blueapi.utils.base_model import BlueapiBaseModel, NoneFallback +from blueapi.utils.base_model import BlueapiBaseModel from blueapi.utils.thread_exception import handle_all_exceptions from .event import ( @@ -69,7 +69,16 @@ class TrackableTask(BlueapiBaseModel): is_complete: bool = False is_pending: bool = True errors: list[str] = Field(default_factory=list) - result: NoneFallback = None + result: Any = None + + def set_result(self, result): + try: + self.result = TypeAdapter(type(result)).dump_python(result, mode="json") + except Exception: + LOGGER.warning( + "Plan result (%s) is not serializable so will not be available", result + ) + pass class TaskWorker: @@ -429,7 +438,7 @@ def process_task(): self._current = next_task self._current.is_pending = False result = self._current.task.do_task(self._ctx) - self._current.result = result + self._current.set_result(result) with plan_tag_filter_context(next_task.task.name, LOGGER): if self._current_task_otel_context is not None: diff --git a/tests/unit_tests/worker/test_task_worker.py b/tests/unit_tests/worker/test_task_worker.py index 45fd42c734..675b0e66cf 100644 --- a/tests/unit_tests/worker/test_task_worker.py +++ b/tests/unit_tests/worker/test_task_worker.py @@ -319,7 +319,7 @@ def result_plan() -> MsgGenerator: events = begin_task_and_wait_until_complete(worker, task_id) ts = events[-1].task_status assert ts is not None - assert ts.result == result + assert ts.result == serial assert ts.model_dump()["result"] == serial From 92f2b5ff4ba039dfc9988e570ec79ddf48dd1298 Mon Sep 17 00:00:00 2001 From: Peter Holloway Date: Wed, 11 Feb 2026 11:21:09 +0000 Subject: [PATCH 11/14] Remove duplicate test --- tests/unit_tests/worker/test_task_worker.py | 13 ------------- 1 file changed, 13 deletions(-) diff --git a/tests/unit_tests/worker/test_task_worker.py b/tests/unit_tests/worker/test_task_worker.py index 675b0e66cf..c2f2387c2b 100644 --- a/tests/unit_tests/worker/test_task_worker.py +++ b/tests/unit_tests/worker/test_task_worker.py @@ -79,11 +79,6 @@ def failing_plan() -> MsgGenerator: raise KeyError("I failed") -def returning_plan() -> MsgGenerator[int]: - yield from [] - return 42 - - @dataclasses.dataclass class ComplexReturn: foo: int @@ -117,7 +112,6 @@ def context(fake_device: FakeDevice, second_fake_device: FakeDevice) -> BlueskyC ctx_config = EnvironmentConfig() ctx_config.sources.append(DeviceSource(module="devices")) ctx.register_plan(failing_plan) - ctx.register_plan(returning_plan) ctx.register_device(fake_device) ctx.register_device(second_fake_device) ctx.with_config(ctx_config) @@ -292,13 +286,6 @@ def test_begin_task_uses_plan_name_filter( filter_mock.assert_called_once() -def test_return_value_recorded(worker: TaskWorker): - task_id = worker.submit_task(Task(name="returning_plan")) - events = begin_task_and_wait_until_complete(worker, task_id) - assert events[-1].task_status is not None - assert events[-1].task_status.result == 42 - - @pytest.mark.parametrize( "result,serial", [ From 6a697c401e8828e597693bbcd6ff62f81f078f83 Mon Sep 17 00:00:00 2001 From: Peter Holloway Date: Wed, 11 Feb 2026 12:08:45 +0000 Subject: [PATCH 12/14] Ignore coverage on line only present to keep type-checking happy --- src/blueapi/worker/task.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/blueapi/worker/task.py b/src/blueapi/worker/task.py index 669b7fb05f..9ce373c769 100644 --- a/src/blueapi/worker/task.py +++ b/src/blueapi/worker/task.py @@ -39,7 +39,8 @@ def do_task(self, ctx: BlueskyContext) -> None: prepared_params = self.prepare_params(ctx) ctx.run_engine.md.update(self.metadata) result = ctx.run_engine(func(**prepared_params)) - if isinstance(result, tuple): + if isinstance(result, tuple): # pragma: no cover + # this is never true if the run_engine is configured correctly return None return result.plan_result From 7775315595082f2efad74f58d93d43c57003add4 Mon Sep 17 00:00:00 2001 From: Peter Holloway Date: Thu, 12 Feb 2026 18:18:40 +0000 Subject: [PATCH 13/14] Split task result into failure and success states Allow callers to access both return values and exceptions raised by plans. --- docs/reference/openapi.yaml | 53 ++++++++++++++++- src/blueapi/cli/cli.py | 17 ++++-- src/blueapi/client/client.py | 22 +++---- src/blueapi/worker/event.py | 47 ++++++++++++++- src/blueapi/worker/task_worker.py | 29 +++++---- tests/system_tests/test_blueapi_system.py | 36 +++++++---- tests/unit_tests/client/test_client.py | 21 +++++-- tests/unit_tests/service/test_interface.py | 10 +++- tests/unit_tests/service/test_rest_api.py | 10 ++-- tests/unit_tests/test_cli.py | 28 ++++++--- tests/unit_tests/worker/test_task_worker.py | 66 +++++++++++++++------ 11 files changed, 257 insertions(+), 82 deletions(-) diff --git a/docs/reference/openapi.yaml b/docs/reference/openapi.yaml index 88a50972db..35146b01a6 100644 --- a/docs/reference/openapi.yaml +++ b/docs/reference/openapi.yaml @@ -246,6 +246,26 @@ components: - name title: Task type: object + TaskError: + additionalProperties: false + description: Wrapper around an exception raised by a plan + properties: + message: + title: Message + type: string + outcome: + const: error + default: error + title: Outcome + type: string + type: + title: Type + type: string + required: + - type + - message + title: TaskError + type: object TaskRequest: additionalProperties: false description: Request to run a task with related info @@ -280,6 +300,31 @@ components: - task_id title: TaskResponse type: object + TaskResult: + additionalProperties: false + description: 'Serializable wrapper around the result of a plan + + + If the result is not serializable, the result will be None but the type + + will be the name of the type. If the result is actually None, the type will + + be ''NoneType''.' + properties: + outcome: + const: success + default: success + title: Outcome + type: string + result: + title: Result + type: + title: Type + type: string + required: + - type + title: TaskResult + type: object TasksListResponse: additionalProperties: false description: Diagnostic information on the tasks @@ -311,11 +356,15 @@ components: default: true title: Is Pending type: boolean + outcome: + anyOf: + - $ref: '#/components/schemas/TaskResult' + - $ref: '#/components/schemas/TaskError' + - type: 'null' + title: Outcome request_id: title: Request Id type: string - result: - title: Result task: $ref: '#/components/schemas/Task' task_id: diff --git a/src/blueapi/cli/cli.py b/src/blueapi/cli/cli.py index 4d974d35e2..3036bc054a 100644 --- a/src/blueapi/cli/cli.py +++ b/src/blueapi/cli/cli.py @@ -8,7 +8,7 @@ from functools import wraps from pathlib import Path from pprint import pprint -from typing import ParamSpec, TypeVar +from typing import ParamSpec, TypeVar, cast import click from bluesky.callbacks.best_effort import BestEffortCallback @@ -38,6 +38,7 @@ from blueapi.service.authentication import SessionCacheManager, SessionManager from blueapi.service.model import SourceInfo, TaskRequest from blueapi.worker import ProgressEvent, WorkerEvent +from blueapi.worker.event import TaskError, TaskResult from .scratch import setup_scratch from .updates import CliEventRenderer @@ -290,7 +291,7 @@ def run_plan( instrument_session: str, ) -> None: """Run a plan with parameters""" - client: BlueapiClient = obj["client"] + client: BlueapiClient = cast(BlueapiClient, obj["client"]) parameters = parameters or "{}" try: @@ -320,9 +321,15 @@ def on_event(event: AnyEvent) -> None: callback(event.name, event.doc) resp = client.run_task(task, on_event=on_event) - - if resp.task_status is not None and not resp.task_status.task_failed: - print("Plan Succeeded") + match resp.result: + case TaskResult(result=None, type="NoneType"): + print("Plan succeeded") + case TaskResult(result=None, type=t): + print(f"Plan returned unserializable result of type '{t}'") + case TaskResult(result=r): + print(f"Plan succeeded: {r}") + case TaskError(type=exc, message=m): + print(f"Plan failed: {exc}: {m}") else: server_task = client.create_and_start_task(task) click.echo(server_task.task_id) diff --git a/src/blueapi/client/client.py b/src/blueapi/client/client.py index 0930e240a9..24448bc2be 100644 --- a/src/blueapi/client/client.py +++ b/src/blueapi/client/client.py @@ -28,7 +28,7 @@ from blueapi.worker import TrackableTask, WorkerEvent, WorkerState from blueapi.worker.event import ProgressEvent, TaskStatus -from .event_bus import AnyEvent, BlueskyStreamingError, EventBusClient, OnAnyEvent +from .event_bus import AnyEvent, EventBusClient, OnAnyEvent from .rest import BlueapiRestClient, BlueskyRemoteControlError TRACER = get_tracer("client") @@ -201,7 +201,7 @@ def run_task( task: TaskRequest, on_event: OnAnyEvent | None = None, timeout: float | None = None, - ) -> WorkerEvent: + ) -> TaskStatus: """ Synchronously run a task, requires a message bus connection @@ -224,7 +224,7 @@ def run_task( task_response = self.create_task(task) task_id = task_response.task_id - complete: Future[WorkerEvent] = Future() + complete: Future[TaskStatus] = Future() def inner_on_event(event: AnyEvent, ctx: MessageContext) -> None: match event: @@ -239,19 +239,19 @@ def inner_on_event(event: AnyEvent, ctx: MessageContext) -> None: if relates_to_task: if on_event is not None: on_event(event) - if isinstance(event, WorkerEvent) and ( - (event.is_complete()) and (ctx.correlation_id == task_id) + if ( + isinstance(event, WorkerEvent) + and (event.is_complete()) + and (ctx.correlation_id == task_id) ): - if event.task_status is not None and event.task_status.task_failed: + if event.task_status is None: complete.set_exception( - BlueskyStreamingError( - "\n".join(event.errors) - if len(event.errors) > 0 - else "Unknown error" + BlueskyRemoteControlError( + "Server completed without task status" ) ) else: - complete.set_result(event) + complete.set_result(event.task_status) with self._events: self._events.subscribe_to_all_events(inner_on_event) diff --git a/src/blueapi/worker/event.py b/src/blueapi/worker/event.py index ca4496d273..bfba07bd62 100644 --- a/src/blueapi/worker/event.py +++ b/src/blueapi/worker/event.py @@ -1,10 +1,10 @@ import logging from collections.abc import Mapping from enum import StrEnum -from typing import Any +from typing import Any, Literal, Self from bluesky.run_engine import RunEngineStateMachine -from pydantic import Field +from pydantic import Field, PydanticSchemaGenerationError, TypeAdapter from super_state_machine.extras import PropertyMachine, ProxyString from blueapi.utils import BlueapiBaseModel @@ -56,6 +56,47 @@ def from_bluesky_state(cls, bluesky_state: RawRunEngineState) -> "WorkerState": return WorkerState(str(bluesky_state).upper()) +class TaskResult(BlueapiBaseModel): + """ + Serializable wrapper around the result of a plan + + If the result is not serializable, the result will be None but the type + will be the name of the type. If the result is actually None, the type will + be 'NoneType'. + """ + + outcome: Literal["success"] = "success" + """Discriminant for serialization""" + result: Any = Field(None) + """The serialized result (or None if it is not serializable)""" + type: str + """The type of the result""" + + @classmethod + def from_result(cls, result: Any) -> Self: + type_str = type(result).__name__ + try: + value = TypeAdapter(type(result)).dump_python(result) + except PydanticSchemaGenerationError: + value = None + return cls(result=value, type=type_str) + + +class TaskError(BlueapiBaseModel): + """Wrapper around an exception raised by a plan""" + + outcome: Literal["error"] = "error" + """Discriminant for serialization""" + type: str + """The class of exception""" + message: str + """The message of the raised exception""" + + @classmethod + def from_exception(cls, err: Exception) -> Self: + return cls(type=type(err).__name__, message=str(err)) + + class StatusView(BlueapiBaseModel): """ A snapshot of a Status of an operation, optionally representing progress @@ -111,9 +152,9 @@ class TaskStatus(BlueapiBaseModel): """ task_id: str + result: TaskResult | TaskError | None = Field(None, discriminator="outcome") task_complete: bool task_failed: bool - result: Any = None class WorkerEvent(BlueapiBaseModel): diff --git a/src/blueapi/worker/task_worker.py b/src/blueapi/worker/task_worker.py index 0d092e266a..c30fc0220d 100644 --- a/src/blueapi/worker/task_worker.py +++ b/src/blueapi/worker/task_worker.py @@ -18,7 +18,7 @@ from opentelemetry.baggage import get_baggage from opentelemetry.context import Context, get_current from opentelemetry.trace import SpanKind -from pydantic import Field, TypeAdapter +from pydantic import Field from pydantic.json_schema import SkipJsonSchema from super_state_machine.errors import TransitionError @@ -39,6 +39,8 @@ ProgressEvent, RawRunEngineState, StatusView, + TaskError, + TaskResult, TaskStatus, TaskStatusEnum, WorkerEvent, @@ -69,16 +71,13 @@ class TrackableTask(BlueapiBaseModel): is_complete: bool = False is_pending: bool = True errors: list[str] = Field(default_factory=list) - result: Any = None + outcome: TaskResult | TaskError | None = None - def set_result(self, result): - try: - self.result = TypeAdapter(type(result)).dump_python(result, mode="json") - except Exception: - LOGGER.warning( - "Plan result (%s) is not serializable so will not be available", result - ) - pass + def set_result(self, result: Any): + self.outcome = TaskResult.from_result(result) + + def set_exception(self, err: Exception): + self.outcome = TaskError.from_exception(err) class TaskWorker: @@ -437,8 +436,12 @@ def process_task(): LOGGER.info(f"Got new task: {next_task}") self._current = next_task self._current.is_pending = False - result = self._current.task.do_task(self._ctx) - self._current.set_result(result) + try: + result = self._current.task.do_task(self._ctx) + self._current.set_result(result) + except Exception as e: + self._current.set_exception(e) + self._report_error(e) with plan_tag_filter_context(next_task.task.name, LOGGER): if self._current_task_otel_context is not None: @@ -539,7 +542,7 @@ def _report_status( task_id=self._current.task_id, task_complete=self._current.is_complete, task_failed=bool(self._current.errors), - result=self._current.result, + result=self._current.outcome, ) correlation_id = self._current.task_id add_span_attributes( diff --git a/tests/system_tests/test_blueapi_system.py b/tests/system_tests/test_blueapi_system.py index f2d6395dc0..9b26e22ec0 100644 --- a/tests/system_tests/test_blueapi_system.py +++ b/tests/system_tests/test_blueapi_system.py @@ -12,7 +12,7 @@ from requests.exceptions import ConnectionError from blueapi.client import BlueapiClient -from blueapi.client.event_bus import AnyEvent, BlueskyStreamingError +from blueapi.client.event_bus import AnyEvent from blueapi.client.rest import ( BlueskyRemoteControlError, BlueskyRequestError, @@ -31,7 +31,13 @@ TaskResponse, WorkerTask, ) -from blueapi.worker.event import TaskStatus, WorkerEvent, WorkerState +from blueapi.worker.event import ( + TaskError, + TaskResult, + TaskStatus, + WorkerEvent, + WorkerState, +) from blueapi.worker.task_worker import TrackableTask AUTHORIZED_INSTRUMENT_SESSION = "cm12345-1" @@ -418,6 +424,7 @@ def on_event(event: AnyEvent): task_id=task_id, task_complete=False, task_failed=False, + result=None, ), ), WorkerEvent( @@ -426,6 +433,7 @@ def on_event(event: AnyEvent): task_id=task_id, task_complete=False, task_failed=False, + result=None, ), ), WorkerEvent( @@ -434,6 +442,7 @@ def on_event(event: AnyEvent): task_id=task_id, task_complete=True, task_failed=False, + result=TaskResult(result=None, type="NoneType"), ), ), ] @@ -510,8 +519,9 @@ def on_event(event: AnyEvent) -> None: resource.put_nowait(event.doc) final_event = client_with_stomp.run_task(task, on_event) - assert final_event.is_complete() and not final_event.is_error() - assert final_event.state is WorkerState.IDLE + assert isinstance(final_event.result, TaskResult) + assert final_event.task_complete + assert not final_event.task_failed start_doc = start.get_nowait() assert start_doc["scan_id"] == scan_id @@ -558,8 +568,9 @@ def on_event(event: AnyEvent) -> None: ) def test_stub_runs(client_with_stomp: BlueapiClient, task: TaskRequest): final_event = client_with_stomp.run_task(task) - assert final_event.is_complete() and not final_event.is_error() - assert final_event.state is WorkerState.IDLE + assert isinstance(final_event.result, TaskResult) + assert final_event.task_complete + assert not final_event.task_failed @pytest.mark.parametrize( @@ -580,7 +591,7 @@ def test_stub_runs(client_with_stomp: BlueapiClient, task: TaskRequest): ), ], ) -def test_unauthozied_plan_run( +def test_unauthorized_plan_run( client_with_stomp: BlueapiClient, task: TaskRequest, scan_id: int ): resource = Queue(maxsize=1) @@ -593,8 +604,9 @@ def on_event(event: AnyEvent) -> None: if event.name == "stream_resource": resource.put_nowait(event.doc) - with pytest.raises( - BlueskyStreamingError, - match="404: No such entry", - ): - client_with_stomp.run_task(task, on_event) + outcome = client_with_stomp.run_task(task, on_event) + assert outcome.task_failed + assert outcome.task_complete + assert isinstance(outcome.result, TaskError) + assert outcome.result.type == "ClientError" + assert outcome.result.message.startswith("404: No such entry: [") diff --git a/tests/unit_tests/client/test_client.py b/tests/unit_tests/client/test_client.py index d6d2e1f22b..8bed8927da 100644 --- a/tests/unit_tests/client/test_client.py +++ b/tests/unit_tests/client/test_client.py @@ -10,7 +10,7 @@ ) from blueapi.client import BlueapiClient -from blueapi.client.event_bus import AnyEvent, BlueskyStreamingError, EventBusClient +from blueapi.client.event_bus import AnyEvent, EventBusClient from blueapi.client.rest import BlueapiRestClient, BlueskyRemoteControlError from blueapi.config import MissingStompConfigurationError from blueapi.core import DataEvent @@ -26,7 +26,7 @@ WorkerTask, ) from blueapi.worker import ProgressEvent, Task, TrackableTask, WorkerEvent, WorkerState -from blueapi.worker.event import TaskStatus +from blueapi.worker.event import TaskError, TaskResult, TaskStatus PLANS = PlanResponse( plans=[ @@ -55,6 +55,7 @@ task_id="foo", task_complete=True, task_failed=False, + result=TaskResult(type="NoneType", result=None), ), ) FAILED_EVENT = WorkerEvent( @@ -63,6 +64,7 @@ task_id="foo", task_complete=True, task_failed=True, + result=TaskError(type="PlanFailure", message="The plan failed"), ), ) @@ -431,10 +433,15 @@ def test_run_task_fails_on_failing_event( mock_events.subscribe_to_all_events = lambda on_event: on_event(FAILED_EVENT, ctx) on_event = Mock() - with pytest.raises(BlueskyStreamingError): - client_with_events.run_task( - TaskRequest(name="foo", instrument_session="cm12345-1"), on_event=on_event - ) + outcome = client_with_events.run_task( + TaskRequest(name="foo", instrument_session="cm12345-1"), + on_event=on_event, + ) + assert outcome.task_failed + assert outcome.task_complete + assert isinstance(outcome.result, TaskError) + assert outcome.result.message == "The plan failed" + assert outcome.result.type == "PlanFailure" on_event.assert_called_with(FAILED_EVENT) @@ -448,6 +455,7 @@ def test_run_task_fails_on_failing_event( task_id="foo", task_complete=False, task_failed=False, + result=TaskError(type="ValueError", message="Task failed"), ), ), ProgressEvent(task_id="foo"), @@ -489,6 +497,7 @@ def callback(on_event: Callable[[AnyEvent, MessageContext], None]): task_id="bar", task_complete=False, task_failed=False, + result=None, ), ), ProgressEvent(task_id="bar"), diff --git a/tests/unit_tests/service/test_interface.py b/tests/unit_tests/service/test_interface.py index da6f1a4d96..77581e7b12 100644 --- a/tests/unit_tests/service/test_interface.py +++ b/tests/unit_tests/service/test_interface.py @@ -43,7 +43,13 @@ from blueapi.utils.invalid_config_error import InvalidConfigError from blueapi.utils.numtracker import NumtrackerClient from blueapi.utils.path_provider import StartDocumentPathProvider -from blueapi.worker.event import TaskStatus, TaskStatusEnum, WorkerEvent, WorkerState +from blueapi.worker.event import ( + TaskResult, + TaskStatus, + TaskStatusEnum, + WorkerEvent, + WorkerState, +) from blueapi.worker.task import Task from blueapi.worker.task_worker import TrackableTask @@ -392,6 +398,7 @@ def test_remove_tiled_subscriber(worker, context, from_uri, writer): task_id="foo_bar", task_complete=False, task_failed=False, + result=None, ), ), "c_id", @@ -406,6 +413,7 @@ def test_remove_tiled_subscriber(worker, context, from_uri, writer): task_id="foo_bar", task_complete=True, task_failed=False, + result=TaskResult(result=None, type="NoneType"), ), ), "c_id", diff --git a/tests/unit_tests/service/test_rest_api.py b/tests/unit_tests/service/test_rest_api.py index 12555cc784..d9b5716928 100644 --- a/tests/unit_tests/service/test_rest_api.py +++ b/tests/unit_tests/service/test_rest_api.py @@ -336,7 +336,7 @@ def test_get_tasks(mock_runner: Mock, client: TestClient) -> None: "params": {"time": 0.0}, "metadata": {}, }, - "result": None, + "outcome": None, "task_id": "0", }, { @@ -349,7 +349,7 @@ def test_get_tasks(mock_runner: Mock, client: TestClient) -> None: "params": {}, "metadata": {}, }, - "result": None, + "outcome": None, "task_id": "1", }, ] @@ -381,7 +381,7 @@ def test_get_tasks_by_status(mock_runner: Mock, client: TestClient) -> None: "params": {}, "metadata": {}, }, - "result": None, + "outcome": None, "task_id": "3", } ] @@ -475,7 +475,7 @@ def test_get_task(mock_runner: Mock, client: TestClient): "foo": "bar", }, }, - "result": None, + "outcome": None, "task_id": f"{task_id}", } @@ -504,7 +504,7 @@ def test_get_all_tasks(mock_runner: Mock, client: TestClient): "is_complete": False, "is_pending": True, "request_id": None, - "result": None, + "outcome": None, "errors": [], } ] diff --git a/tests/unit_tests/test_cli.py b/tests/unit_tests/test_cli.py index 1aaf585581..4e7c44a11e 100644 --- a/tests/unit_tests/test_cli.py +++ b/tests/unit_tests/test_cli.py @@ -52,7 +52,13 @@ TaskRequest, TaskResponse, ) -from blueapi.worker.event import ProgressEvent, TaskStatus, WorkerEvent, WorkerState +from blueapi.worker.event import ( + ProgressEvent, + TaskResult, + TaskStatus, + WorkerEvent, + WorkerState, +) @pytest.fixture(autouse=True) @@ -302,7 +308,10 @@ def mock_events(topic: MessageTopic, callback: Callable[[Any, Any], Any]): WorkerEvent( state=WorkerState.RUNNING, task_status=TaskStatus( - task_id=task_id, task_complete=False, task_failed=False + task_id=task_id, + task_complete=False, + task_failed=False, + result=None, ), ), ctx, @@ -313,7 +322,7 @@ def mock_events(topic: MessageTopic, callback: Callable[[Any, Any], Any]): WorkerEvent( state=WorkerState.IDLE, task_status=TaskStatus( - task_id=task_id, task_complete=False, task_failed=False + task_id=task_id, task_complete=False, task_failed=False, result=None ), ), ctx, @@ -322,7 +331,10 @@ def mock_events(topic: MessageTopic, callback: Callable[[Any, Any], Any]): WorkerEvent( state=WorkerState.IDLE, task_status=TaskStatus( - task_id=task_id, task_complete=True, task_failed=False + task_id=task_id, + task_complete=True, + task_failed=False, + result=TaskResult(result=None, type="NoneType"), ), ), ctx, @@ -897,7 +909,9 @@ def test_event_formatting(): ) worker = WorkerEvent( state=WorkerState.RUNNING, - task_status=TaskStatus(task_id="count", task_complete=False, task_failed=False), + task_status=TaskStatus( + task_id="count", task_complete=False, task_failed=False, result=None + ), errors=[], warnings=[], ) @@ -931,9 +945,9 @@ def test_event_formatting(): ( '{"state": "RUNNING", "task_status": {' '"task_id": "count", ' + '"result": null, ' '"task_complete": false, ' - '"task_failed": false, ' - '"result": null' + '"task_failed": false' '}, "errors": [], "warnings": []}\n' ), ) diff --git a/tests/unit_tests/worker/test_task_worker.py b/tests/unit_tests/worker/test_task_worker.py index c2f2387c2b..5b1cb6f71b 100644 --- a/tests/unit_tests/worker/test_task_worker.py +++ b/tests/unit_tests/worker/test_task_worker.py @@ -35,7 +35,7 @@ WorkerEvent, WorkerState, ) -from blueapi.worker.event import TaskStatusEnum +from blueapi.worker.event import TaskResult, TaskStatusEnum _SIMPLE_TASK = Task(name="sleep", params={"time": 0.0}) _LONG_TASK = Task(name="sleep", params={"time": 1.0}) @@ -287,16 +287,23 @@ def test_begin_task_uses_plan_name_filter( @pytest.mark.parametrize( - "result,serial", + "result,serial,type_name", [ - (42, 42), - ("helloWorld", "helloWorld"), - (ComplexReturn(34, ["foo"]), {"foo": 34, "bar": ["foo"]}), - (ModelReturn(foo=42, bar=["helloWorld"]), {"foo": 42, "bar": ["helloWorld"]}), - (Unreturnable(17, ["fizzbuzz"]), None), + (None, None, "NoneType"), + (42, 42, "int"), + ("helloWorld", "helloWorld", "str"), + (ComplexReturn(34, ["foo"]), {"foo": 34, "bar": ["foo"]}, "ComplexReturn"), + ( + ModelReturn(foo=42, bar=["helloWorld"]), + {"foo": 42, "bar": ["helloWorld"]}, + "ModelReturn", + ), + (Unreturnable(17, ["fizzbuzz"]), None, "Unreturnable"), ], ) -def test_plan_result_serialized(worker: TaskWorker, result: Any, serial: Any): +def test_plan_result_serialized( + worker: TaskWorker, result: Any, serial: Any, type_name: str +): def result_plan() -> MsgGenerator: yield from [] return result @@ -306,9 +313,13 @@ def result_plan() -> MsgGenerator: events = begin_task_and_wait_until_complete(worker, task_id) ts = events[-1].task_status assert ts is not None - assert ts.result == serial + assert ts.result == TaskResult(result=serial, type=type_name) - assert ts.model_dump()["result"] == serial + assert ts.model_dump()["result"] == { + "outcome": "success", + "result": serial, + "type": type_name, + } def test_plan_failure_recorded_in_active_task(worker: TaskWorker) -> None: @@ -355,7 +366,10 @@ def _sleep_events(task_id: str) -> list[WorkerEvent]: WorkerEvent( state=WorkerState.RUNNING, task_status=TaskStatus( - task_id=task_id, task_complete=False, task_failed=False + task_id=task_id, + task_complete=False, + task_failed=False, + result=None, ), errors=[], warnings=[], @@ -363,7 +377,10 @@ def _sleep_events(task_id: str) -> list[WorkerEvent]: WorkerEvent( state=WorkerState.IDLE, task_status=TaskStatus( - task_id=task_id, task_complete=False, task_failed=False + task_id=task_id, + task_complete=False, + task_failed=False, + result=None, ), errors=[], warnings=[], @@ -371,7 +388,10 @@ def _sleep_events(task_id: str) -> list[WorkerEvent]: WorkerEvent( state=WorkerState.IDLE, task_status=TaskStatus( - task_id=task_id, task_complete=True, task_failed=False + task_id=task_id, + task_complete=True, + task_failed=False, + result=TaskResult.from_result(None), ), errors=[], warnings=[], @@ -447,7 +467,10 @@ def test_worker_and_data_events_produce_in_order( WorkerEvent( state=WorkerState.RUNNING, task_status=TaskStatus( - task_id="count", task_complete=False, task_failed=False + task_id="count", + task_complete=False, + task_failed=False, + result=None, ), errors=[], warnings=[], @@ -459,7 +482,10 @@ def test_worker_and_data_events_produce_in_order( WorkerEvent( state=WorkerState.IDLE, task_status=TaskStatus( - task_id="count", task_complete=False, task_failed=False + task_id="count", + task_complete=False, + task_failed=False, + result=None, ), errors=[], warnings=[], @@ -467,7 +493,10 @@ def test_worker_and_data_events_produce_in_order( WorkerEvent( state=WorkerState.IDLE, task_status=TaskStatus( - task_id="count", task_complete=True, task_failed=False + task_id="count", + task_complete=True, + task_failed=False, + result=TaskResult.from_result(None), ), errors=[], warnings=[], @@ -744,13 +773,16 @@ def test_cycle_without_otel_context(mock_logger: Mock, inert_worker: TaskWorker) task = TrackableTask(task_id="0", task=_SIMPLE_TASK) inert_worker._task_channel.put_nowait(task) inert_worker._pending_tasks["0"] = task + + # task changes during cycle so need to cache log message first + expected = f"Got new task: {task}" inert_worker._cycle() assert inert_worker._current_task_otel_context is None # Bad way to tell that this branch ahs been run, but I can't think of a better way # Have to set these values to match output task.is_complete = False task.is_pending = True - mock_logger.info.assert_called_with(f"Got new task: {task}") + mock_logger.info.assert_called_with(expected) class MyComposite(BlueapiBaseModel): From 5f316d8d7b4e6c0b42ad86f46836c88bfd2cd500 Mon Sep 17 00:00:00 2001 From: Peter Holloway Date: Fri, 13 Feb 2026 12:38:05 +0000 Subject: [PATCH 14/14] Add tests for plan feedback --- tests/unit_tests/test_cli.py | 46 ++++++++++++++++++++++++++++++++++++ 1 file changed, 46 insertions(+) diff --git a/tests/unit_tests/test_cli.py b/tests/unit_tests/test_cli.py index 4e7c44a11e..96e14dc3d8 100644 --- a/tests/unit_tests/test_cli.py +++ b/tests/unit_tests/test_cli.py @@ -8,6 +8,7 @@ from pathlib import Path from textwrap import dedent from typing import Any, TypeVar +from unittest import mock from unittest.mock import Mock, patch import pytest @@ -54,6 +55,7 @@ ) from blueapi.worker.event import ( ProgressEvent, + TaskError, TaskResult, TaskStatus, WorkerEvent, @@ -363,6 +365,50 @@ def mock_events(topic: MessageTopic, callback: Callable[[Any, Any], Any]): assert run_response.call_count == 1 +@pytest.mark.parametrize( + "result,failed,message", + [ + (TaskResult(result=None, type="NoneType"), False, "Plan succeeded\n"), + (TaskResult(result=32, type="int"), False, "Plan succeeded: 32\n"), + ( + TaskResult(result=None, type="CustomType"), + False, + "Plan returned unserializable result of type 'CustomType'\n", + ), + ( + TaskError(type="ValueError", message="Error with value"), + True, + "Plan failed: ValueError: Error with value\n", + ), + ], +) +@patch("blueapi.cli.cli.BlueapiClient") +def test_run_plan_feedback( + mock_client: Mock, + runner: CliRunner, + result: TaskResult | TaskError | None, + failed: bool, + message: str, +): + bc = mock_client.from_config() + bc.run_task.return_value = TaskStatus( + task_id="foo_bar", + task_complete=True, + task_failed=failed, + result=result, + ) + res = runner.invoke( + main, + ["controller", "run", "-i", "cm12345-1", "name"], + ) + bc.run_task.assert_called_once_with( + TaskRequest(name="name", params={}, instrument_session="cm12345-1"), + on_event=mock.ANY, + ) + assert res.exit_code == 0 + assert res.stdout == message + + @responses.activate def test_run_plan_background_without_stomp(runner: CliRunner): submit_response = responses.post(